mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
Compare commits
59 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
072adc41aa | ||
|
|
c6f25ef1f0 | ||
|
|
2f0c5baea2 | ||
|
|
a63dd66d41 | ||
|
|
d6b3ccb37b | ||
|
|
c4f99e82e5 | ||
|
|
979a2d3d9d | ||
|
|
7ac5f74c80 | ||
|
|
ecdee4d2b1 | ||
|
|
f391ed828a | ||
|
|
a99a450f2b | ||
|
|
6fa1f37506 | ||
|
|
544382df5e | ||
|
|
784f00ef6d | ||
|
|
96d7446f70 | ||
|
|
99ea78fb55 | ||
|
|
8eef4cdc28 | ||
|
|
0f102f02c3 | ||
|
|
a33a0670f6 | ||
|
|
14c9ff46d1 | ||
|
|
1865f7decf | ||
|
|
a608621476 | ||
|
|
00514999ff | ||
|
|
b3b597fef6 | ||
|
|
bf17144591 | ||
|
|
09e110525f | ||
|
|
40f0dbb64d | ||
|
|
3b19e96ae7 | ||
|
|
78a17ad54c | ||
|
|
a8e6b491e2 | ||
|
|
cea541ca46 | ||
|
|
873ffc1042 | ||
|
|
83273ad997 | ||
|
|
d18d63c69d | ||
|
|
c3e865e8d0 | ||
|
|
a7755cb313 | ||
|
|
3490f3456f | ||
|
|
0a1d0693e1 | ||
|
|
fd330b4b4b | ||
|
|
d4e9fc08e0 | ||
|
|
3626f2f5e1 | ||
|
|
e64712cfa5 | ||
|
|
3e3118f85c | ||
|
|
592598a333 | ||
|
|
5ad21341c9 | ||
|
|
6e08caa091 | ||
|
|
7e259d8b0f | ||
|
|
e84f747464 | ||
|
|
998cd43fe6 | ||
|
|
4bc7eebe61 | ||
|
|
2e3b34e79b | ||
|
|
e7574698eb | ||
|
|
801a9e5f6f | ||
|
|
4e5fbe6c99 | ||
|
|
1a449fa49e | ||
|
|
6bf742c759 | ||
|
|
ef3093bc23 | ||
|
|
16851389ea | ||
|
|
c269524b2f |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.15.1-beta.3"
|
current_version = "0.16.1-beta.3"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
7
.github/workflows/rust.yml
vendored
7
.github/workflows/rust.yml
vendored
@@ -61,7 +61,12 @@ jobs:
|
|||||||
CXX: clang++
|
CXX: clang++
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
# Remote cargo.lock to force a fresh build
|
# Building without a lock file often requires the latest Rust version since downstream
|
||||||
|
# dependencies may have updated their minimum Rust version.
|
||||||
|
- uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: "stable"
|
||||||
|
# Remove cargo.lock to force a fresh build
|
||||||
- name: Remove Cargo.lock
|
- name: Remove Cargo.lock
|
||||||
run: rm -f Cargo.lock
|
run: rm -f Cargo.lock
|
||||||
- uses: rui314/setup-mold@v1
|
- uses: rui314/setup-mold@v1
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ repos:
|
|||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
# Ruff version.
|
# Ruff version.
|
||||||
rev: v0.2.2
|
rev: v0.8.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- repo: local
|
- repo: local
|
||||||
|
|||||||
884
Cargo.lock
generated
884
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
39
Cargo.toml
39
Cargo.toml
@@ -21,16 +21,14 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.78.0"
|
rust-version = "1.78.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.23.0", "features" = [
|
lance = { "version" = "=0.23.2", "features" = ["dynamodb"] }
|
||||||
"dynamodb",
|
lance-io = { version = "=0.23.2" }
|
||||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-index = { version = "=0.23.2" }
|
||||||
lance-io = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-linalg = { version = "=0.23.2" }
|
||||||
lance-index = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-table = { version = "=0.23.2" }
|
||||||
lance-linalg = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-testing = { version = "=0.23.2" }
|
||||||
lance-table = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-datafusion = { version = "=0.23.2" }
|
||||||
lance-testing = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
lance-encoding = { version = "=0.23.2" }
|
||||||
lance-datafusion = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
|
||||||
lance-encoding = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.5" }
|
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "53.2", optional = false }
|
arrow = { version = "53.2", optional = false }
|
||||||
arrow-array = "53.2"
|
arrow-array = "53.2"
|
||||||
@@ -41,24 +39,33 @@ arrow-schema = "53.2"
|
|||||||
arrow-arith = "53.2"
|
arrow-arith = "53.2"
|
||||||
arrow-cast = "53.2"
|
arrow-cast = "53.2"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
chrono = "0.4.35"
|
datafusion = { version = "44.0", default-features = false }
|
||||||
datafusion-common = "44.0"
|
datafusion-catalog = "44.0"
|
||||||
|
datafusion-common = { version = "44.0", default-features = false }
|
||||||
|
datafusion-execution = "44.0"
|
||||||
|
datafusion-expr = "44.0"
|
||||||
datafusion-physical-plan = "44.0"
|
datafusion-physical-plan = "44.0"
|
||||||
env_logger = "0.10"
|
env_logger = "0.11"
|
||||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
] }
|
] }
|
||||||
futures = "0"
|
futures = "0"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
moka = { version = "0.11", features = ["future"] }
|
moka = { version = "0.12", features = ["future"] }
|
||||||
object_store = "0.10.2"
|
object_store = "0.11.0"
|
||||||
pin-project = "1.0.7"
|
pin-project = "1.0.7"
|
||||||
snafu = "0.7.4"
|
snafu = "0.8"
|
||||||
url = "2"
|
url = "2"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
regex = "1.10"
|
regex = "1.10"
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
|
|
||||||
|
# Temporary pins to work around downstream issues
|
||||||
|
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
|
||||||
|
chrono = "=0.4.39"
|
||||||
|
# https://github.com/RustCrypto/formats/issues/1684
|
||||||
|
base64ct = "=1.6.0"
|
||||||
|
|
||||||
# Workaround for: https://github.com/eira-fransham/crunchy/issues/13
|
# Workaround for: https://github.com/eira-fransham/crunchy/issues/13
|
||||||
crunchy = "=0.2.2"
|
crunchy = "=0.2.2"
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ repo_url: https://github.com/lancedb/lancedb
|
|||||||
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
|
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
|
||||||
repo_name: lancedb/lancedb
|
repo_name: lancedb/lancedb
|
||||||
docs_dir: src
|
docs_dir: src
|
||||||
|
watch:
|
||||||
|
- src
|
||||||
|
- ../python/python
|
||||||
|
|
||||||
theme:
|
theme:
|
||||||
name: "material"
|
name: "material"
|
||||||
@@ -63,6 +66,7 @@ plugins:
|
|||||||
- https://arrow.apache.org/docs/objects.inv
|
- https://arrow.apache.org/docs/objects.inv
|
||||||
- https://pandas.pydata.org/docs/objects.inv
|
- https://pandas.pydata.org/docs/objects.inv
|
||||||
- https://lancedb.github.io/lance/objects.inv
|
- https://lancedb.github.io/lance/objects.inv
|
||||||
|
- https://docs.pydantic.dev/latest/objects.inv
|
||||||
- mkdocs-jupyter
|
- mkdocs-jupyter
|
||||||
- render_swagger:
|
- render_swagger:
|
||||||
allow_arbitrary_locations: true
|
allow_arbitrary_locations: true
|
||||||
@@ -105,8 +109,8 @@ nav:
|
|||||||
- 📚 Concepts:
|
- 📚 Concepts:
|
||||||
- Vector search: concepts/vector_search.md
|
- Vector search: concepts/vector_search.md
|
||||||
- Indexing:
|
- Indexing:
|
||||||
- IVFPQ: concepts/index_ivfpq.md
|
- IVFPQ: concepts/index_ivfpq.md
|
||||||
- HNSW: concepts/index_hnsw.md
|
- HNSW: concepts/index_hnsw.md
|
||||||
- Storage: concepts/storage.md
|
- Storage: concepts/storage.md
|
||||||
- Data management: concepts/data_management.md
|
- Data management: concepts/data_management.md
|
||||||
- 🔨 Guides:
|
- 🔨 Guides:
|
||||||
@@ -130,8 +134,8 @@ nav:
|
|||||||
- Adaptive RAG: rag/adaptive_rag.md
|
- Adaptive RAG: rag/adaptive_rag.md
|
||||||
- SFR RAG: rag/sfr_rag.md
|
- SFR RAG: rag/sfr_rag.md
|
||||||
- Advanced Techniques:
|
- Advanced Techniques:
|
||||||
- HyDE: rag/advanced_techniques/hyde.md
|
- HyDE: rag/advanced_techniques/hyde.md
|
||||||
- FLARE: rag/advanced_techniques/flare.md
|
- FLARE: rag/advanced_techniques/flare.md
|
||||||
- Reranking:
|
- Reranking:
|
||||||
- Quickstart: reranking/index.md
|
- Quickstart: reranking/index.md
|
||||||
- Cohere Reranker: reranking/cohere.md
|
- Cohere Reranker: reranking/cohere.md
|
||||||
@@ -146,7 +150,7 @@ nav:
|
|||||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||||
- Example: notebooks/lancedb_reranking.ipynb
|
- Example: notebooks/lancedb_reranking.ipynb
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility:
|
- Versioning & Reproducibility:
|
||||||
- sync API: notebooks/reproducibility.ipynb
|
- sync API: notebooks/reproducibility.ipynb
|
||||||
- async API: notebooks/reproducibility_async.ipynb
|
- async API: notebooks/reproducibility_async.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
@@ -178,6 +182,7 @@ nav:
|
|||||||
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
||||||
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
||||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||||
|
- Variables and secrets: embeddings/variables_and_secrets.md
|
||||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- 🔌 Integrations:
|
- 🔌 Integrations:
|
||||||
@@ -240,8 +245,8 @@ nav:
|
|||||||
- Concepts:
|
- Concepts:
|
||||||
- Vector search: concepts/vector_search.md
|
- Vector search: concepts/vector_search.md
|
||||||
- Indexing:
|
- Indexing:
|
||||||
- IVFPQ: concepts/index_ivfpq.md
|
- IVFPQ: concepts/index_ivfpq.md
|
||||||
- HNSW: concepts/index_hnsw.md
|
- HNSW: concepts/index_hnsw.md
|
||||||
- Storage: concepts/storage.md
|
- Storage: concepts/storage.md
|
||||||
- Data management: concepts/data_management.md
|
- Data management: concepts/data_management.md
|
||||||
- Guides:
|
- Guides:
|
||||||
@@ -265,8 +270,8 @@ nav:
|
|||||||
- Adaptive RAG: rag/adaptive_rag.md
|
- Adaptive RAG: rag/adaptive_rag.md
|
||||||
- SFR RAG: rag/sfr_rag.md
|
- SFR RAG: rag/sfr_rag.md
|
||||||
- Advanced Techniques:
|
- Advanced Techniques:
|
||||||
- HyDE: rag/advanced_techniques/hyde.md
|
- HyDE: rag/advanced_techniques/hyde.md
|
||||||
- FLARE: rag/advanced_techniques/flare.md
|
- FLARE: rag/advanced_techniques/flare.md
|
||||||
- Reranking:
|
- Reranking:
|
||||||
- Quickstart: reranking/index.md
|
- Quickstart: reranking/index.md
|
||||||
- Cohere Reranker: reranking/cohere.md
|
- Cohere Reranker: reranking/cohere.md
|
||||||
@@ -280,7 +285,7 @@ nav:
|
|||||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||||
- Example: notebooks/lancedb_reranking.ipynb
|
- Example: notebooks/lancedb_reranking.ipynb
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility:
|
- Versioning & Reproducibility:
|
||||||
- sync API: notebooks/reproducibility.ipynb
|
- sync API: notebooks/reproducibility.ipynb
|
||||||
- async API: notebooks/reproducibility_async.ipynb
|
- async API: notebooks/reproducibility_async.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
@@ -311,6 +316,7 @@ nav:
|
|||||||
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
||||||
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
||||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||||
|
- Variables and secrets: embeddings/variables_and_secrets.md
|
||||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- Integrations:
|
- Integrations:
|
||||||
@@ -349,8 +355,8 @@ nav:
|
|||||||
- 🦀 Rust:
|
- 🦀 Rust:
|
||||||
- Overview: examples/examples_rust.md
|
- Overview: examples/examples_rust.md
|
||||||
- Studies:
|
- Studies:
|
||||||
- studies/overview.md
|
- studies/overview.md
|
||||||
- ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/
|
- ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/
|
||||||
- API reference:
|
- API reference:
|
||||||
- Overview: api_reference.md
|
- Overview: api_reference.md
|
||||||
- Python: python/python.md
|
- Python: python/python.md
|
||||||
|
|||||||
@@ -38,6 +38,13 @@ components:
|
|||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
index_name:
|
||||||
|
name: index_name
|
||||||
|
in: path
|
||||||
|
description: name of the index
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
responses:
|
responses:
|
||||||
invalid_request:
|
invalid_request:
|
||||||
description: Invalid request
|
description: Invalid request
|
||||||
@@ -485,3 +492,22 @@ paths:
|
|||||||
$ref: "#/components/responses/unauthorized"
|
$ref: "#/components/responses/unauthorized"
|
||||||
"404":
|
"404":
|
||||||
$ref: "#/components/responses/not_found"
|
$ref: "#/components/responses/not_found"
|
||||||
|
/v1/table/{name}/index/{index_name}/drop/:
|
||||||
|
post:
|
||||||
|
description: Drop an index from the table
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: Drop an index from the table
|
||||||
|
operationId: dropIndex
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
- $ref: "#/components/parameters/index_name"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Index successfully dropped
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
@@ -3,6 +3,7 @@ import * as vectordb from "vectordb";
|
|||||||
// --8<-- [end:import]
|
// --8<-- [end:import]
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
|
console.log("ann_indexes.ts: start");
|
||||||
// --8<-- [start:ingest]
|
// --8<-- [start:ingest]
|
||||||
const db = await vectordb.connect("data/sample-lancedb");
|
const db = await vectordb.connect("data/sample-lancedb");
|
||||||
|
|
||||||
@@ -49,5 +50,5 @@ import * as vectordb from "vectordb";
|
|||||||
.execute();
|
.execute();
|
||||||
// --8<-- [end:search3]
|
// --8<-- [end:search3]
|
||||||
|
|
||||||
console.log("Ann indexes: done");
|
console.log("ann_indexes.ts: done");
|
||||||
})();
|
})();
|
||||||
|
|||||||
@@ -107,7 +107,6 @@ const example = async () => {
|
|||||||
// --8<-- [start:search]
|
// --8<-- [start:search]
|
||||||
const query = await tbl.search([100, 100]).limit(2).execute();
|
const query = await tbl.search([100, 100]).limit(2).execute();
|
||||||
// --8<-- [end:search]
|
// --8<-- [end:search]
|
||||||
console.log(query);
|
|
||||||
|
|
||||||
// --8<-- [start:delete]
|
// --8<-- [start:delete]
|
||||||
await tbl.delete('item = "fizz"');
|
await tbl.delete('item = "fizz"');
|
||||||
@@ -119,8 +118,9 @@ const example = async () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
async function main() {
|
async function main() {
|
||||||
|
console.log("basic_legacy.ts: start");
|
||||||
await example();
|
await example();
|
||||||
console.log("Basic example: done");
|
console.log("basic_legacy.ts: done");
|
||||||
}
|
}
|
||||||
|
|
||||||
main();
|
main();
|
||||||
|
|||||||
@@ -55,6 +55,14 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
|||||||
|
|
||||||
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings.
|
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings.
|
||||||
|
|
||||||
|
!!! danger "Use sensitive keys to prevent leaking secrets"
|
||||||
|
To prevent leaking secrets, such as API keys, you should add any sensitive
|
||||||
|
parameters of an embedding function to the output of the
|
||||||
|
[sensitive_keys()][lancedb.embeddings.base.EmbeddingFunction.sensitive_keys] /
|
||||||
|
[getSensitiveKeys()](../../js/namespaces/embedding/classes/EmbeddingFunction/#getsensitivekeys)
|
||||||
|
method. This prevents users from accidentally instantiating the embedding
|
||||||
|
function with hard-coded secrets.
|
||||||
|
|
||||||
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
|
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|||||||
53
docs/src/embeddings/variables_and_secrets.md
Normal file
53
docs/src/embeddings/variables_and_secrets.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# Variable and Secrets
|
||||||
|
|
||||||
|
Most embedding configuration options are saved in the table's metadata. However,
|
||||||
|
this isn't always appropriate. For example, API keys should never be stored in the
|
||||||
|
metadata. Additionally, other configuration options might be best set at runtime,
|
||||||
|
such as the `device` configuration that controls whether to use GPU or CPU for
|
||||||
|
inference. If you hardcoded this to GPU, you wouldn't be able to run the code on
|
||||||
|
a server without one.
|
||||||
|
|
||||||
|
To handle these cases, you can set variables on the embedding registry and
|
||||||
|
reference them in the embedding configuration. These variables will be available
|
||||||
|
during the runtime of your program, but not saved in the table's metadata. When
|
||||||
|
the table is loaded from a different process, the variables must be set again.
|
||||||
|
|
||||||
|
To set a variable, use the `set_var()` / `setVar()` method on the embedding registry.
|
||||||
|
To reference a variable, use the syntax `$env:VARIABLE_NAME`. If there is a default
|
||||||
|
value, you can use the syntax `$env:VARIABLE_NAME:DEFAULT_VALUE`.
|
||||||
|
|
||||||
|
## Using variables to set secrets
|
||||||
|
|
||||||
|
Sensitive configuration, such as API keys, must either be set as environment
|
||||||
|
variables or using variables on the embedding registry. If you pass in a hardcoded
|
||||||
|
value, LanceDB will raise an error. Instead, if you want to set an API key via
|
||||||
|
configuration, use a variable:
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_secret"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Typescript"
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
--8<-- "nodejs/examples/embedding.test.ts:register_secret"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using variables to set the device parameter
|
||||||
|
|
||||||
|
Many embedding functions that run locally have a `device` parameter that controls
|
||||||
|
whether to use GPU or CPU for inference. Because not all computers have a GPU,
|
||||||
|
it's helpful to be able to set the `device` parameter at runtime, rather than
|
||||||
|
have it hard coded in the embedding configuration. To make it work even if the
|
||||||
|
variable isn't set, you could provide a default value of `cpu` in the embedding
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
Some embedding libraries even have a method to detect which devices are available,
|
||||||
|
which could be used to dynamically set the device at runtime. For example, in Python
|
||||||
|
you can check if a CUDA GPU is available using `torch.cuda.is_available()`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_device"
|
||||||
|
```
|
||||||
@@ -131,6 +131,20 @@ Return a brief description of the connection
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### dropAllTables()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
abstract dropAllTables(): Promise<void>
|
||||||
|
```
|
||||||
|
|
||||||
|
Drop all tables in the database.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<`void`>
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### dropTable()
|
### dropTable()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ when creating a table or adding data to it)
|
|||||||
This function converts an array of Record<String, any> (row-major JS objects)
|
This function converts an array of Record<String, any> (row-major JS objects)
|
||||||
to an Arrow Table (a columnar structure)
|
to an Arrow Table (a columnar structure)
|
||||||
|
|
||||||
Note that it currently does not support nulls.
|
|
||||||
|
|
||||||
If a schema is provided then it will be used to determine the resulting array
|
If a schema is provided then it will be used to determine the resulting array
|
||||||
types. Fields will also be reordered to fit the order defined by the schema.
|
types. Fields will also be reordered to fit the order defined by the schema.
|
||||||
|
|
||||||
@@ -31,6 +29,9 @@ If a schema is not provided then the types will be inferred and the field order
|
|||||||
will be controlled by the order of properties in the first record. If a type
|
will be controlled by the order of properties in the first record. If a type
|
||||||
is inferred it will always be nullable.
|
is inferred it will always be nullable.
|
||||||
|
|
||||||
|
If not all fields are found in the data, then a subset of the schema will be
|
||||||
|
returned.
|
||||||
|
|
||||||
If the input is empty then a schema must be provided to create an empty table.
|
If the input is empty then a schema must be provided to create an empty table.
|
||||||
|
|
||||||
When a schema is not specified then data types will be inferred. The inference
|
When a schema is not specified then data types will be inferred. The inference
|
||||||
@@ -38,6 +39,7 @@ rules are as follows:
|
|||||||
|
|
||||||
- boolean => Bool
|
- boolean => Bool
|
||||||
- number => Float64
|
- number => Float64
|
||||||
|
- bigint => Int64
|
||||||
- String => Utf8
|
- String => Utf8
|
||||||
- Buffer => Binary
|
- Buffer => Binary
|
||||||
- Record<String, any> => Struct
|
- Record<String, any> => Struct
|
||||||
|
|||||||
@@ -8,6 +8,14 @@
|
|||||||
|
|
||||||
## Properties
|
## Properties
|
||||||
|
|
||||||
|
### extraHeaders?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional extraHeaders: Record<string, string>;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### retryConfig?
|
### retryConfig?
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
## Properties
|
## Properties
|
||||||
|
|
||||||
### dataStorageVersion?
|
### ~~dataStorageVersion?~~
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
optional dataStorageVersion: string;
|
optional dataStorageVersion: string;
|
||||||
@@ -19,6 +19,10 @@ The version of the data storage format to use.
|
|||||||
The default is `stable`.
|
The default is `stable`.
|
||||||
Set to "legacy" to use the old format.
|
Set to "legacy" to use the old format.
|
||||||
|
|
||||||
|
#### Deprecated
|
||||||
|
|
||||||
|
Pass `new_table_data_storage_version` to storageOptions instead.
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
### embeddingFunction?
|
### embeddingFunction?
|
||||||
@@ -29,7 +33,7 @@ optional embeddingFunction: EmbeddingFunctionConfig;
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
### enableV2ManifestPaths?
|
### ~~enableV2ManifestPaths?~~
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
optional enableV2ManifestPaths: boolean;
|
optional enableV2ManifestPaths: boolean;
|
||||||
@@ -41,6 +45,10 @@ turning this on will make the dataset unreadable for older versions
|
|||||||
of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
||||||
use the LocalTable#migrateManifestPathsV2 method.
|
use the LocalTable#migrateManifestPathsV2 method.
|
||||||
|
|
||||||
|
#### Deprecated
|
||||||
|
|
||||||
|
Pass `new_table_enable_v2_manifest_paths` to storageOptions instead.
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
### existOk
|
### existOk
|
||||||
@@ -90,17 +98,3 @@ Options already set on the connection will be inherited by the table,
|
|||||||
but can be overridden here.
|
but can be overridden here.
|
||||||
|
|
||||||
The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||||
|
|
||||||
***
|
|
||||||
|
|
||||||
### useLegacyFormat?
|
|
||||||
|
|
||||||
```ts
|
|
||||||
optional useLegacyFormat: boolean;
|
|
||||||
```
|
|
||||||
|
|
||||||
If true then data files will be written with the legacy format
|
|
||||||
|
|
||||||
The default is false.
|
|
||||||
|
|
||||||
Deprecated. Use data storage version instead.
|
|
||||||
|
|||||||
@@ -8,6 +8,23 @@
|
|||||||
|
|
||||||
An embedding function that automatically creates vector representation for a given column.
|
An embedding function that automatically creates vector representation for a given column.
|
||||||
|
|
||||||
|
It's important subclasses pass the **original** options to the super constructor
|
||||||
|
and then pass those options to `resolveVariables` to resolve any variables before
|
||||||
|
using them.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
class MyEmbeddingFunction extends EmbeddingFunction {
|
||||||
|
constructor(options: {model: string, timeout: number}) {
|
||||||
|
super(optionsRaw);
|
||||||
|
const options = this.resolveVariables(optionsRaw);
|
||||||
|
this.model = options.model;
|
||||||
|
this.timeout = options.timeout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Extended by
|
## Extended by
|
||||||
|
|
||||||
- [`TextEmbeddingFunction`](TextEmbeddingFunction.md)
|
- [`TextEmbeddingFunction`](TextEmbeddingFunction.md)
|
||||||
@@ -82,12 +99,33 @@ The datatype of the embeddings
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### getSensitiveKeys()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
protected getSensitiveKeys(): string[]
|
||||||
|
```
|
||||||
|
|
||||||
|
Provide a list of keys in the function options that should be treated as
|
||||||
|
sensitive. If users pass raw values for these keys, they will be rejected.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`string`[]
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### init()?
|
### init()?
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
optional init(): Promise<void>
|
optional init(): Promise<void>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Optionally load any resources needed for the embedding function.
|
||||||
|
|
||||||
|
This method is called after the embedding function has been initialized
|
||||||
|
but before any embeddings are computed. It is useful for loading local models
|
||||||
|
or other resources that are needed for the embedding function to work.
|
||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Promise`<`void`>
|
`Promise`<`void`>
|
||||||
@@ -108,6 +146,24 @@ The number of dimensions of the embeddings
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### resolveVariables()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
protected resolveVariables(config): Partial<M>
|
||||||
|
```
|
||||||
|
|
||||||
|
Apply variables to the config.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **config**: `Partial`<`M`>
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Partial`<`M`>
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### sourceField()
|
### sourceField()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
@@ -134,37 +190,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
|
|||||||
### toJSON()
|
### toJSON()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
abstract toJSON(): Partial<M>
|
toJSON(): Record<string, any>
|
||||||
```
|
```
|
||||||
|
|
||||||
Convert the embedding function to a JSON object
|
Get the original arguments to the constructor, to serialize them so they
|
||||||
It is used to serialize the embedding function to the schema
|
can be used to recreate the embedding function later.
|
||||||
It's important that any object returned by this method contains all the necessary
|
|
||||||
information to recreate the embedding function
|
|
||||||
|
|
||||||
It should return the same object that was passed to the constructor
|
|
||||||
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
|
||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Partial`<`M`>
|
`Record`<`string`, `any`>
|
||||||
|
|
||||||
#### Example
|
|
||||||
|
|
||||||
```ts
|
|
||||||
class MyEmbeddingFunction extends EmbeddingFunction {
|
|
||||||
constructor(options: {model: string, timeout: number}) {
|
|
||||||
super();
|
|
||||||
this.model = options.model;
|
|
||||||
this.timeout = options.timeout;
|
|
||||||
}
|
|
||||||
toJSON() {
|
|
||||||
return {
|
|
||||||
model: this.model,
|
|
||||||
timeout: this.timeout,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,28 @@ getTableMetadata(functions): Map<string, string>
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### getVar()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
getVar(name): undefined | string
|
||||||
|
```
|
||||||
|
|
||||||
|
Get a variable.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **name**: `string`
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`undefined` \| `string`
|
||||||
|
|
||||||
|
#### See
|
||||||
|
|
||||||
|
[setVar](EmbeddingFunctionRegistry.md#setvar)
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### length()
|
### length()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
@@ -145,3 +167,31 @@ reset the registry to the initial state
|
|||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`void`
|
`void`
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### setVar()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
setVar(name, value): void
|
||||||
|
```
|
||||||
|
|
||||||
|
Set a variable. These can be accessed in the embedding function
|
||||||
|
configuration using the syntax `$var:variable_name`. If they are not
|
||||||
|
set, an error will be thrown letting you know which key is unset. If you
|
||||||
|
want to supply a default value, you can add an additional part in the
|
||||||
|
configuration like so: `$var:variable_name:default_value`. Default values
|
||||||
|
can be used for runtime configurations that are not sensitive, such as
|
||||||
|
whether to use a GPU for inference.
|
||||||
|
|
||||||
|
The name must not contain colons. The default value can contain colons.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **name**: `string`
|
||||||
|
|
||||||
|
* **value**: `string`
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`void`
|
||||||
|
|||||||
@@ -114,12 +114,37 @@ abstract generateEmbeddings(texts, ...args): Promise<number[][] | Float32Array[]
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### getSensitiveKeys()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
protected getSensitiveKeys(): string[]
|
||||||
|
```
|
||||||
|
|
||||||
|
Provide a list of keys in the function options that should be treated as
|
||||||
|
sensitive. If users pass raw values for these keys, they will be rejected.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`string`[]
|
||||||
|
|
||||||
|
#### Inherited from
|
||||||
|
|
||||||
|
[`EmbeddingFunction`](EmbeddingFunction.md).[`getSensitiveKeys`](EmbeddingFunction.md#getsensitivekeys)
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### init()?
|
### init()?
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
optional init(): Promise<void>
|
optional init(): Promise<void>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Optionally load any resources needed for the embedding function.
|
||||||
|
|
||||||
|
This method is called after the embedding function has been initialized
|
||||||
|
but before any embeddings are computed. It is useful for loading local models
|
||||||
|
or other resources that are needed for the embedding function to work.
|
||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Promise`<`void`>
|
`Promise`<`void`>
|
||||||
@@ -148,6 +173,28 @@ The number of dimensions of the embeddings
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### resolveVariables()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
protected resolveVariables(config): Partial<M>
|
||||||
|
```
|
||||||
|
|
||||||
|
Apply variables to the config.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **config**: `Partial`<`M`>
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Partial`<`M`>
|
||||||
|
|
||||||
|
#### Inherited from
|
||||||
|
|
||||||
|
[`EmbeddingFunction`](EmbeddingFunction.md).[`resolveVariables`](EmbeddingFunction.md#resolvevariables)
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### sourceField()
|
### sourceField()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
@@ -173,37 +220,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
|
|||||||
### toJSON()
|
### toJSON()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
abstract toJSON(): Partial<M>
|
toJSON(): Record<string, any>
|
||||||
```
|
```
|
||||||
|
|
||||||
Convert the embedding function to a JSON object
|
Get the original arguments to the constructor, to serialize them so they
|
||||||
It is used to serialize the embedding function to the schema
|
can be used to recreate the embedding function later.
|
||||||
It's important that any object returned by this method contains all the necessary
|
|
||||||
information to recreate the embedding function
|
|
||||||
|
|
||||||
It should return the same object that was passed to the constructor
|
|
||||||
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
|
||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Partial`<`M`>
|
`Record`<`string`, `any`>
|
||||||
|
|
||||||
#### Example
|
|
||||||
|
|
||||||
```ts
|
|
||||||
class MyEmbeddingFunction extends EmbeddingFunction {
|
|
||||||
constructor(options: {model: string, timeout: number}) {
|
|
||||||
super();
|
|
||||||
this.model = options.model;
|
|
||||||
this.timeout = options.timeout;
|
|
||||||
}
|
|
||||||
toJSON() {
|
|
||||||
return {
|
|
||||||
model: this.model,
|
|
||||||
timeout: this.timeout,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Inherited from
|
#### Inherited from
|
||||||
|
|
||||||
|
|||||||
@@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D
|
|||||||
|
|
||||||
First, we connect to a LanceDB database.
|
First, we connect to a LanceDB database.
|
||||||
|
|
||||||
|
=== "Sync API"
|
||||||
|
|
||||||
|
```py
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Async API"
|
||||||
|
|
||||||
|
```py
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async"
|
||||||
|
```
|
||||||
|
|
||||||
```py
|
|
||||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
|
||||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
|
||||||
```
|
|
||||||
|
|
||||||
We can load a Polars `DataFrame` to LanceDB directly.
|
We can load a Polars `DataFrame` to LanceDB directly.
|
||||||
|
|
||||||
```py
|
=== "Sync API"
|
||||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
|
||||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
```py
|
||||||
```
|
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Async API"
|
||||||
|
|
||||||
|
```py
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:create_table_polars_async"
|
||||||
|
```
|
||||||
|
|
||||||
We can now perform similarity search via the LanceDB Python API.
|
We can now perform similarity search via the LanceDB Python API.
|
||||||
|
|
||||||
```py
|
=== "Sync API"
|
||||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
|
||||||
```
|
```py
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Async API"
|
||||||
|
|
||||||
|
```py
|
||||||
|
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async"
|
||||||
|
```
|
||||||
|
|
||||||
In addition to the selected columns, LanceDB also returns a vector
|
In addition to the selected columns, LanceDB also returns a vector
|
||||||
and also the `_distance` column which is the distance between the query
|
and also the `_distance` column which is the distance between the query
|
||||||
@@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table
|
|||||||
to a DataFrame is because the table can potentially be way larger
|
to a DataFrame is because the table can potentially be way larger
|
||||||
than memory, and Polars LazyFrames allow us to work with such
|
than memory, and Polars LazyFrames allow us to work with such
|
||||||
larger-than-memory datasets by not loading it into memory all at once.
|
larger-than-memory datasets by not loading it into memory all at once.
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,19 @@
|
|||||||
|
|
||||||
[Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python.
|
[Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python.
|
||||||
LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting.
|
LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting.
|
||||||
|
Using [LanceModel][lancedb.pydantic.LanceModel], users can seamlessly
|
||||||
|
integrate Pydantic with the rest of the LanceDB APIs.
|
||||||
|
|
||||||
## Schema
|
```python
|
||||||
|
|
||||||
LanceDB supports to create Apache Arrow Schema from a
|
--8<-- "python/python/tests/docs/test_pydantic_integration.py:imports"
|
||||||
[Pydantic BaseModel](https://docs.pydantic.dev/latest/api/main/#pydantic.main.BaseModel)
|
|
||||||
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method.
|
--8<-- "python/python/tests/docs/test_pydantic_integration.py:base_model"
|
||||||
|
|
||||||
|
--8<-- "python/python/tests/docs/test_pydantic_integration.py:set_url"
|
||||||
|
--8<-- "python/python/tests/docs/test_pydantic_integration.py:base_example"
|
||||||
|
```
|
||||||
|
|
||||||
::: lancedb.pydantic.pydantic_to_schema
|
|
||||||
|
|
||||||
## Vector Field
|
## Vector Field
|
||||||
|
|
||||||
@@ -34,3 +39,9 @@ Current supported type conversions:
|
|||||||
| `list` | `pyarrow.List` |
|
| `list` | `pyarrow.List` |
|
||||||
| `BaseModel` | `pyarrow.Struct` |
|
| `BaseModel` | `pyarrow.Struct` |
|
||||||
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||||
|
|
||||||
|
LanceDB supports to create Apache Arrow Schema from a
|
||||||
|
[Pydantic BaseModel][pydantic.BaseModel]
|
||||||
|
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method.
|
||||||
|
|
||||||
|
::: lancedb.pydantic.pydantic_to_schema
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ async function setup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async () => {
|
async () => {
|
||||||
|
console.log("search_legacy.ts: start");
|
||||||
await setup();
|
await setup();
|
||||||
|
|
||||||
// --8<-- [start:search1]
|
// --8<-- [start:search1]
|
||||||
@@ -37,5 +38,5 @@ async () => {
|
|||||||
.execute();
|
.execute();
|
||||||
// --8<-- [end:search2]
|
// --8<-- [end:search2]
|
||||||
|
|
||||||
console.log("search: done");
|
console.log("search_legacy.ts: done");
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import * as vectordb from "vectordb";
|
import * as vectordb from "vectordb";
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
|
console.log("sql_legacy.ts: start");
|
||||||
const db = await vectordb.connect("data/sample-lancedb");
|
const db = await vectordb.connect("data/sample-lancedb");
|
||||||
|
|
||||||
let data = [];
|
let data = [];
|
||||||
@@ -34,5 +35,5 @@ import * as vectordb from "vectordb";
|
|||||||
await tbl.filter("id = 10").limit(10).execute();
|
await tbl.filter("id = 10").limit(10).execute();
|
||||||
// --8<-- [end:sql_search]
|
// --8<-- [end:sql_search]
|
||||||
|
|
||||||
console.log("SQL search: done");
|
console.log("sql_legacy.ts: done");
|
||||||
})();
|
})();
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ excluded_globs = [
|
|||||||
"../src/python/duckdb.md",
|
"../src/python/duckdb.md",
|
||||||
"../src/python/pandas_and_pyarrow.md",
|
"../src/python/pandas_and_pyarrow.md",
|
||||||
"../src/python/polars_arrow.md",
|
"../src/python/polars_arrow.md",
|
||||||
|
"../src/python/pydantic.md",
|
||||||
"../src/embeddings/*.md",
|
"../src/embeddings/*.md",
|
||||||
"../src/concepts/*.md",
|
"../src/concepts/*.md",
|
||||||
"../src/ann_indexes.md",
|
"../src/ann_indexes.md",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.15.1-beta.3</version>
|
<version>0.16.1-beta.3</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.15.1-beta.3</version>
|
<version>0.16.1-beta.3</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>LanceDB Parent</name>
|
<name>LanceDB Parent</name>
|
||||||
|
|||||||
68
node/package-lock.json
generated
68
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -52,14 +52,14 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.3",
|
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.3",
|
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.3",
|
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.3"
|
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.3"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
@@ -330,9 +330,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-2GinbODdSsUc+zJQ4BFZPsdraPWHJpDpGf7CsZIqfokwxIRnzVzFfQy+SZhmNhKzFkmtW21yWw6wrJ4FgS7Qtw==",
|
"integrity": "sha512-k2dfDNvoFjZuF8RCkFX9yFkLIg292mFg+o6IUeXndlikhABi8F+NbRODGUxJf3QUioks2tGF831KFoV5oQyeEA==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -343,9 +343,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-nRp5eN6yvx5kvfDEQuh3EHCmwjVNCIm7dXoV6BasepFkOoaHHmjKSIUFW7HjtJOfdFbb+r8UjBJx4cN6Jh2iFg==",
|
"integrity": "sha512-pYvwcAXBB3MXxa2kvK8PxMoEsaE+EFld5pky6dDo6qJQVepUz9pi/e1FTLxW6m0mgwtRj52P6xe55sj1Yln9Qw==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -356,9 +356,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-JOyD7Nt3RSfHGWNQjHbZMHsIw1cVWPySxbtDmDqk5QH5IfgDNZLiz/sNbROuQkNvc5SsC6wUmhBUwWBETzW7/g==",
|
"integrity": "sha512-BS4rnBtKGJlEdbYgOe85mGhviQaSfEXl8qw0fh0ml8E0qbi5RuLtwfTFMe3yAKSOnNAvaJISqXQyUN7hzkYkUQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -369,9 +369,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-4jTHl1i/4e7wP2U7RMjHr87/gsGJ9tfRJ4ljQIfV+LkA7ROMd/TA5XSnvPesQCDjPNRI4wAyb/BmK18V96VqBg==",
|
"integrity": "sha512-/F1mzpgSipfXjeaXJx5c0zLPOipPKnSPIpYviSdLU2Ahm1aHLweW1UsoiUoRkBkvEcVrZfHxL64vasey2I0P7Q==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -382,9 +382,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-odrNqB/bGL+sweZi6ed9sKft/H5/bca/tDVG/Y39xCJ6swPWxXQK2Zpn7EjqbccI2p2zkrhKcOUBO/bEkOqQng==",
|
"integrity": "sha512-zGn2Oby8GAQYG7+dqFVi2DDzli2/GAAY7lwPoYbPlyVytcdTlXRsxea1XiT1jzZmyKIlrxA/XXSRsmRq4n1j1w==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -395,9 +395,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-Zml4KgQWzkkMBHZiD30Gs3N56BT5xO01efwO/Q2qB7JKw5Vy9pa6SgFf9woBvKFQRY73fiKqafy+BmGHTgozNg==",
|
"integrity": "sha512-MXYvI7dL+0QtWGDuliUUaEp/XQN+hSndtDc8wlAMyI0lOzmTvC7/C3OZQcMKf6JISZuNS71OVzVTYDYSab9aXw==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -408,9 +408,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-win32-arm64-msvc": {
|
"node_modules/@lancedb/vectordb-win32-arm64-msvc": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-3BWkK+8JP+js/KoTad7bm26NTR5pq2tvXJkrFB0eaFfsIuUXebS+LIBF22f39He2WMpq3YojT0bMnYxp8qvRkQ==",
|
"integrity": "sha512-1dbUSg+Mi+0W8JAUXqNWC+uCr0RUqVHhxFVGLSlprqZ8qFJYQ61jFSZr4onOYj9Ta1n6tUb3Nc4acxf3vXXPmw==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -421,9 +421,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.15.1-beta.3.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.16.1-beta.3.tgz",
|
||||||
"integrity": "sha512-jr8SEisYAX7pQHIbxIDJPkANmxWh5Yohm8ELbMgu76IvLI7bsS7sB9ID+kcj1SiS5m4V6OG2BO1FrEYbPLZ6Dg==",
|
"integrity": "sha512-K9oT47zKnFoCEB/JjVKG+w+L0GOMDsPPln+B2TvefAXAWrvweCN2H4LUdsBYCTnntzy80OJCwwH3OwX07M1Y3g==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"private": false,
|
"private": false,
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
@@ -92,13 +92,13 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.3",
|
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.3",
|
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.3",
|
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.3",
|
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.3",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.3"
|
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ const {
|
|||||||
tableSchema,
|
tableSchema,
|
||||||
tableAddColumns,
|
tableAddColumns,
|
||||||
tableAlterColumns,
|
tableAlterColumns,
|
||||||
tableDropColumns
|
tableDropColumns,
|
||||||
|
tableDropIndex
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||||
} = require("../native.js");
|
} = require("../native.js");
|
||||||
|
|
||||||
@@ -604,6 +605,13 @@ export interface Table<T = number[]> {
|
|||||||
*/
|
*/
|
||||||
dropColumns(columnNames: string[]): Promise<void>
|
dropColumns(columnNames: string[]): Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Drop an index from the table
|
||||||
|
*
|
||||||
|
* @param indexName The name of the index to drop
|
||||||
|
*/
|
||||||
|
dropIndex(indexName: string): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instrument the behavior of this Table with middleware.
|
* Instrument the behavior of this Table with middleware.
|
||||||
*
|
*
|
||||||
@@ -1206,6 +1214,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
return tableDropColumns.call(this._tbl, columnNames);
|
return tableDropColumns.call(this._tbl, columnNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async dropIndex(indexName: string): Promise<void> {
|
||||||
|
return tableDropIndex.call(this._tbl, indexName);
|
||||||
|
}
|
||||||
|
|
||||||
withMiddleware(middleware: HttpMiddleware): Table<T> {
|
withMiddleware(middleware: HttpMiddleware): Table<T> {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -471,6 +471,18 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
async dropIndex (index_name: string): Promise<void> {
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${encodeURIComponent(this._name)}/index/${encodeURIComponent(index_name)}/drop/`
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(
|
||||||
|
`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${await res.body()}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async countRows (filter?: string): Promise<number> {
|
async countRows (filter?: string): Promise<number> {
|
||||||
const result = await this._client.post(`/v1/table/${encodeURIComponent(this._name)}/count_rows/`, {
|
const result = await this._client.post(`/v1/table/${encodeURIComponent(this._name)}/count_rows/`, {
|
||||||
|
|||||||
@@ -894,6 +894,27 @@ describe("LanceDB client", function () {
|
|||||||
expect(stats.distanceType).to.equal("l2");
|
expect(stats.distanceType).to.equal("l2");
|
||||||
expect(stats.numIndices).to.equal(1);
|
expect(stats.numIndices).to.equal(1);
|
||||||
}).timeout(50_000);
|
}).timeout(50_000);
|
||||||
|
|
||||||
|
// not yet implemented
|
||||||
|
// it("can drop index", async function () {
|
||||||
|
// const uri = await createTestDB(32, 300);
|
||||||
|
// const con = await lancedb.connect(uri);
|
||||||
|
// const table = await con.openTable("vectors");
|
||||||
|
// await table.createIndex({
|
||||||
|
// type: "ivf_pq",
|
||||||
|
// column: "vector",
|
||||||
|
// num_partitions: 2,
|
||||||
|
// max_iters: 2,
|
||||||
|
// num_sub_vectors: 2
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// const indices = await table.listIndices();
|
||||||
|
// expect(indices).to.have.lengthOf(1);
|
||||||
|
// expect(indices[0].name).to.equal("vector_idx");
|
||||||
|
//
|
||||||
|
// await table.dropIndex("vector_idx");
|
||||||
|
// expect(await table.listIndices()).to.have.lengthOf(0);
|
||||||
|
// }).timeout(50_000);
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("when using a custom embedding function", function () {
|
describe("when using a custom embedding function", function () {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.15.1-beta.3"
|
version = "0.16.1-beta.3"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
Float64,
|
Float64,
|
||||||
Struct,
|
Struct,
|
||||||
List,
|
List,
|
||||||
|
Int16,
|
||||||
Int32,
|
Int32,
|
||||||
Int64,
|
Int64,
|
||||||
Float,
|
Float,
|
||||||
@@ -108,13 +109,16 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
false,
|
false,
|
||||||
),
|
),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const table = (await tableCreationMethod(
|
const table = (await tableCreationMethod(
|
||||||
records,
|
records,
|
||||||
recordsReversed,
|
recordsReversed,
|
||||||
schema,
|
schema,
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
)) as any;
|
)) as any;
|
||||||
|
|
||||||
|
// We expect deterministic ordering of the fields
|
||||||
|
expect(table.schema.names).toEqual(schema.names);
|
||||||
|
|
||||||
schema.fields.forEach(
|
schema.fields.forEach(
|
||||||
(
|
(
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
@@ -141,13 +145,13 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
describe("The function makeArrowTable", function () {
|
describe("The function makeArrowTable", function () {
|
||||||
it("will use data types from a provided schema instead of inference", async function () {
|
it("will use data types from a provided schema instead of inference", async function () {
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field("a", new Int32()),
|
new Field("a", new Int32(), false),
|
||||||
new Field("b", new Float32()),
|
new Field("b", new Float32(), true),
|
||||||
new Field(
|
new Field(
|
||||||
"c",
|
"c",
|
||||||
new FixedSizeList(3, new Field("item", new Float16())),
|
new FixedSizeList(3, new Field("item", new Float16())),
|
||||||
),
|
),
|
||||||
new Field("d", new Int64()),
|
new Field("d", new Int64(), true),
|
||||||
]);
|
]);
|
||||||
const table = makeArrowTable(
|
const table = makeArrowTable(
|
||||||
[
|
[
|
||||||
@@ -165,12 +169,15 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
expect(actual.numRows).toBe(3);
|
expect(actual.numRows).toBe(3);
|
||||||
const actualSchema = actual.schema;
|
const actualSchema = actual.schema;
|
||||||
expect(actualSchema).toEqual(schema);
|
expect(actualSchema).toEqual(schema);
|
||||||
|
expect(table.getChild("a")?.toJSON()).toEqual([1, 4, 7]);
|
||||||
|
expect(table.getChild("b")?.toJSON()).toEqual([2, 5, 8]);
|
||||||
|
expect(table.getChild("d")?.toJSON()).toEqual([9n, 10n, null]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("will assume the column `vector` is FixedSizeList<Float32> by default", async function () {
|
it("will assume the column `vector` is FixedSizeList<Float32> by default", async function () {
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field("a", new Float(Precision.DOUBLE), true),
|
new Field("a", new Float(Precision.DOUBLE), true),
|
||||||
new Field("b", new Float(Precision.DOUBLE), true),
|
new Field("b", new Int64(), true),
|
||||||
new Field(
|
new Field(
|
||||||
"vector",
|
"vector",
|
||||||
new FixedSizeList(
|
new FixedSizeList(
|
||||||
@@ -181,9 +188,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
),
|
),
|
||||||
]);
|
]);
|
||||||
const table = makeArrowTable([
|
const table = makeArrowTable([
|
||||||
{ a: 1, b: 2, vector: [1, 2, 3] },
|
{ a: 1, b: 2n, vector: [1, 2, 3] },
|
||||||
{ a: 4, b: 5, vector: [4, 5, 6] },
|
{ a: 4, b: 5n, vector: [4, 5, 6] },
|
||||||
{ a: 7, b: 8, vector: [7, 8, 9] },
|
{ a: 7, b: 8n, vector: [7, 8, 9] },
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const buf = await fromTableToBuffer(table);
|
const buf = await fromTableToBuffer(table);
|
||||||
@@ -193,6 +200,19 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
expect(actual.numRows).toBe(3);
|
expect(actual.numRows).toBe(3);
|
||||||
const actualSchema = actual.schema;
|
const actualSchema = actual.schema;
|
||||||
expect(actualSchema).toEqual(schema);
|
expect(actualSchema).toEqual(schema);
|
||||||
|
|
||||||
|
expect(table.getChild("a")?.toJSON()).toEqual([1, 4, 7]);
|
||||||
|
expect(table.getChild("b")?.toJSON()).toEqual([2n, 5n, 8n]);
|
||||||
|
expect(
|
||||||
|
table
|
||||||
|
.getChild("vector")
|
||||||
|
?.toJSON()
|
||||||
|
.map((v) => v.toJSON()),
|
||||||
|
).toEqual([
|
||||||
|
[1, 2, 3],
|
||||||
|
[4, 5, 6],
|
||||||
|
[7, 8, 9],
|
||||||
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("can support multiple vector columns", async function () {
|
it("can support multiple vector columns", async function () {
|
||||||
@@ -206,7 +226,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
),
|
),
|
||||||
new Field(
|
new Field(
|
||||||
"vec2",
|
"vec2",
|
||||||
new FixedSizeList(3, new Field("item", new Float16(), true)),
|
new FixedSizeList(3, new Field("item", new Float64(), true)),
|
||||||
true,
|
true,
|
||||||
),
|
),
|
||||||
]);
|
]);
|
||||||
@@ -219,7 +239,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
{
|
{
|
||||||
vectorColumns: {
|
vectorColumns: {
|
||||||
vec1: { type: new Float16() },
|
vec1: { type: new Float16() },
|
||||||
vec2: { type: new Float16() },
|
vec2: { type: new Float64() },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -307,6 +327,53 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("will allow subsets of columns if nullable", async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("a", new Int64(), true),
|
||||||
|
new Field(
|
||||||
|
"s",
|
||||||
|
new Struct([
|
||||||
|
new Field("x", new Int32(), true),
|
||||||
|
new Field("y", new Int32(), true),
|
||||||
|
]),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
new Field("d", new Int16(), true),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const table = makeArrowTable([{ a: 1n }], { schema });
|
||||||
|
expect(table.numCols).toBe(1);
|
||||||
|
expect(table.numRows).toBe(1);
|
||||||
|
|
||||||
|
const table2 = makeArrowTable([{ a: 1n, d: 2 }], { schema });
|
||||||
|
expect(table2.numCols).toBe(2);
|
||||||
|
|
||||||
|
const table3 = makeArrowTable([{ s: { y: 3 } }], { schema });
|
||||||
|
expect(table3.numCols).toBe(1);
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field("s", new Struct([new Field("y", new Int32(), true)]), true),
|
||||||
|
]);
|
||||||
|
expect(table3.schema).toEqual(expectedSchema);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("will work even if columns are sparsely provided", async function () {
|
||||||
|
const sparseRecords = [{ a: 1n }, { b: 2n }, { c: 3n }, { d: 4n }];
|
||||||
|
const table = makeArrowTable(sparseRecords);
|
||||||
|
expect(table.numCols).toBe(4);
|
||||||
|
expect(table.numRows).toBe(4);
|
||||||
|
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("a", new Int64(), true),
|
||||||
|
new Field("b", new Int32(), true),
|
||||||
|
new Field("c", new Int64(), true),
|
||||||
|
new Field("d", new Int16(), true),
|
||||||
|
]);
|
||||||
|
const table2 = makeArrowTable(sparseRecords, { schema });
|
||||||
|
expect(table2.numCols).toBe(4);
|
||||||
|
expect(table2.numRows).toBe(4);
|
||||||
|
expect(table2.schema).toEqual(schema);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
class DummyEmbedding extends EmbeddingFunction<string> {
|
class DummyEmbedding extends EmbeddingFunction<string> {
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ describe("when connecting", () => {
|
|||||||
it("should connect", async () => {
|
it("should connect", async () => {
|
||||||
const db = await connect(tmpDir.name);
|
const db = await connect(tmpDir.name);
|
||||||
expect(db.display()).toBe(
|
expect(db.display()).toBe(
|
||||||
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`,
|
`ListingDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should allow read consistency interval to be specified", async () => {
|
it("should allow read consistency interval to be specified", async () => {
|
||||||
const db = await connect(tmpDir.name, { readConsistencyInterval: 5 });
|
const db = await connect(tmpDir.name, { readConsistencyInterval: 5 });
|
||||||
expect(db.display()).toBe(
|
expect(db.display()).toBe(
|
||||||
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`,
|
`ListingDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -61,6 +61,26 @@ describe("given a connection", () => {
|
|||||||
await expect(tbl.countRows()).resolves.toBe(1);
|
await expect(tbl.countRows()).resolves.toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should be able to drop tables`", async () => {
|
||||||
|
await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||||
|
await db.createTable("test2", [{ id: 1 }, { id: 2 }]);
|
||||||
|
await db.createTable("test3", [{ id: 1 }, { id: 2 }]);
|
||||||
|
|
||||||
|
await expect(db.tableNames()).resolves.toEqual(["test", "test2", "test3"]);
|
||||||
|
|
||||||
|
await db.dropTable("test2");
|
||||||
|
|
||||||
|
await expect(db.tableNames()).resolves.toEqual(["test", "test3"]);
|
||||||
|
|
||||||
|
await db.dropAllTables();
|
||||||
|
|
||||||
|
await expect(db.tableNames()).resolves.toEqual([]);
|
||||||
|
|
||||||
|
// Make sure we can still create more tables after dropping all
|
||||||
|
|
||||||
|
await db.createTable("test4", [{ id: 1 }, { id: 2 }]);
|
||||||
|
});
|
||||||
|
|
||||||
it("should fail if creating table twice, unless overwrite is true", async () => {
|
it("should fail if creating table twice, unless overwrite is true", async () => {
|
||||||
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||||
await expect(tbl.countRows()).resolves.toBe(2);
|
await expect(tbl.countRows()).resolves.toBe(2);
|
||||||
@@ -96,14 +116,15 @@ describe("given a connection", () => {
|
|||||||
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
|
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
|
||||||
|
|
||||||
// Create in v1 mode
|
// Create in v1 mode
|
||||||
let table = await db.createTable("test", data, { useLegacyFormat: true });
|
let table = await db.createTable("test", data, {
|
||||||
|
storageOptions: { newTableDataStorageVersion: "legacy" },
|
||||||
|
});
|
||||||
|
|
||||||
const isV2 = async (table: Table) => {
|
const isV2 = async (table: Table) => {
|
||||||
const data = await table
|
const data = await table
|
||||||
.query()
|
.query()
|
||||||
.limit(10000)
|
.limit(10000)
|
||||||
.toArrow({ maxBatchLength: 100000 });
|
.toArrow({ maxBatchLength: 100000 });
|
||||||
console.log(data.batches.length);
|
|
||||||
return data.batches.length < 5;
|
return data.batches.length < 5;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -122,7 +143,7 @@ describe("given a connection", () => {
|
|||||||
const schema = new Schema([new Field("id", new Float64(), true)]);
|
const schema = new Schema([new Field("id", new Float64(), true)]);
|
||||||
|
|
||||||
table = await db.createEmptyTable("test_v2_empty", schema, {
|
table = await db.createEmptyTable("test_v2_empty", schema, {
|
||||||
useLegacyFormat: false,
|
storageOptions: { newTableDataStorageVersion: "stable" },
|
||||||
});
|
});
|
||||||
|
|
||||||
await table.add(data);
|
await table.add(data);
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import {
|
|||||||
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
||||||
import { getRegistry, register } from "../lancedb/embedding/registry";
|
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||||
|
|
||||||
|
const testOpenAIInteg = process.env.OPENAI_API_KEY == null ? test.skip : test;
|
||||||
|
|
||||||
describe("embedding functions", () => {
|
describe("embedding functions", () => {
|
||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -29,9 +31,6 @@ describe("embedding functions", () => {
|
|||||||
|
|
||||||
it("should be able to create a table with an embedding function", async () => {
|
it("should be able to create a table with an embedding function", async () => {
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
@@ -75,9 +74,6 @@ describe("embedding functions", () => {
|
|||||||
it("should be able to append and upsert using embedding function", async () => {
|
it("should be able to append and upsert using embedding function", async () => {
|
||||||
@register()
|
@register()
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
@@ -143,9 +139,6 @@ describe("embedding functions", () => {
|
|||||||
it("should be able to create an empty table with an embedding function", async () => {
|
it("should be able to create an empty table with an embedding function", async () => {
|
||||||
@register()
|
@register()
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
@@ -194,9 +187,6 @@ describe("embedding functions", () => {
|
|||||||
it("should error when appending to a table with an unregistered embedding function", async () => {
|
it("should error when appending to a table with an unregistered embedding function", async () => {
|
||||||
@register("mock")
|
@register("mock")
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
@@ -241,13 +231,35 @@ describe("embedding functions", () => {
|
|||||||
`Function "mock" not found in registry`,
|
`Function "mock" not found in registry`,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
testOpenAIInteg("propagates variables through all methods", async () => {
|
||||||
|
delete process.env.OPENAI_API_KEY;
|
||||||
|
const registry = getRegistry();
|
||||||
|
registry.setVar("openai_api_key", "sk-...");
|
||||||
|
const func = registry.get("openai")?.create({
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
apiKey: "$var:openai_api_key",
|
||||||
|
}) as EmbeddingFunction;
|
||||||
|
|
||||||
|
const db = await connect("memory://");
|
||||||
|
const wordsSchema = LanceSchema({
|
||||||
|
text: func.sourceField(new Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
const tbl = await db.createEmptyTable("words", wordsSchema, {
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]);
|
||||||
|
|
||||||
|
const query = "greetings";
|
||||||
|
const actual = (await tbl.search(query).limit(1).toArray())[0];
|
||||||
|
expect(actual).toHaveProperty("text");
|
||||||
|
});
|
||||||
|
|
||||||
test.each([new Float16(), new Float32(), new Float64()])(
|
test.each([new Float16(), new Float32(), new Float64()])(
|
||||||
"should be able to provide manual embeddings with multiple float datatype",
|
"should be able to provide manual embeddings with multiple float datatype",
|
||||||
async (floatType) => {
|
async (floatType) => {
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
@@ -292,10 +304,6 @@ describe("embedding functions", () => {
|
|||||||
async (floatType) => {
|
async (floatType) => {
|
||||||
@register("test1")
|
@register("test1")
|
||||||
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
|
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddingDataType(): Float {
|
embeddingDataType(): Float {
|
||||||
return floatType;
|
return floatType;
|
||||||
}
|
}
|
||||||
@@ -310,9 +318,6 @@ describe("embedding functions", () => {
|
|||||||
}
|
}
|
||||||
@register("test")
|
@register("test")
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,11 @@ import * as arrow18 from "apache-arrow-18";
|
|||||||
import * as tmp from "tmp";
|
import * as tmp from "tmp";
|
||||||
|
|
||||||
import { connect } from "../lancedb";
|
import { connect } from "../lancedb";
|
||||||
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
import {
|
||||||
|
EmbeddingFunction,
|
||||||
|
FunctionOptions,
|
||||||
|
LanceSchema,
|
||||||
|
} from "../lancedb/embedding";
|
||||||
import { getRegistry, register } from "../lancedb/embedding/registry";
|
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||||
|
|
||||||
describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => {
|
describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => {
|
||||||
@@ -39,11 +43,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
|||||||
it("should register a new item to the registry", async () => {
|
it("should register a new item to the registry", async () => {
|
||||||
@register("mock-embedding")
|
@register("mock-embedding")
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {
|
|
||||||
someText: "hello",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
@@ -89,11 +88,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
|||||||
});
|
});
|
||||||
test("should error if registering with the same name", async () => {
|
test("should error if registering with the same name", async () => {
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {
|
|
||||||
someText: "hello",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
@@ -114,13 +108,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
|||||||
});
|
});
|
||||||
test("schema should contain correct metadata", async () => {
|
test("schema should contain correct metadata", async () => {
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
constructor(args: FunctionOptions = {}) {
|
||||||
return {
|
|
||||||
someText: "hello",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
constructor() {
|
|
||||||
super();
|
super();
|
||||||
|
this.resolveVariables(args);
|
||||||
}
|
}
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
@@ -132,7 +122,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
|||||||
return data.map(() => [1, 2, 3]);
|
return data.map(() => [1, 2, 3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const func = new MockEmbeddingFunction();
|
const func = new MockEmbeddingFunction({ someText: "hello" });
|
||||||
|
|
||||||
const schema = LanceSchema({
|
const schema = LanceSchema({
|
||||||
id: new arrow.Int32(),
|
id: new arrow.Int32(),
|
||||||
@@ -155,3 +145,79 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
|||||||
expect(schema.metadata).toEqual(expectedMetadata);
|
expect(schema.metadata).toEqual(expectedMetadata);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("Registry.setVar", () => {
|
||||||
|
const registry = getRegistry();
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
@register("mock-embedding")
|
||||||
|
// biome-ignore lint/correctness/noUnusedVariables :
|
||||||
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
constructor(optionsRaw: FunctionOptions = {}) {
|
||||||
|
super();
|
||||||
|
const options = this.resolveVariables(optionsRaw);
|
||||||
|
|
||||||
|
expect(optionsRaw["someKey"].startsWith("$var:someName")).toBe(true);
|
||||||
|
expect(options["someKey"]).toBe("someValue");
|
||||||
|
|
||||||
|
if (options["secretKey"]) {
|
||||||
|
expect(optionsRaw["secretKey"]).toBe("$var:secretKey");
|
||||||
|
expect(options["secretKey"]).toBe("mySecret");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
return data.map(() => [1, 2, 3]);
|
||||||
|
}
|
||||||
|
embeddingDataType() {
|
||||||
|
return new arrow18.Float32() as apiArrow.Float;
|
||||||
|
}
|
||||||
|
protected getSensitiveKeys() {
|
||||||
|
return ["secretKey"];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
afterEach(() => {
|
||||||
|
registry.reset();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("Should error if the variable is not set", () => {
|
||||||
|
console.log(registry.get("mock-embedding"));
|
||||||
|
expect(() =>
|
||||||
|
registry.get("mock-embedding")!.create({ someKey: "$var:someName" }),
|
||||||
|
).toThrow('Variable "someName" not found');
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should use default values if not set", () => {
|
||||||
|
registry
|
||||||
|
.get("mock-embedding")!
|
||||||
|
.create({ someKey: "$var:someName:someValue" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should set a variable that the embedding function understand", () => {
|
||||||
|
registry.setVar("someName", "someValue");
|
||||||
|
registry.get("mock-embedding")!.create({ someKey: "$var:someName" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should reject secrets that aren't passed as variables", () => {
|
||||||
|
registry.setVar("someName", "someValue");
|
||||||
|
expect(() =>
|
||||||
|
registry
|
||||||
|
.get("mock-embedding")!
|
||||||
|
.create({ secretKey: "someValue", someKey: "$var:someName" }),
|
||||||
|
).toThrow(
|
||||||
|
'The key "secretKey" is sensitive and cannot be set directly. Please use the $var: syntax to set it.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not serialize secrets", () => {
|
||||||
|
registry.setVar("someName", "someValue");
|
||||||
|
registry.setVar("secretKey", "mySecret");
|
||||||
|
const func = registry
|
||||||
|
.get("mock-embedding")!
|
||||||
|
.create({ secretKey: "$var:secretKey", someKey: "$var:someName" });
|
||||||
|
expect(func.toJSON()).toEqual({
|
||||||
|
secretKey: "$var:secretKey",
|
||||||
|
someKey: "$var:someName",
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -104,4 +104,26 @@ describe("remote connection", () => {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should pass on requested extra headers", async () => {
|
||||||
|
await withMockDatabase(
|
||||||
|
(req, res) => {
|
||||||
|
expect(req.headers["x-my-header"]).toEqual("my-value");
|
||||||
|
|
||||||
|
const body = JSON.stringify({ tables: [] });
|
||||||
|
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||||
|
},
|
||||||
|
async (db) => {
|
||||||
|
const tableNames = await db.tableNames();
|
||||||
|
expect(tableNames).toEqual([]);
|
||||||
|
},
|
||||||
|
{
|
||||||
|
clientConfig: {
|
||||||
|
extraHeaders: {
|
||||||
|
"x-my-header": "my-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -253,6 +253,31 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
const arrowTbl = await table.toArrow();
|
const arrowTbl = await table.toArrow();
|
||||||
expect(arrowTbl).toBeInstanceOf(ArrowTable);
|
expect(arrowTbl).toBeInstanceOf(ArrowTable);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should be able to handle missing fields", async () => {
|
||||||
|
const schema = new arrow.Schema([
|
||||||
|
new arrow.Field("id", new arrow.Int32(), true),
|
||||||
|
new arrow.Field("y", new arrow.Int32(), true),
|
||||||
|
new arrow.Field("z", new arrow.Int64(), true),
|
||||||
|
]);
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const table = await db.createEmptyTable("testNull", schema);
|
||||||
|
await table.add([{ id: 1, y: 2 }]);
|
||||||
|
await table.add([{ id: 2 }]);
|
||||||
|
|
||||||
|
await table
|
||||||
|
.mergeInsert("id")
|
||||||
|
.whenNotMatchedInsertAll()
|
||||||
|
.execute([
|
||||||
|
{ id: 3, z: 3 },
|
||||||
|
{ id: 4, z: 5 },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const res = await table.query().toArrow();
|
||||||
|
expect(res.getChild("id")?.toJSON()).toEqual([1, 2, 3, 4]);
|
||||||
|
expect(res.getChild("y")?.toJSON()).toEqual([2, null, null, null]);
|
||||||
|
expect(res.getChild("z")?.toJSON()).toEqual([null, null, 3n, 5n]);
|
||||||
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1013,9 +1038,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
test("can search using a string", async () => {
|
test("can search using a string", async () => {
|
||||||
@register()
|
@register()
|
||||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
ndims() {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,12 +43,17 @@ test("custom embedding function", async () => {
|
|||||||
|
|
||||||
@register("my_embedding")
|
@register("my_embedding")
|
||||||
class MyEmbeddingFunction extends EmbeddingFunction<string> {
|
class MyEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
toJSON(): object {
|
constructor(optionsRaw = {}) {
|
||||||
return {};
|
super();
|
||||||
|
const options = this.resolveVariables(optionsRaw);
|
||||||
|
// Initialize using options
|
||||||
}
|
}
|
||||||
ndims() {
|
ndims() {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
protected getSensitiveKeys(): string[] {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
embeddingDataType(): Float {
|
embeddingDataType(): Float {
|
||||||
return new Float32();
|
return new Float32();
|
||||||
}
|
}
|
||||||
@@ -94,3 +99,14 @@ test("custom embedding function", async () => {
|
|||||||
expect(await table2.countRows()).toBe(2);
|
expect(await table2.countRows()).toBe(2);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("embedding function api_key", async () => {
|
||||||
|
// --8<-- [start:register_secret]
|
||||||
|
const registry = getRegistry();
|
||||||
|
registry.setVar("api_key", "sk-...");
|
||||||
|
|
||||||
|
const func = registry.get("openai")!.create({
|
||||||
|
apiKey: "$var:api_key",
|
||||||
|
});
|
||||||
|
// --8<-- [end:register_secret]
|
||||||
|
});
|
||||||
|
|||||||
@@ -42,4 +42,4 @@ test("full text search", async () => {
|
|||||||
expect(result.length).toBe(10);
|
expect(result.length).toBe(10);
|
||||||
// --8<-- [end:full_text_search]
|
// --8<-- [end:full_text_search]
|
||||||
});
|
});
|
||||||
});
|
}, 10_000);
|
||||||
|
|||||||
@@ -2,31 +2,37 @@
|
|||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
Data as ArrowData,
|
||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
Binary,
|
Binary,
|
||||||
|
Bool,
|
||||||
BufferType,
|
BufferType,
|
||||||
DataType,
|
DataType,
|
||||||
|
Dictionary,
|
||||||
Field,
|
Field,
|
||||||
FixedSizeBinary,
|
FixedSizeBinary,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
Float,
|
Float,
|
||||||
Float32,
|
Float32,
|
||||||
|
Float64,
|
||||||
Int,
|
Int,
|
||||||
|
Int32,
|
||||||
|
Int64,
|
||||||
LargeBinary,
|
LargeBinary,
|
||||||
List,
|
List,
|
||||||
Null,
|
Null,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
RecordBatchFileReader,
|
RecordBatchFileReader,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
RecordBatchReader,
|
|
||||||
RecordBatchStreamWriter,
|
RecordBatchStreamWriter,
|
||||||
Schema,
|
Schema,
|
||||||
Struct,
|
Struct,
|
||||||
Utf8,
|
Utf8,
|
||||||
Vector,
|
Vector,
|
||||||
|
makeVector as arrowMakeVector,
|
||||||
makeBuilder,
|
makeBuilder,
|
||||||
makeData,
|
makeData,
|
||||||
type makeTable,
|
makeTable,
|
||||||
vectorFromArray,
|
vectorFromArray,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
import { Buffers } from "apache-arrow/data";
|
import { Buffers } from "apache-arrow/data";
|
||||||
@@ -236,8 +242,6 @@ export class MakeArrowTableOptions {
|
|||||||
* This function converts an array of Record<String, any> (row-major JS objects)
|
* This function converts an array of Record<String, any> (row-major JS objects)
|
||||||
* to an Arrow Table (a columnar structure)
|
* to an Arrow Table (a columnar structure)
|
||||||
*
|
*
|
||||||
* Note that it currently does not support nulls.
|
|
||||||
*
|
|
||||||
* If a schema is provided then it will be used to determine the resulting array
|
* If a schema is provided then it will be used to determine the resulting array
|
||||||
* types. Fields will also be reordered to fit the order defined by the schema.
|
* types. Fields will also be reordered to fit the order defined by the schema.
|
||||||
*
|
*
|
||||||
@@ -245,6 +249,9 @@ export class MakeArrowTableOptions {
|
|||||||
* will be controlled by the order of properties in the first record. If a type
|
* will be controlled by the order of properties in the first record. If a type
|
||||||
* is inferred it will always be nullable.
|
* is inferred it will always be nullable.
|
||||||
*
|
*
|
||||||
|
* If not all fields are found in the data, then a subset of the schema will be
|
||||||
|
* returned.
|
||||||
|
*
|
||||||
* If the input is empty then a schema must be provided to create an empty table.
|
* If the input is empty then a schema must be provided to create an empty table.
|
||||||
*
|
*
|
||||||
* When a schema is not specified then data types will be inferred. The inference
|
* When a schema is not specified then data types will be inferred. The inference
|
||||||
@@ -252,6 +259,7 @@ export class MakeArrowTableOptions {
|
|||||||
*
|
*
|
||||||
* - boolean => Bool
|
* - boolean => Bool
|
||||||
* - number => Float64
|
* - number => Float64
|
||||||
|
* - bigint => Int64
|
||||||
* - String => Utf8
|
* - String => Utf8
|
||||||
* - Buffer => Binary
|
* - Buffer => Binary
|
||||||
* - Record<String, any> => Struct
|
* - Record<String, any> => Struct
|
||||||
@@ -322,126 +330,316 @@ export function makeArrowTable(
|
|||||||
options?: Partial<MakeArrowTableOptions>,
|
options?: Partial<MakeArrowTableOptions>,
|
||||||
metadata?: Map<string, string>,
|
metadata?: Map<string, string>,
|
||||||
): ArrowTable {
|
): ArrowTable {
|
||||||
|
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
|
||||||
|
let schema: Schema | undefined = undefined;
|
||||||
|
if (opt.schema !== undefined && opt.schema !== null) {
|
||||||
|
schema = sanitizeSchema(opt.schema);
|
||||||
|
schema = validateSchemaEmbeddings(
|
||||||
|
schema as Schema,
|
||||||
|
data,
|
||||||
|
options?.embeddingFunction,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let schemaMetadata = schema?.metadata || new Map<string, string>();
|
||||||
|
if (metadata !== undefined) {
|
||||||
|
schemaMetadata = new Map([...schemaMetadata, ...metadata]);
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
data.length === 0 &&
|
data.length === 0 &&
|
||||||
(options?.schema === undefined || options?.schema === null)
|
(options?.schema === undefined || options?.schema === null)
|
||||||
) {
|
) {
|
||||||
throw new Error("At least one record or a schema needs to be provided");
|
throw new Error("At least one record or a schema needs to be provided");
|
||||||
}
|
} else if (data.length === 0) {
|
||||||
|
if (schema === undefined) {
|
||||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
|
throw new Error("A schema must be provided if data is empty");
|
||||||
if (opt.schema !== undefined && opt.schema !== null) {
|
|
||||||
opt.schema = sanitizeSchema(opt.schema);
|
|
||||||
opt.schema = validateSchemaEmbeddings(
|
|
||||||
opt.schema as Schema,
|
|
||||||
data,
|
|
||||||
options?.embeddingFunction,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
const columns: Record<string, Vector> = {};
|
|
||||||
// TODO: sample dataset to find missing columns
|
|
||||||
// Prefer the field ordering of the schema, if present
|
|
||||||
const columnNames =
|
|
||||||
opt.schema != null ? (opt.schema.names as string[]) : Object.keys(data[0]);
|
|
||||||
for (const colName of columnNames) {
|
|
||||||
if (
|
|
||||||
data.length !== 0 &&
|
|
||||||
!Object.prototype.hasOwnProperty.call(data[0], colName)
|
|
||||||
) {
|
|
||||||
// The field is present in the schema, but not in the data, skip it
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Extract a single column from the records (transpose from row-major to col-major)
|
|
||||||
let values = data.map((datum) => datum[colName]);
|
|
||||||
|
|
||||||
// By default (type === undefined) arrow will infer the type from the JS type
|
|
||||||
let type;
|
|
||||||
if (opt.schema !== undefined) {
|
|
||||||
// If there is a schema provided, then use that for the type instead
|
|
||||||
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type;
|
|
||||||
if (DataType.isInt(type) && type.bitWidth === 64) {
|
|
||||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
|
||||||
values = values.map((v) => {
|
|
||||||
if (v === null) {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
if (typeof v === "bigint") {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
if (typeof v === "number") {
|
|
||||||
return BigInt(v);
|
|
||||||
}
|
|
||||||
throw new Error(
|
|
||||||
`Expected BigInt or number for column ${colName}, got ${typeof v}`,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, check to see if this column is one of the vector columns
|
schema = new Schema(schema.fields, schemaMetadata);
|
||||||
// defined by opt.vectorColumns and, if so, use the fixed size list type
|
return new ArrowTable(schema);
|
||||||
const vectorColumnOptions = opt.vectorColumns[colName];
|
}
|
||||||
if (vectorColumnOptions !== undefined) {
|
}
|
||||||
const firstNonNullValue = values.find((v) => v !== null);
|
|
||||||
if (Array.isArray(firstNonNullValue)) {
|
let inferredSchema = inferSchema(data, schema, opt);
|
||||||
type = newVectorType(
|
inferredSchema = new Schema(inferredSchema.fields, schemaMetadata);
|
||||||
firstNonNullValue.length,
|
|
||||||
vectorColumnOptions.type,
|
const finalColumns: Record<string, Vector> = {};
|
||||||
);
|
for (const field of inferredSchema.fields) {
|
||||||
|
finalColumns[field.name] = transposeData(data, field);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ArrowTable(inferredSchema, finalColumns);
|
||||||
|
}
|
||||||
|
|
||||||
|
function inferSchema(
|
||||||
|
data: Array<Record<string, unknown>>,
|
||||||
|
schema: Schema | undefined,
|
||||||
|
opts: MakeArrowTableOptions,
|
||||||
|
): Schema {
|
||||||
|
// We will collect all fields we see in the data.
|
||||||
|
const pathTree = new PathTree<DataType>();
|
||||||
|
|
||||||
|
for (const [rowI, row] of data.entries()) {
|
||||||
|
for (const [path, value] of rowPathsAndValues(row)) {
|
||||||
|
if (!pathTree.has(path)) {
|
||||||
|
// First time seeing this field.
|
||||||
|
if (schema !== undefined) {
|
||||||
|
const field = getFieldForPath(schema, path);
|
||||||
|
if (field === undefined) {
|
||||||
|
throw new Error(
|
||||||
|
`Found field not in schema: ${path.join(".")} at row ${rowI}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
pathTree.set(path, field.type);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new Error(
|
const inferredType = inferType(value, path, opts);
|
||||||
`Column ${colName} is expected to be a vector column but first non-null value is not an array. Could not determine size of vector column`,
|
if (inferredType === undefined) {
|
||||||
);
|
throw new Error(`Failed to infer data type for field ${path.join(".")} at row ${rowI}. \
|
||||||
|
Consider providing an explicit schema.`);
|
||||||
|
}
|
||||||
|
pathTree.set(path, inferredType);
|
||||||
|
}
|
||||||
|
} else if (schema === undefined) {
|
||||||
|
const currentType = pathTree.get(path);
|
||||||
|
const newType = inferType(value, path, opts);
|
||||||
|
if (currentType !== newType) {
|
||||||
|
new Error(`Failed to infer schema for data. Previously inferred type \
|
||||||
|
${currentType} but found ${newType} at row ${rowI}. Consider \
|
||||||
|
providing an explicit schema.`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
|
||||||
// Convert an Array of JS values to an arrow vector
|
|
||||||
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings);
|
|
||||||
} catch (error: unknown) {
|
|
||||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
|
||||||
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opt.schema != null) {
|
if (schema === undefined) {
|
||||||
// `new ArrowTable(columns)` infers a schema which may sometimes have
|
function fieldsFromPathTree(pathTree: PathTree<DataType>): Field[] {
|
||||||
// incorrect nullability (it assumes nullable=true always)
|
const fields = [];
|
||||||
//
|
for (const [name, value] of pathTree.map.entries()) {
|
||||||
// `new ArrowTable(schema, columns)` will also fail because it will create a
|
if (value instanceof PathTree) {
|
||||||
// batch with an inferred schema and then complain that the batch schema
|
const children = fieldsFromPathTree(value);
|
||||||
// does not match the provided schema.
|
fields.push(new Field(name, new Struct(children), true));
|
||||||
//
|
} else {
|
||||||
// To work around this we first create a table with the wrong schema and
|
fields.push(new Field(name, value, true));
|
||||||
// then patch the schema of the batches so we can use
|
|
||||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
|
||||||
const firstTable = new ArrowTable(columns);
|
|
||||||
const batchesFixed = firstTable.batches.map(
|
|
||||||
(batch) => new RecordBatch(opt.schema as Schema, batch.data),
|
|
||||||
);
|
|
||||||
let schema: Schema;
|
|
||||||
if (metadata !== undefined) {
|
|
||||||
let schemaMetadata = opt.schema.metadata;
|
|
||||||
if (schemaMetadata.size === 0) {
|
|
||||||
schemaMetadata = metadata;
|
|
||||||
} else {
|
|
||||||
for (const [key, entry] of schemaMetadata.entries()) {
|
|
||||||
schemaMetadata.set(key, entry);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return fields;
|
||||||
|
}
|
||||||
|
const fields = fieldsFromPathTree(pathTree);
|
||||||
|
return new Schema(fields);
|
||||||
|
} else {
|
||||||
|
function takeMatchingFields(
|
||||||
|
fields: Field[],
|
||||||
|
pathTree: PathTree<DataType>,
|
||||||
|
): Field[] {
|
||||||
|
const outFields = [];
|
||||||
|
for (const field of fields) {
|
||||||
|
if (pathTree.map.has(field.name)) {
|
||||||
|
const value = pathTree.get([field.name]);
|
||||||
|
if (value instanceof PathTree) {
|
||||||
|
const struct = field.type as Struct;
|
||||||
|
const children = takeMatchingFields(struct.children, value);
|
||||||
|
outFields.push(
|
||||||
|
new Field(field.name, new Struct(children), field.nullable),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
outFields.push(
|
||||||
|
new Field(field.name, value as DataType, field.nullable),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outFields;
|
||||||
|
}
|
||||||
|
const fields = takeMatchingFields(schema.fields, pathTree);
|
||||||
|
return new Schema(fields);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
schema = new Schema(opt.schema.fields as Field[], schemaMetadata);
|
function* rowPathsAndValues(
|
||||||
|
row: Record<string, unknown>,
|
||||||
|
basePath: string[] = [],
|
||||||
|
): Generator<[string[], unknown]> {
|
||||||
|
for (const [key, value] of Object.entries(row)) {
|
||||||
|
if (isObject(value)) {
|
||||||
|
yield* rowPathsAndValues(value, [...basePath, key]);
|
||||||
} else {
|
} else {
|
||||||
schema = opt.schema as Schema;
|
yield [[...basePath, key], value];
|
||||||
}
|
}
|
||||||
return new ArrowTable(schema, batchesFixed);
|
|
||||||
}
|
}
|
||||||
const tbl = new ArrowTable(columns);
|
}
|
||||||
if (metadata !== undefined) {
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
function isObject(value: unknown): value is Record<string, unknown> {
|
||||||
(<any>tbl.schema).metadata = metadata;
|
return (
|
||||||
|
typeof value === "object" &&
|
||||||
|
value !== null &&
|
||||||
|
!Array.isArray(value) &&
|
||||||
|
!(value instanceof RegExp) &&
|
||||||
|
!(value instanceof Date) &&
|
||||||
|
!(value instanceof Set) &&
|
||||||
|
!(value instanceof Map) &&
|
||||||
|
!(value instanceof Buffer)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getFieldForPath(schema: Schema, path: string[]): Field | undefined {
|
||||||
|
let current: Field | Schema = schema;
|
||||||
|
for (const key of path) {
|
||||||
|
if (current instanceof Schema) {
|
||||||
|
const field: Field | undefined = current.fields.find(
|
||||||
|
(f) => f.name === key,
|
||||||
|
);
|
||||||
|
if (field === undefined) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
current = field;
|
||||||
|
} else if (current instanceof Field && DataType.isStruct(current.type)) {
|
||||||
|
const struct: Struct = current.type;
|
||||||
|
const field = struct.children.find((f) => f.name === key);
|
||||||
|
if (field === undefined) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
current = field;
|
||||||
|
} else {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (current instanceof Field) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to infer which Arrow type to use for a given value.
|
||||||
|
*
|
||||||
|
* May return undefined if the type cannot be inferred.
|
||||||
|
*/
|
||||||
|
function inferType(
|
||||||
|
value: unknown,
|
||||||
|
path: string[],
|
||||||
|
opts: MakeArrowTableOptions,
|
||||||
|
): DataType | undefined {
|
||||||
|
if (typeof value === "bigint") {
|
||||||
|
return new Int64();
|
||||||
|
} else if (typeof value === "number") {
|
||||||
|
// Even if it's an integer, it's safer to assume Float64. Users can
|
||||||
|
// always provide an explicit schema or use BigInt if they mean integer.
|
||||||
|
return new Float64();
|
||||||
|
} else if (typeof value === "string") {
|
||||||
|
if (opts.dictionaryEncodeStrings) {
|
||||||
|
return new Dictionary(new Utf8(), new Int32());
|
||||||
|
} else {
|
||||||
|
return new Utf8();
|
||||||
|
}
|
||||||
|
} else if (typeof value === "boolean") {
|
||||||
|
return new Bool();
|
||||||
|
} else if (value instanceof Buffer) {
|
||||||
|
return new Binary();
|
||||||
|
} else if (Array.isArray(value)) {
|
||||||
|
if (value.length === 0) {
|
||||||
|
return undefined; // Without any values we can't infer the type
|
||||||
|
}
|
||||||
|
if (path.length === 1 && Object.hasOwn(opts.vectorColumns, path[0])) {
|
||||||
|
const floatType = sanitizeType(opts.vectorColumns[path[0]].type);
|
||||||
|
return new FixedSizeList(
|
||||||
|
value.length,
|
||||||
|
new Field("item", floatType, true),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const valueType = inferType(value[0], path, opts);
|
||||||
|
if (valueType === undefined) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
// Try to automatically detect embedding columns.
|
||||||
|
if (valueType instanceof Float && path[path.length - 1] === "vector") {
|
||||||
|
// We default to Float32 for vectors.
|
||||||
|
const child = new Field("item", new Float32(), true);
|
||||||
|
return new FixedSizeList(value.length, child);
|
||||||
|
} else {
|
||||||
|
const child = new Field("item", valueType, true);
|
||||||
|
return new List(child);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO: timestamp
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class PathTree<V> {
|
||||||
|
map: Map<string, V | PathTree<V>>;
|
||||||
|
|
||||||
|
constructor(entries?: [string[], V][]) {
|
||||||
|
this.map = new Map();
|
||||||
|
if (entries !== undefined) {
|
||||||
|
for (const [path, value] of entries) {
|
||||||
|
this.set(path, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
has(path: string[]): boolean {
|
||||||
|
let ref: PathTree<V> = this;
|
||||||
|
for (const part of path) {
|
||||||
|
if (!(ref instanceof PathTree) || !ref.map.has(part)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ref = ref.map.get(part) as PathTree<V>;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
get(path: string[]): V | undefined {
|
||||||
|
let ref: PathTree<V> = this;
|
||||||
|
for (const part of path) {
|
||||||
|
if (!(ref instanceof PathTree) || !ref.map.has(part)) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
ref = ref.map.get(part) as PathTree<V>;
|
||||||
|
}
|
||||||
|
return ref as V;
|
||||||
|
}
|
||||||
|
set(path: string[], value: V): void {
|
||||||
|
let ref: PathTree<V> = this;
|
||||||
|
for (const part of path.slice(0, path.length - 1)) {
|
||||||
|
if (!ref.map.has(part)) {
|
||||||
|
ref.map.set(part, new PathTree<V>());
|
||||||
|
}
|
||||||
|
ref = ref.map.get(part) as PathTree<V>;
|
||||||
|
}
|
||||||
|
ref.map.set(path[path.length - 1], value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function transposeData(
|
||||||
|
data: Record<string, unknown>[],
|
||||||
|
field: Field,
|
||||||
|
path: string[] = [],
|
||||||
|
): Vector {
|
||||||
|
if (field.type instanceof Struct) {
|
||||||
|
const childFields = field.type.children;
|
||||||
|
const childVectors = childFields.map((child) => {
|
||||||
|
return transposeData(data, child, [...path, child.name]);
|
||||||
|
});
|
||||||
|
const structData = makeData({
|
||||||
|
type: field.type,
|
||||||
|
children: childVectors as unknown as ArrowData<DataType>[],
|
||||||
|
});
|
||||||
|
return arrowMakeVector(structData);
|
||||||
|
} else {
|
||||||
|
const valuesPath = [...path, field.name];
|
||||||
|
const values = data.map((datum) => {
|
||||||
|
let current: unknown = datum;
|
||||||
|
for (const key of valuesPath) {
|
||||||
|
if (isObject(current) && Object.hasOwn(current, key)) {
|
||||||
|
current = current[key];
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return current;
|
||||||
|
});
|
||||||
|
return makeVector(values, field.type);
|
||||||
}
|
}
|
||||||
return tbl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -491,6 +689,31 @@ function makeVector(
|
|||||||
): Vector<any> {
|
): Vector<any> {
|
||||||
if (type !== undefined) {
|
if (type !== undefined) {
|
||||||
// No need for inference, let Arrow create it
|
// No need for inference, let Arrow create it
|
||||||
|
if (type instanceof Int) {
|
||||||
|
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||||
|
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||||
|
values = values.map((v) => {
|
||||||
|
if (v === null) {
|
||||||
|
return v;
|
||||||
|
} else if (typeof v === "bigint") {
|
||||||
|
return v;
|
||||||
|
} else if (typeof v === "number") {
|
||||||
|
return BigInt(v);
|
||||||
|
} else {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Similarly, bigint isn't supported for 16 or 32-bit ints.
|
||||||
|
values = values.map((v) => {
|
||||||
|
if (typeof v == "bigint") {
|
||||||
|
return Number(v);
|
||||||
|
} else {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
return vectorFromArray(values, type);
|
return vectorFromArray(values, type);
|
||||||
}
|
}
|
||||||
if (values.length === 0) {
|
if (values.length === 0) {
|
||||||
@@ -902,7 +1125,7 @@ function validateSchemaEmbeddings(
|
|||||||
schema: Schema,
|
schema: Schema,
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings: EmbeddingFunctionConfig | undefined,
|
embeddings: EmbeddingFunctionConfig | undefined,
|
||||||
) {
|
): Schema {
|
||||||
const fields = [];
|
const fields = [];
|
||||||
const missingEmbeddingFields = [];
|
const missingEmbeddingFields = [];
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ export interface CreateTableOptions {
|
|||||||
*
|
*
|
||||||
* The default is `stable`.
|
* The default is `stable`.
|
||||||
* Set to "legacy" to use the old format.
|
* Set to "legacy" to use the old format.
|
||||||
|
*
|
||||||
|
* @deprecated Pass `new_table_data_storage_version` to storageOptions instead.
|
||||||
*/
|
*/
|
||||||
dataStorageVersion?: string;
|
dataStorageVersion?: string;
|
||||||
|
|
||||||
@@ -61,17 +63,11 @@ export interface CreateTableOptions {
|
|||||||
* turning this on will make the dataset unreadable for older versions
|
* turning this on will make the dataset unreadable for older versions
|
||||||
* of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
* of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
||||||
* use the {@link LocalTable#migrateManifestPathsV2} method.
|
* use the {@link LocalTable#migrateManifestPathsV2} method.
|
||||||
|
*
|
||||||
|
* @deprecated Pass `new_table_enable_v2_manifest_paths` to storageOptions instead.
|
||||||
*/
|
*/
|
||||||
enableV2ManifestPaths?: boolean;
|
enableV2ManifestPaths?: boolean;
|
||||||
|
|
||||||
/**
|
|
||||||
* If true then data files will be written with the legacy format
|
|
||||||
*
|
|
||||||
* The default is false.
|
|
||||||
*
|
|
||||||
* Deprecated. Use data storage version instead.
|
|
||||||
*/
|
|
||||||
useLegacyFormat?: boolean;
|
|
||||||
schema?: SchemaLike;
|
schema?: SchemaLike;
|
||||||
embeddingFunction?: EmbeddingFunctionConfig;
|
embeddingFunction?: EmbeddingFunctionConfig;
|
||||||
}
|
}
|
||||||
@@ -215,6 +211,11 @@ export abstract class Connection {
|
|||||||
* @param {string} name The name of the table to drop.
|
* @param {string} name The name of the table to drop.
|
||||||
*/
|
*/
|
||||||
abstract dropTable(name: string): Promise<void>;
|
abstract dropTable(name: string): Promise<void>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Drop all tables in the database.
|
||||||
|
*/
|
||||||
|
abstract dropAllTables(): Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @hideconstructor */
|
/** @hideconstructor */
|
||||||
@@ -256,6 +257,28 @@ export class LocalConnection extends Connection {
|
|||||||
return new LocalTable(innerTable);
|
return new LocalTable(innerTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getStorageOptions(
|
||||||
|
options?: Partial<CreateTableOptions>,
|
||||||
|
): Record<string, string> | undefined {
|
||||||
|
if (options?.dataStorageVersion !== undefined) {
|
||||||
|
if (options.storageOptions === undefined) {
|
||||||
|
options.storageOptions = {};
|
||||||
|
}
|
||||||
|
options.storageOptions["newTableDataStorageVersion"] =
|
||||||
|
options.dataStorageVersion;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options?.enableV2ManifestPaths !== undefined) {
|
||||||
|
if (options.storageOptions === undefined) {
|
||||||
|
options.storageOptions = {};
|
||||||
|
}
|
||||||
|
options.storageOptions["newTableEnableV2ManifestPaths"] =
|
||||||
|
options.enableV2ManifestPaths ? "true" : "false";
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanseStorageOptions(options?.storageOptions);
|
||||||
|
}
|
||||||
|
|
||||||
async createTable(
|
async createTable(
|
||||||
nameOrOptions:
|
nameOrOptions:
|
||||||
| string
|
| string
|
||||||
@@ -272,20 +295,14 @@ export class LocalConnection extends Connection {
|
|||||||
throw new Error("data is required");
|
throw new Error("data is required");
|
||||||
}
|
}
|
||||||
const { buf, mode } = await parseTableData(data, options);
|
const { buf, mode } = await parseTableData(data, options);
|
||||||
let dataStorageVersion = "stable";
|
|
||||||
if (options?.dataStorageVersion !== undefined) {
|
const storageOptions = this.getStorageOptions(options);
|
||||||
dataStorageVersion = options.dataStorageVersion;
|
|
||||||
} else if (options?.useLegacyFormat !== undefined) {
|
|
||||||
dataStorageVersion = options.useLegacyFormat ? "legacy" : "stable";
|
|
||||||
}
|
|
||||||
|
|
||||||
const innerTable = await this.inner.createTable(
|
const innerTable = await this.inner.createTable(
|
||||||
nameOrOptions,
|
nameOrOptions,
|
||||||
buf,
|
buf,
|
||||||
mode,
|
mode,
|
||||||
cleanseStorageOptions(options?.storageOptions),
|
storageOptions,
|
||||||
dataStorageVersion,
|
|
||||||
options?.enableV2ManifestPaths,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
return new LocalTable(innerTable);
|
return new LocalTable(innerTable);
|
||||||
@@ -309,22 +326,14 @@ export class LocalConnection extends Connection {
|
|||||||
metadata = registry.getTableMetadata([embeddingFunction]);
|
metadata = registry.getTableMetadata([embeddingFunction]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let dataStorageVersion = "stable";
|
const storageOptions = this.getStorageOptions(options);
|
||||||
if (options?.dataStorageVersion !== undefined) {
|
|
||||||
dataStorageVersion = options.dataStorageVersion;
|
|
||||||
} else if (options?.useLegacyFormat !== undefined) {
|
|
||||||
dataStorageVersion = options.useLegacyFormat ? "legacy" : "stable";
|
|
||||||
}
|
|
||||||
|
|
||||||
const table = makeEmptyTable(schema, metadata);
|
const table = makeEmptyTable(schema, metadata);
|
||||||
const buf = await fromTableToBuffer(table);
|
const buf = await fromTableToBuffer(table);
|
||||||
const innerTable = await this.inner.createEmptyTable(
|
const innerTable = await this.inner.createEmptyTable(
|
||||||
name,
|
name,
|
||||||
buf,
|
buf,
|
||||||
mode,
|
mode,
|
||||||
cleanseStorageOptions(options?.storageOptions),
|
storageOptions,
|
||||||
dataStorageVersion,
|
|
||||||
options?.enableV2ManifestPaths,
|
|
||||||
);
|
);
|
||||||
return new LocalTable(innerTable);
|
return new LocalTable(innerTable);
|
||||||
}
|
}
|
||||||
@@ -332,6 +341,10 @@ export class LocalConnection extends Connection {
|
|||||||
async dropTable(name: string): Promise<void> {
|
async dropTable(name: string): Promise<void> {
|
||||||
return this.inner.dropTable(name);
|
return this.inner.dropTable(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async dropAllTables(): Promise<void> {
|
||||||
|
return this.inner.dropAllTables();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
newVectorType,
|
newVectorType,
|
||||||
} from "../arrow";
|
} from "../arrow";
|
||||||
import { sanitizeType } from "../sanitize";
|
import { sanitizeType } from "../sanitize";
|
||||||
|
import { getRegistry } from "./registry";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Options for a given embedding function
|
* Options for a given embedding function
|
||||||
@@ -32,6 +33,22 @@ export interface EmbeddingFunctionConstructor<
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* An embedding function that automatically creates vector representation for a given column.
|
* An embedding function that automatically creates vector representation for a given column.
|
||||||
|
*
|
||||||
|
* It's important subclasses pass the **original** options to the super constructor
|
||||||
|
* and then pass those options to `resolveVariables` to resolve any variables before
|
||||||
|
* using them.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||||
|
* constructor(options: {model: string, timeout: number}) {
|
||||||
|
* super(optionsRaw);
|
||||||
|
* const options = this.resolveVariables(optionsRaw);
|
||||||
|
* this.model = options.model;
|
||||||
|
* this.timeout = options.timeout;
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
*/
|
*/
|
||||||
export abstract class EmbeddingFunction<
|
export abstract class EmbeddingFunction<
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
||||||
@@ -44,33 +61,74 @@ export abstract class EmbeddingFunction<
|
|||||||
*/
|
*/
|
||||||
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
||||||
readonly TOptions!: M;
|
readonly TOptions!: M;
|
||||||
/**
|
|
||||||
* Convert the embedding function to a JSON object
|
|
||||||
* It is used to serialize the embedding function to the schema
|
|
||||||
* It's important that any object returned by this method contains all the necessary
|
|
||||||
* information to recreate the embedding function
|
|
||||||
*
|
|
||||||
* It should return the same object that was passed to the constructor
|
|
||||||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
|
||||||
*
|
|
||||||
* @example
|
|
||||||
* ```ts
|
|
||||||
* class MyEmbeddingFunction extends EmbeddingFunction {
|
|
||||||
* constructor(options: {model: string, timeout: number}) {
|
|
||||||
* super();
|
|
||||||
* this.model = options.model;
|
|
||||||
* this.timeout = options.timeout;
|
|
||||||
* }
|
|
||||||
* toJSON() {
|
|
||||||
* return {
|
|
||||||
* model: this.model,
|
|
||||||
* timeout: this.timeout,
|
|
||||||
* };
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
abstract toJSON(): Partial<M>;
|
|
||||||
|
|
||||||
|
#config: Partial<M>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the original arguments to the constructor, to serialize them so they
|
||||||
|
* can be used to recreate the embedding function later.
|
||||||
|
*/
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny :
|
||||||
|
toJSON(): Record<string, any> {
|
||||||
|
return JSON.parse(JSON.stringify(this.#config));
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
this.#config = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide a list of keys in the function options that should be treated as
|
||||||
|
* sensitive. If users pass raw values for these keys, they will be rejected.
|
||||||
|
*/
|
||||||
|
protected getSensitiveKeys(): string[] {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply variables to the config.
|
||||||
|
*/
|
||||||
|
protected resolveVariables(config: Partial<M>): Partial<M> {
|
||||||
|
this.#config = config;
|
||||||
|
const registry = getRegistry();
|
||||||
|
const newConfig = { ...config };
|
||||||
|
for (const [key_, value] of Object.entries(newConfig)) {
|
||||||
|
if (
|
||||||
|
this.getSensitiveKeys().includes(key_) &&
|
||||||
|
!value.startsWith("$var:")
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
`The key "${key_}" is sensitive and cannot be set directly. Please use the $var: syntax to set it.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Makes TS happy (https://stackoverflow.com/a/78391854)
|
||||||
|
const key = key_ as keyof M;
|
||||||
|
if (typeof value === "string" && value.startsWith("$var:")) {
|
||||||
|
const [name, defaultValue] = value.slice(5).split(":", 2);
|
||||||
|
const variableValue = registry.getVar(name);
|
||||||
|
if (!variableValue) {
|
||||||
|
if (defaultValue) {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny:
|
||||||
|
newConfig[key] = defaultValue as any;
|
||||||
|
} else {
|
||||||
|
throw new Error(`Variable "${name}" not found`);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny:
|
||||||
|
newConfig[key] = variableValue as any;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optionally load any resources needed for the embedding function.
|
||||||
|
*
|
||||||
|
* This method is called after the embedding function has been initialized
|
||||||
|
* but before any embeddings are computed. It is useful for loading local models
|
||||||
|
* or other resources that are needed for the embedding function to work.
|
||||||
|
*/
|
||||||
async init?(): Promise<void>;
|
async init?(): Promise<void>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
|||||||
#modelName: OpenAIOptions["model"];
|
#modelName: OpenAIOptions["model"];
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
options: Partial<OpenAIOptions> = {
|
optionsRaw: Partial<OpenAIOptions> = {
|
||||||
model: "text-embedding-ada-002",
|
model: "text-embedding-ada-002",
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
super();
|
super();
|
||||||
|
const options = this.resolveVariables(optionsRaw);
|
||||||
|
|
||||||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||||
if (!openAIKey) {
|
if (!openAIKey) {
|
||||||
throw new Error("OpenAI API key is required");
|
throw new Error("OpenAI API key is required");
|
||||||
@@ -52,10 +54,8 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
|||||||
this.#modelName = modelName;
|
this.#modelName = modelName;
|
||||||
}
|
}
|
||||||
|
|
||||||
toJSON() {
|
protected getSensitiveKeys(): string[] {
|
||||||
return {
|
return ["apiKey"];
|
||||||
model: this.#modelName,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ndims(): number {
|
ndims(): number {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ export interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
|||||||
*/
|
*/
|
||||||
export class EmbeddingFunctionRegistry {
|
export class EmbeddingFunctionRegistry {
|
||||||
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
||||||
|
#variables = new Map<string, string>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the number of registered functions
|
* Get the number of registered functions
|
||||||
@@ -82,10 +83,7 @@ export class EmbeddingFunctionRegistry {
|
|||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
create = function (options?: any) {
|
create = (options?: any) => new factory(options);
|
||||||
const instance = new factory(options);
|
|
||||||
return instance;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -164,6 +162,37 @@ export class EmbeddingFunctionRegistry {
|
|||||||
|
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a variable. These can be accessed in the embedding function
|
||||||
|
* configuration using the syntax `$var:variable_name`. If they are not
|
||||||
|
* set, an error will be thrown letting you know which key is unset. If you
|
||||||
|
* want to supply a default value, you can add an additional part in the
|
||||||
|
* configuration like so: `$var:variable_name:default_value`. Default values
|
||||||
|
* can be used for runtime configurations that are not sensitive, such as
|
||||||
|
* whether to use a GPU for inference.
|
||||||
|
*
|
||||||
|
* The name must not contain colons. The default value can contain colons.
|
||||||
|
*
|
||||||
|
* @param name
|
||||||
|
* @param value
|
||||||
|
*/
|
||||||
|
setVar(name: string, value: string): void {
|
||||||
|
if (name.includes(":")) {
|
||||||
|
throw new Error("Variable names cannot contain colons");
|
||||||
|
}
|
||||||
|
this.#variables.set(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a variable.
|
||||||
|
* @param name
|
||||||
|
* @returns
|
||||||
|
* @see {@link setVar}
|
||||||
|
*/
|
||||||
|
getVar(name: string): string | undefined {
|
||||||
|
return this.#variables.get(name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const _REGISTRY = new EmbeddingFunctionRegistry();
|
const _REGISTRY = new EmbeddingFunctionRegistry();
|
||||||
|
|||||||
@@ -44,11 +44,12 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
#ndims?: number;
|
#ndims?: number;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
options: Partial<XenovaTransformerOptions> = {
|
optionsRaw: Partial<XenovaTransformerOptions> = {
|
||||||
model: "Xenova/all-MiniLM-L6-v2",
|
model: "Xenova/all-MiniLM-L6-v2",
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
super();
|
super();
|
||||||
|
const options = this.resolveVariables(optionsRaw);
|
||||||
|
|
||||||
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
|
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
|
||||||
this.#tokenizerOptions = {
|
this.#tokenizerOptions = {
|
||||||
@@ -59,22 +60,6 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
this.#ndims = options.ndims;
|
this.#ndims = options.ndims;
|
||||||
this.#modelName = modelName;
|
this.#modelName = modelName;
|
||||||
}
|
}
|
||||||
toJSON() {
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
|
||||||
const obj: Record<string, any> = {
|
|
||||||
model: this.#modelName,
|
|
||||||
};
|
|
||||||
if (this.#ndims) {
|
|
||||||
obj["ndims"] = this.#ndims;
|
|
||||||
}
|
|
||||||
if (this.#tokenizerOptions) {
|
|
||||||
obj["tokenizerOptions"] = this.#tokenizerOptions;
|
|
||||||
}
|
|
||||||
if (this.#tokenizer) {
|
|
||||||
obj["tokenizer"] = this.#tokenizer.name;
|
|
||||||
}
|
|
||||||
return obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
async init() {
|
async init() {
|
||||||
let transformers;
|
let transformers;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.15.1-beta.3",
|
"version": "0.16.1-beta.3",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -2,17 +2,15 @@
|
|||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
|
use lancedb::database::CreateTableMode;
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
use crate::error::{convert_error, NapiErrorExt};
|
use crate::error::NapiErrorExt;
|
||||||
use crate::table::Table;
|
use crate::table::Table;
|
||||||
use crate::ConnectionOptions;
|
use crate::ConnectionOptions;
|
||||||
use lancedb::connection::{
|
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
||||||
ConnectBuilder, Connection as LanceDBConnection, CreateTableMode, LanceFileVersion,
|
|
||||||
};
|
|
||||||
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
|
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -124,8 +122,6 @@ impl Connection {
|
|||||||
buf: Buffer,
|
buf: Buffer,
|
||||||
mode: String,
|
mode: String,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
data_storage_options: Option<String>,
|
|
||||||
enable_v2_manifest_paths: Option<bool>,
|
|
||||||
) -> napi::Result<Table> {
|
) -> napi::Result<Table> {
|
||||||
let batches = ipc_file_to_batches(buf.to_vec())
|
let batches = ipc_file_to_batches(buf.to_vec())
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
@@ -137,14 +133,6 @@ impl Connection {
|
|||||||
builder = builder.storage_option(key, value);
|
builder = builder.storage_option(key, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(data_storage_option) = data_storage_options.as_ref() {
|
|
||||||
builder = builder.data_storage_version(
|
|
||||||
LanceFileVersion::from_str(data_storage_option).map_err(|e| convert_error(&e))?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
|
||||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
|
||||||
}
|
|
||||||
let tbl = builder.execute().await.default_error()?;
|
let tbl = builder.execute().await.default_error()?;
|
||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
}
|
}
|
||||||
@@ -156,8 +144,6 @@ impl Connection {
|
|||||||
schema_buf: Buffer,
|
schema_buf: Buffer,
|
||||||
mode: String,
|
mode: String,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
data_storage_options: Option<String>,
|
|
||||||
enable_v2_manifest_paths: Option<bool>,
|
|
||||||
) -> napi::Result<Table> {
|
) -> napi::Result<Table> {
|
||||||
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
|
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
|
||||||
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
|
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
|
||||||
@@ -172,14 +158,6 @@ impl Connection {
|
|||||||
builder = builder.storage_option(key, value);
|
builder = builder.storage_option(key, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(data_storage_option) = data_storage_options.as_ref() {
|
|
||||||
builder = builder.data_storage_version(
|
|
||||||
LanceFileVersion::from_str(data_storage_option).map_err(|e| convert_error(&e))?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
|
||||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
|
||||||
}
|
|
||||||
let tbl = builder.execute().await.default_error()?;
|
let tbl = builder.execute().await.default_error()?;
|
||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
}
|
}
|
||||||
@@ -209,4 +187,9 @@ impl Connection {
|
|||||||
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
||||||
self.get_inner()?.drop_table(&name).await.default_error()
|
self.get_inner()?.drop_table(&name).await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi(catch_unwind)]
|
||||||
|
pub async fn drop_all_tables(&self) -> napi::Result<()> {
|
||||||
|
self.get_inner()?.drop_all_tables().await.default_error()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
/// Timeout configuration for remote HTTP client.
|
/// Timeout configuration for remote HTTP client.
|
||||||
@@ -67,6 +69,7 @@ pub struct ClientConfig {
|
|||||||
pub user_agent: Option<String>,
|
pub user_agent: Option<String>,
|
||||||
pub retry_config: Option<RetryConfig>,
|
pub retry_config: Option<RetryConfig>,
|
||||||
pub timeout_config: Option<TimeoutConfig>,
|
pub timeout_config: Option<TimeoutConfig>,
|
||||||
|
pub extra_headers: Option<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||||
@@ -104,6 +107,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
|||||||
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
|
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
|
||||||
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
||||||
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
||||||
|
extra_headers: config.extra_headers.unwrap_or_default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.18.1-beta.4"
|
current_version = "0.20.0"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
2
python/.gitignore
vendored
Normal file
2
python/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Test data created by some example tests
|
||||||
|
data/
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.18.1-beta.4"
|
version = "0.20.0"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -29,4 +29,4 @@ doctest: ## Run documentation tests.
|
|||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
test: ## Run tests.
|
test: ## Run tests.
|
||||||
pytest python/tests -vv --durations=10 -m "not slow"
|
pytest python/tests -vv --durations=10 -m "not slow and not s3_test"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ name = "lancedb"
|
|||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.23.0b5",
|
"pylance~=0.23.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"packaging",
|
"packaging",
|
||||||
@@ -55,7 +55,12 @@ tests = [
|
|||||||
"tantivy",
|
"tantivy",
|
||||||
"pyarrow-stubs",
|
"pyarrow-stubs",
|
||||||
]
|
]
|
||||||
dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"']
|
dev = [
|
||||||
|
"ruff",
|
||||||
|
"pre-commit",
|
||||||
|
"pyright",
|
||||||
|
'typing-extensions>=4.0.0; python_version < "3.11"',
|
||||||
|
]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ class Connection(object):
|
|||||||
mode: str,
|
mode: str,
|
||||||
data: pa.RecordBatchReader,
|
data: pa.RecordBatchReader,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
|
||||||
) -> Table: ...
|
) -> Table: ...
|
||||||
async def create_empty_table(
|
async def create_empty_table(
|
||||||
self,
|
self,
|
||||||
@@ -24,8 +22,6 @@ class Connection(object):
|
|||||||
mode: str,
|
mode: str,
|
||||||
schema: pa.Schema,
|
schema: pa.Schema,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
|
||||||
) -> Table: ...
|
) -> Table: ...
|
||||||
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||||
async def drop_table(self, name: str) -> None: ...
|
async def drop_table(self, name: str) -> None: ...
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -66,3 +66,17 @@ class AsyncRecordBatchReader:
|
|||||||
batches = table.to_batches(max_chunksize=max_batch_length)
|
batches = table.to_batches(max_chunksize=max_batch_length)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def peek_reader(
|
||||||
|
reader: pa.RecordBatchReader,
|
||||||
|
) -> Tuple[pa.RecordBatch, pa.RecordBatchReader]:
|
||||||
|
if not isinstance(reader, pa.RecordBatchReader):
|
||||||
|
raise TypeError("reader must be a RecordBatchReader")
|
||||||
|
batch = reader.read_next_batch()
|
||||||
|
|
||||||
|
def all_batches():
|
||||||
|
yield batch
|
||||||
|
yield from reader
|
||||||
|
|
||||||
|
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from overrides import EnforceOverrides, override # type: ignore
|
|||||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||||
from lancedb.background_loop import LOOP
|
from lancedb.background_loop import LOOP
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
from ._lancedb import connect as lancedb_connect # type: ignore
|
from ._lancedb import connect as lancedb_connect # type: ignore
|
||||||
from .table import (
|
from .table import (
|
||||||
AsyncTable,
|
AsyncTable,
|
||||||
@@ -26,6 +27,8 @@ from .util import (
|
|||||||
validate_table_name,
|
validate_table_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import deprecation
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
@@ -119,19 +122,11 @@ class DBConnection(EnforceOverrides):
|
|||||||
See available options at
|
See available options at
|
||||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||||
data_storage_version: optional, str, default "stable"
|
data_storage_version: optional, str, default "stable"
|
||||||
The version of the data storage format to use. Newer versions are more
|
Deprecated. Set `storage_options` when connecting to the database and set
|
||||||
efficient but require newer versions of lance to read. The default is
|
`new_table_data_storage_version` in the options.
|
||||||
"stable" which will use the legacy v2 version. See the user guide
|
enable_v2_manifest_paths: optional, bool, default False
|
||||||
for more details.
|
Deprecated. Set `storage_options` when connecting to the database and set
|
||||||
enable_v2_manifest_paths: bool, optional, default False
|
`new_table_enable_v2_manifest_paths` in the options.
|
||||||
Use the new V2 manifest paths. These paths provide more efficient
|
|
||||||
opening of datasets with many versions on object stores. WARNING:
|
|
||||||
turning this on will make the dataset unreadable for older versions
|
|
||||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
|
||||||
use the
|
|
||||||
[Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths]
|
|
||||||
method.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceTable
|
LanceTable
|
||||||
@@ -302,6 +297,12 @@ class DBConnection(EnforceOverrides):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def drop_all_tables(self):
|
||||||
|
"""
|
||||||
|
Drop all tables from the database
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
return self._uri
|
return self._uri
|
||||||
@@ -452,8 +453,6 @@ class LanceDBConnection(DBConnection):
|
|||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
embedding_functions=embedding_functions,
|
embedding_functions=embedding_functions,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
data_storage_version=data_storage_version,
|
|
||||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
|
||||||
)
|
)
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
@@ -496,9 +495,19 @@ class LanceDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def drop_all_tables(self):
|
||||||
|
LOOP.run(self._conn.drop_all_tables())
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
deprecated_in="0.15.1",
|
||||||
|
removed_in="0.17",
|
||||||
|
current_version=__version__,
|
||||||
|
details="Use drop_all_tables() instead",
|
||||||
|
)
|
||||||
@override
|
@override
|
||||||
def drop_database(self):
|
def drop_database(self):
|
||||||
LOOP.run(self._conn.drop_database())
|
LOOP.run(self._conn.drop_all_tables())
|
||||||
|
|
||||||
|
|
||||||
class AsyncConnection(object):
|
class AsyncConnection(object):
|
||||||
@@ -595,9 +604,6 @@ class AsyncConnection(object):
|
|||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
*,
|
*,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
|
||||||
use_legacy_format: Optional[bool] = None,
|
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
|
||||||
) -> AsyncTable:
|
) -> AsyncTable:
|
||||||
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
||||||
|
|
||||||
@@ -640,23 +646,6 @@ class AsyncConnection(object):
|
|||||||
connection will be inherited by the table, but can be overridden here.
|
connection will be inherited by the table, but can be overridden here.
|
||||||
See available options at
|
See available options at
|
||||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||||
data_storage_version: optional, str, default "stable"
|
|
||||||
The version of the data storage format to use. Newer versions are more
|
|
||||||
efficient but require newer versions of lance to read. The default is
|
|
||||||
"stable" which will use the legacy v2 version. See the user guide
|
|
||||||
for more details.
|
|
||||||
use_legacy_format: bool, optional, default False. (Deprecated)
|
|
||||||
If True, use the legacy format for the table. If False, use the new format.
|
|
||||||
This method is deprecated, use `data_storage_version` instead.
|
|
||||||
enable_v2_manifest_paths: bool, optional, default False
|
|
||||||
Use the new V2 manifest paths. These paths provide more efficient
|
|
||||||
opening of datasets with many versions on object stores. WARNING:
|
|
||||||
turning this on will make the dataset unreadable for older versions
|
|
||||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
|
||||||
use the
|
|
||||||
[AsyncTable.migrate_manifest_paths_v2][lancedb.table.AsyncTable.migrate_manifest_paths_v2]
|
|
||||||
method.
|
|
||||||
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -795,17 +784,12 @@ class AsyncConnection(object):
|
|||||||
if mode == "create" and exist_ok:
|
if mode == "create" and exist_ok:
|
||||||
mode = "exist_ok"
|
mode = "exist_ok"
|
||||||
|
|
||||||
if not data_storage_version:
|
|
||||||
data_storage_version = "legacy" if use_legacy_format else "stable"
|
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
new_table = await self._inner.create_empty_table(
|
new_table = await self._inner.create_empty_table(
|
||||||
name,
|
name,
|
||||||
mode,
|
mode,
|
||||||
schema,
|
schema,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
data_storage_version=data_storage_version,
|
|
||||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data = data_to_reader(data, schema)
|
data = data_to_reader(data, schema)
|
||||||
@@ -814,8 +798,6 @@ class AsyncConnection(object):
|
|||||||
mode,
|
mode,
|
||||||
data,
|
data,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
data_storage_version=data_storage_version,
|
|
||||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsyncTable(new_table)
|
return AsyncTable(new_table)
|
||||||
@@ -885,9 +867,19 @@ class AsyncConnection(object):
|
|||||||
if f"Table '{name}' was not found" not in str(e):
|
if f"Table '{name}' was not found" not in str(e):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def drop_all_tables(self):
|
||||||
|
"""Drop all tables from the database."""
|
||||||
|
await self._inner.drop_all_tables()
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
deprecated_in="0.15.1",
|
||||||
|
removed_in="0.17",
|
||||||
|
current_version=__version__,
|
||||||
|
details="Use drop_all_tables() instead",
|
||||||
|
)
|
||||||
async def drop_database(self):
|
async def drop_database(self):
|
||||||
"""
|
"""
|
||||||
Drop database
|
Drop database
|
||||||
This is the same thing as dropping all the tables
|
This is the same thing as dropping all the tables
|
||||||
"""
|
"""
|
||||||
await self._inner.drop_db()
|
await self._inner.drop_all_tables()
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import copy
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from lancedb.util import add_note
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC):
|
|||||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||||
)
|
)
|
||||||
_ndims: int = PrivateAttr()
|
_ndims: int = PrivateAttr()
|
||||||
|
_original_args: dict = PrivateAttr()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, **kwargs):
|
def create(cls, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create an instance of the embedding function
|
Create an instance of the embedding function
|
||||||
"""
|
"""
|
||||||
return cls(**kwargs)
|
resolved_kwargs = cls.__resolveVariables(kwargs)
|
||||||
|
instance = cls(**resolved_kwargs)
|
||||||
|
instance._original_args = kwargs
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __resolveVariables(cls, args: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Resolve variables in the args
|
||||||
|
"""
|
||||||
|
from .registry import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
|
new_args = copy.deepcopy(args)
|
||||||
|
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
sensitive_keys = cls.sensitive_keys()
|
||||||
|
for k, v in new_args.items():
|
||||||
|
if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys:
|
||||||
|
exc = ValueError(
|
||||||
|
f"Sensitive key '{k}' cannot be set to a hardcoded value"
|
||||||
|
)
|
||||||
|
add_note(exc, "Help: Use $var: to set sensitive keys to variables")
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
if isinstance(v, str) and v.startswith("$var:"):
|
||||||
|
parts = v[5:].split(":", maxsplit=1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
try:
|
||||||
|
new_args[k] = registry.get_var(parts[0])
|
||||||
|
except KeyError:
|
||||||
|
exc = ValueError(
|
||||||
|
"Variable '{}' not found in registry".format(parts[0])
|
||||||
|
)
|
||||||
|
add_note(
|
||||||
|
exc,
|
||||||
|
"Help: Variables are reset in new Python sessions. "
|
||||||
|
"Use `registry.set_var` to set variables.",
|
||||||
|
)
|
||||||
|
raise exc
|
||||||
|
else:
|
||||||
|
name, default = parts
|
||||||
|
try:
|
||||||
|
new_args[k] = registry.get_var(name)
|
||||||
|
except KeyError:
|
||||||
|
new_args[k] = default
|
||||||
|
return new_args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sensitive_keys() -> List[str]:
|
||||||
|
"""
|
||||||
|
Return a list of keys that are sensitive and should not be allowed
|
||||||
|
to be set to hardcoded values in the config. For example, API keys.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||||
@@ -103,20 +159,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
|||||||
return texts
|
return texts
|
||||||
|
|
||||||
def safe_model_dump(self):
|
def safe_model_dump(self):
|
||||||
from ..pydantic import PYDANTIC_VERSION
|
if not hasattr(self, "_original_args"):
|
||||||
|
raise ValueError(
|
||||||
if PYDANTIC_VERSION.major < 2:
|
"EmbeddingFunction was not created with EmbeddingFunction.create()"
|
||||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
)
|
||||||
return self.model_dump(
|
return self._original_args
|
||||||
exclude={
|
|
||||||
field_name
|
|
||||||
for field_name in self.model_fields
|
|
||||||
if field_name.startswith("_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def ndims(self):
|
def ndims(self) -> int:
|
||||||
"""
|
"""
|
||||||
Return the dimensions of the vector column
|
Return the dimensions of the vector column
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
# TODO: fix hardcoding
|
# TODO: fix hardcoding
|
||||||
return 768
|
return 768
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sensitive_keys() -> List[str]:
|
||||||
|
return ["api_key"]
|
||||||
|
|
||||||
def sanitize_input(
|
def sanitize_input(
|
||||||
self, inputs: Union[TEXT, IMAGES]
|
self, inputs: Union[TEXT, IMAGES]
|
||||||
) -> Union[List[Any], np.ndarray]:
|
) -> Union[List[Any], np.ndarray]:
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
def ndims(self):
|
def ndims(self):
|
||||||
return self._ndims
|
return self._ndims
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sensitive_keys():
|
||||||
|
return ["api_key"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def model_names():
|
def model_names():
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._functions = {}
|
self._functions = {}
|
||||||
|
self._variables = {}
|
||||||
|
|
||||||
def register(self, alias: str = None):
|
def register(self, alias: str = None):
|
||||||
"""
|
"""
|
||||||
@@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry:
|
|||||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||||
return {"embedding_functions": metadata}
|
return {"embedding_functions": metadata}
|
||||||
|
|
||||||
|
def set_var(self, name: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a variable. These can be accessed in embedding configuration using
|
||||||
|
the syntax `$var:variable_name`. If they are not set, an error will be
|
||||||
|
thrown letting you know which variable is missing. If you want to supply
|
||||||
|
a default value, you can add an additional part in the configuration
|
||||||
|
like so: `$var:variable_name:default_value`. Default values can be
|
||||||
|
used for runtime configurations that are not sensitive, such as
|
||||||
|
whether to use a GPU for inference.
|
||||||
|
|
||||||
|
The name must not contain a colon. Default values can contain colons.
|
||||||
|
"""
|
||||||
|
if ":" in name:
|
||||||
|
raise ValueError("Variable names cannot contain colons")
|
||||||
|
self._variables[name] = value
|
||||||
|
|
||||||
|
def get_var(self, name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a variable.
|
||||||
|
"""
|
||||||
|
return self._variables[name]
|
||||||
|
|
||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction):
|
|||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
params: Optional[Dict] = None
|
params: Optional[Dict] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sensitive_keys():
|
||||||
|
return ["api_key"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def model_names():
|
def model_names():
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -199,18 +199,29 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||||
|
if inspect.isclass(tp):
|
||||||
|
if issubclass(tp, pydantic.BaseModel):
|
||||||
|
# Struct
|
||||||
|
fields = _pydantic_model_to_fields(tp)
|
||||||
|
return pa.struct(fields)
|
||||||
|
if issubclass(tp, FixedSizeListMixin):
|
||||||
|
return pa.list_(tp.value_arrow_type(), tp.dim())
|
||||||
|
return _py_type_to_arrow_type(tp, field)
|
||||||
|
|
||||||
|
|
||||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||||
|
|
||||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||||
origin = field.annotation.__origin__
|
origin = field.annotation.__origin__
|
||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
|
|
||||||
if origin is list:
|
if origin is list:
|
||||||
child = args[0]
|
child = args[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||||
elif origin == Union:
|
elif origin == Union:
|
||||||
if len(args) == 2 and args[1] is type(None):
|
if len(args) == 2 and args[1] is type(None):
|
||||||
return _py_type_to_arrow_type(args[0], field)
|
return _pydantic_type_to_arrow_type(args[0], field)
|
||||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
if len(args) == 2:
|
if len(args) == 2:
|
||||||
@@ -218,14 +229,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
|||||||
if typ is type(None):
|
if typ is type(None):
|
||||||
continue
|
continue
|
||||||
return _py_type_to_arrow_type(typ, field)
|
return _py_type_to_arrow_type(typ, field)
|
||||||
elif inspect.isclass(field.annotation):
|
return _pydantic_type_to_arrow_type(field.annotation, field)
|
||||||
if issubclass(field.annotation, pydantic.BaseModel):
|
|
||||||
# Struct
|
|
||||||
fields = _pydantic_model_to_fields(field.annotation)
|
|
||||||
return pa.struct(fields)
|
|
||||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
|
||||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
|
||||||
return _py_type_to_arrow_type(field.annotation, field)
|
|
||||||
|
|
||||||
|
|
||||||
def is_nullable(field: FieldInfo) -> bool:
|
def is_nullable(field: FieldInfo) -> bool:
|
||||||
@@ -255,7 +259,8 @@ def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
|
|||||||
|
|
||||||
|
|
||||||
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||||
"""Convert a Pydantic model to a PyArrow Schema.
|
"""Convert a [Pydantic Model][pydantic.BaseModel] to a
|
||||||
|
[PyArrow Schema][pyarrow.Schema].
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -265,24 +270,25 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
pyarrow.Schema
|
pyarrow.Schema
|
||||||
|
The Arrow Schema
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
|
||||||
>>> from typing import List, Optional
|
>>> from typing import List, Optional
|
||||||
>>> import pydantic
|
>>> import pydantic
|
||||||
>>> from lancedb.pydantic import pydantic_to_schema
|
>>> from lancedb.pydantic import pydantic_to_schema, Vector
|
||||||
>>> class FooModel(pydantic.BaseModel):
|
>>> class FooModel(pydantic.BaseModel):
|
||||||
... id: int
|
... id: int
|
||||||
... s: str
|
... s: str
|
||||||
... vec: List[float]
|
... vec: Vector(1536) # fixed_size_list<item: float32>[1536]
|
||||||
... li: List[int]
|
... li: List[int]
|
||||||
...
|
...
|
||||||
>>> schema = pydantic_to_schema(FooModel)
|
>>> schema = pydantic_to_schema(FooModel)
|
||||||
>>> assert schema == pa.schema([
|
>>> assert schema == pa.schema([
|
||||||
... pa.field("id", pa.int64(), False),
|
... pa.field("id", pa.int64(), False),
|
||||||
... pa.field("s", pa.utf8(), False),
|
... pa.field("s", pa.utf8(), False),
|
||||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
... pa.field("vec", pa.list_(pa.float32(), 1536)),
|
||||||
... pa.field("li", pa.list_(pa.int64()), False),
|
... pa.field("li", pa.list_(pa.int64()), False),
|
||||||
... ])
|
... ])
|
||||||
"""
|
"""
|
||||||
@@ -304,7 +310,7 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
... vector: Vector(2)
|
... vector: Vector(2)
|
||||||
...
|
...
|
||||||
>>> db = lancedb.connect("./example")
|
>>> db = lancedb.connect("./example")
|
||||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
>>> table = db.create_table("test", schema=TestModel)
|
||||||
>>> table.add([
|
>>> table.add([
|
||||||
... TestModel(name="test", vector=[1.0, 2.0])
|
... TestModel(name="test", vector=[1.0, 2.0])
|
||||||
... ])
|
... ])
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class ClientConfig:
|
|||||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||||
|
extra_headers: Optional[dict] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.retry_config, dict):
|
if isinstance(self.retry_config, dict):
|
||||||
|
|||||||
@@ -526,6 +526,9 @@ class RemoteTable(Table):
|
|||||||
def drop_columns(self, columns: Iterable[str]):
|
def drop_columns(self, columns: Iterable[str]):
|
||||||
return LOOP.run(self._table.drop_columns(columns))
|
return LOOP.run(self._table.drop_columns(columns))
|
||||||
|
|
||||||
|
def drop_index(self, index_name: str):
|
||||||
|
return LOOP.run(self._table.drop_index(index_name))
|
||||||
|
|
||||||
def uses_v2_manifest_paths(self) -> bool:
|
def uses_v2_manifest_paths(self) -> bool:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@@ -23,11 +25,13 @@ from typing import (
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
|
from lancedb.arrow import peek_reader
|
||||||
from lancedb.background_loop import LOOP
|
from lancedb.background_loop import LOOP
|
||||||
from .dependencies import _check_for_pandas
|
from .dependencies import _check_for_pandas
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
|
import numpy as np
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
from lance.dependencies import _check_for_hugging_face
|
from lance.dependencies import _check_for_hugging_face
|
||||||
|
|
||||||
@@ -37,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
|||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import (
|
from .query import (
|
||||||
|
AsyncFTSQuery,
|
||||||
|
AsyncHybridQuery,
|
||||||
AsyncQuery,
|
AsyncQuery,
|
||||||
AsyncVectorQuery,
|
AsyncVectorQuery,
|
||||||
LanceEmptyQueryBuilder,
|
LanceEmptyQueryBuilder,
|
||||||
@@ -73,17 +79,19 @@ pl = safe_import_polars()
|
|||||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||||
|
|
||||||
|
|
||||||
def _into_pyarrow_table(data) -> pa.Table:
|
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||||
if _check_for_hugging_face(data):
|
if _check_for_hugging_face(data):
|
||||||
# Huggingface datasets
|
# Huggingface datasets
|
||||||
from lance.dependencies import datasets
|
from lance.dependencies import datasets
|
||||||
|
|
||||||
if isinstance(data, datasets.Dataset):
|
if isinstance(data, datasets.Dataset):
|
||||||
schema = data.features.arrow_schema
|
schema = data.features.arrow_schema
|
||||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
|
||||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||||
schema = _schema_from_hf(data, schema)
|
schema = _schema_from_hf(data, schema)
|
||||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
return pa.RecordBatchReader.from_batches(
|
||||||
|
schema, _to_batches_with_split(data)
|
||||||
|
)
|
||||||
if isinstance(data, LanceModel):
|
if isinstance(data, LanceModel):
|
||||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||||
|
|
||||||
@@ -95,41 +103,41 @@ def _into_pyarrow_table(data) -> pa.Table:
|
|||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
schema = data[0].__class__.to_arrow_schema()
|
schema = data[0].__class__.to_arrow_schema()
|
||||||
data = [model_to_dict(d) for d in data]
|
data = [model_to_dict(d) for d in data]
|
||||||
return pa.Table.from_pylist(data, schema=schema)
|
return pa.Table.from_pylist(data, schema=schema).to_reader()
|
||||||
elif isinstance(data[0], pa.RecordBatch):
|
elif isinstance(data[0], pa.RecordBatch):
|
||||||
return pa.Table.from_batches(data)
|
return pa.Table.from_batches(data).to_reader()
|
||||||
else:
|
else:
|
||||||
return pa.Table.from_pylist(data)
|
return pa.Table.from_pylist(data).to_reader()
|
||||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||||
# Do not serialize Pandas metadata
|
# Do not serialize Pandas metadata
|
||||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||||
return table.replace_schema_metadata(meta)
|
return table.replace_schema_metadata(meta).to_reader()
|
||||||
elif isinstance(data, pa.Table):
|
elif isinstance(data, pa.Table):
|
||||||
return data
|
return data.to_reader()
|
||||||
elif isinstance(data, pa.RecordBatch):
|
elif isinstance(data, pa.RecordBatch):
|
||||||
return pa.Table.from_batches([data])
|
return pa.RecordBatchReader.from_batches(data.schema, [data])
|
||||||
elif isinstance(data, LanceDataset):
|
elif isinstance(data, LanceDataset):
|
||||||
return data.scanner().to_table()
|
return data.scanner().to_reader()
|
||||||
elif isinstance(data, pa.dataset.Dataset):
|
elif isinstance(data, pa.dataset.Dataset):
|
||||||
return data.to_table()
|
return data.scanner().to_reader()
|
||||||
elif isinstance(data, pa.dataset.Scanner):
|
elif isinstance(data, pa.dataset.Scanner):
|
||||||
return data.to_table()
|
return data.to_reader()
|
||||||
elif isinstance(data, pa.RecordBatchReader):
|
elif isinstance(data, pa.RecordBatchReader):
|
||||||
return data.read_all()
|
return data
|
||||||
elif (
|
elif (
|
||||||
type(data).__module__.startswith("polars")
|
type(data).__module__.startswith("polars")
|
||||||
and data.__class__.__name__ == "DataFrame"
|
and data.__class__.__name__ == "DataFrame"
|
||||||
):
|
):
|
||||||
return data.to_arrow()
|
return data.to_arrow().to_reader()
|
||||||
elif (
|
elif (
|
||||||
type(data).__module__.startswith("polars")
|
type(data).__module__.startswith("polars")
|
||||||
and data.__class__.__name__ == "LazyFrame"
|
and data.__class__.__name__ == "LazyFrame"
|
||||||
):
|
):
|
||||||
return data.collect().to_arrow()
|
return data.collect().to_arrow().to_reader()
|
||||||
elif isinstance(data, Iterable):
|
elif isinstance(data, Iterable):
|
||||||
return _iterator_to_table(data)
|
return _iterator_to_reader(data)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown data type {type(data)}. "
|
f"Unknown data type {type(data)}. "
|
||||||
@@ -139,30 +147,28 @@ def _into_pyarrow_table(data) -> pa.Table:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _iterator_to_table(data: Iterable) -> pa.Table:
|
def _iterator_to_reader(data: Iterable) -> pa.RecordBatchReader:
|
||||||
batches = []
|
# Each batch is treated as it's own reader, mainly so we can
|
||||||
schema = None # Will get schema from first batch
|
# re-use the _into_pyarrow_reader logic.
|
||||||
for batch in data:
|
first = _into_pyarrow_reader(next(data))
|
||||||
batch_table = _into_pyarrow_table(batch)
|
schema = first.schema
|
||||||
if schema is not None:
|
|
||||||
if batch_table.schema != schema:
|
def gen():
|
||||||
|
yield from first
|
||||||
|
for batch in data:
|
||||||
|
table: pa.Table = _into_pyarrow_reader(batch).read_all()
|
||||||
|
if table.schema != schema:
|
||||||
try:
|
try:
|
||||||
batch_table = batch_table.cast(schema)
|
table = table.cast(schema)
|
||||||
except pa.lib.ArrowInvalid:
|
except pa.lib.ArrowInvalid:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input iterator yielded a batch with schema that "
|
f"Input iterator yielded a batch with schema that "
|
||||||
f"does not match the schema of other batches.\n"
|
f"does not match the schema of other batches.\n"
|
||||||
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
|
f"Expected:\n{schema}\nGot:\n{batch.schema}"
|
||||||
)
|
)
|
||||||
else:
|
yield from table.to_batches()
|
||||||
# Use the first schema for the remainder of the batches
|
|
||||||
schema = batch_table.schema
|
|
||||||
batches.append(batch_table)
|
|
||||||
|
|
||||||
if batches:
|
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||||
return pa.concat_tables(batches)
|
|
||||||
else:
|
|
||||||
raise ValueError("Input iterable is empty")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_data(
|
def _sanitize_data(
|
||||||
@@ -173,7 +179,7 @@ def _sanitize_data(
|
|||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
*,
|
*,
|
||||||
allow_subschema: bool = False,
|
allow_subschema: bool = False,
|
||||||
) -> pa.Table:
|
) -> pa.RecordBatchReader:
|
||||||
"""
|
"""
|
||||||
Handle input data, applying all standard transformations.
|
Handle input data, applying all standard transformations.
|
||||||
|
|
||||||
@@ -206,20 +212,20 @@ def _sanitize_data(
|
|||||||
# 1. There might be embedding columns missing that will be added
|
# 1. There might be embedding columns missing that will be added
|
||||||
# in the add_embeddings step.
|
# in the add_embeddings step.
|
||||||
# 2. If `allow_subschemas` is True, there might be columns missing.
|
# 2. If `allow_subschemas` is True, there might be columns missing.
|
||||||
table = _into_pyarrow_table(data)
|
reader = _into_pyarrow_reader(data)
|
||||||
|
|
||||||
table = _append_vector_columns(table, target_schema, metadata=metadata)
|
reader = _append_vector_columns(reader, target_schema, metadata=metadata)
|
||||||
|
|
||||||
# This happens before the cast so we can fix vector columns with
|
# This happens before the cast so we can fix vector columns with
|
||||||
# incorrect lengths before they are cast to FSL.
|
# incorrect lengths before they are cast to FSL.
|
||||||
table = _handle_bad_vectors(
|
reader = _handle_bad_vectors(
|
||||||
table,
|
reader,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_schema is None:
|
if target_schema is None:
|
||||||
target_schema = _infer_target_schema(table)
|
target_schema, reader = _infer_target_schema(reader)
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
new_metadata = target_schema.metadata or {}
|
new_metadata = target_schema.metadata or {}
|
||||||
@@ -228,25 +234,25 @@ def _sanitize_data(
|
|||||||
|
|
||||||
_validate_schema(target_schema)
|
_validate_schema(target_schema)
|
||||||
|
|
||||||
table = _cast_to_target_schema(table, target_schema, allow_subschema)
|
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||||
|
|
||||||
return table
|
return reader
|
||||||
|
|
||||||
|
|
||||||
def _cast_to_target_schema(
|
def _cast_to_target_schema(
|
||||||
table: pa.Table,
|
reader: pa.RecordBatchReader,
|
||||||
target_schema: pa.Schema,
|
target_schema: pa.Schema,
|
||||||
allow_subschema: bool = False,
|
allow_subschema: bool = False,
|
||||||
) -> pa.Table:
|
) -> pa.RecordBatchReader:
|
||||||
# pa.Table.cast expects field order not to be changed.
|
# pa.Table.cast expects field order not to be changed.
|
||||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||||
# to match the target schema. We just need to correctly cast the fields.
|
# to match the target schema. We just need to correctly cast the fields.
|
||||||
if table.schema == target_schema:
|
if reader.schema == target_schema:
|
||||||
# Fast path when the schemas are already the same
|
# Fast path when the schemas are already the same
|
||||||
return table
|
return reader
|
||||||
|
|
||||||
fields = []
|
fields = []
|
||||||
for field in table.schema:
|
for field in reader.schema:
|
||||||
target_field = target_schema.field(field.name)
|
target_field = target_schema.field(field.name)
|
||||||
if target_field is None:
|
if target_field is None:
|
||||||
raise ValueError(f"Field {field.name} not found in target schema")
|
raise ValueError(f"Field {field.name} not found in target schema")
|
||||||
@@ -259,12 +265,16 @@ def _cast_to_target_schema(
|
|||||||
|
|
||||||
if allow_subschema and len(reordered_schema) != len(target_schema):
|
if allow_subschema and len(reordered_schema) != len(target_schema):
|
||||||
fields = _infer_subschema(
|
fields = _infer_subschema(
|
||||||
list(iter(table.schema)), list(iter(reordered_schema))
|
list(iter(reader.schema)), list(iter(reordered_schema))
|
||||||
)
|
)
|
||||||
subschema = pa.schema(fields, metadata=target_schema.metadata)
|
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||||
return table.cast(subschema)
|
|
||||||
else:
|
def gen():
|
||||||
return table.cast(reordered_schema)
|
for batch in reader:
|
||||||
|
# Table but not RecordBatch has cast.
|
||||||
|
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
|
||||||
|
|
||||||
|
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||||
|
|
||||||
|
|
||||||
def _infer_subschema(
|
def _infer_subschema(
|
||||||
@@ -343,7 +353,10 @@ def sanitize_create_table(
|
|||||||
if metadata:
|
if metadata:
|
||||||
schema = schema.with_metadata(metadata)
|
schema = schema.with_metadata(metadata)
|
||||||
# Need to apply metadata to the data as well
|
# Need to apply metadata to the data as well
|
||||||
data = data.replace_schema_metadata(metadata)
|
if isinstance(data, pa.Table):
|
||||||
|
data = data.replace_schema_metadata(metadata)
|
||||||
|
elif isinstance(data, pa.RecordBatchReader):
|
||||||
|
data = pa.RecordBatchReader.from_batches(schema, data)
|
||||||
|
|
||||||
return data, schema
|
return data, schema
|
||||||
|
|
||||||
@@ -380,11 +393,11 @@ def _to_batches_with_split(data):
|
|||||||
|
|
||||||
|
|
||||||
def _append_vector_columns(
|
def _append_vector_columns(
|
||||||
data: pa.Table,
|
reader: pa.RecordBatchReader,
|
||||||
schema: Optional[pa.Schema] = None,
|
schema: Optional[pa.Schema] = None,
|
||||||
*,
|
*,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
) -> pa.Table:
|
) -> pa.RecordBatchReader:
|
||||||
"""
|
"""
|
||||||
Use the embedding function to automatically embed the source columns and add the
|
Use the embedding function to automatically embed the source columns and add the
|
||||||
vector columns to the table.
|
vector columns to the table.
|
||||||
@@ -395,28 +408,43 @@ def _append_vector_columns(
|
|||||||
metadata = schema.metadata or metadata or {}
|
metadata = schema.metadata or metadata or {}
|
||||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||||
|
|
||||||
|
if not functions:
|
||||||
|
return reader
|
||||||
|
|
||||||
|
fields = list(reader.schema)
|
||||||
for vector_column, conf in functions.items():
|
for vector_column, conf in functions.items():
|
||||||
func = conf.function
|
if vector_column not in reader.schema.names:
|
||||||
no_vector_column = vector_column not in data.column_names
|
|
||||||
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
|
|
||||||
col_data = func.compute_source_embeddings_with_retry(
|
|
||||||
data[conf.source_column]
|
|
||||||
)
|
|
||||||
if schema is not None:
|
if schema is not None:
|
||||||
dtype = schema.field(vector_column).type
|
field = schema.field(vector_column)
|
||||||
else:
|
else:
|
||||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
dtype = pa.list_(pa.float32(), conf.function.ndims())
|
||||||
if no_vector_column:
|
field = pa.field(vector_column, type=dtype, nullable=True)
|
||||||
data = data.append_column(
|
fields.append(field)
|
||||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
schema = pa.schema(fields, metadata=reader.schema.metadata)
|
||||||
)
|
|
||||||
else:
|
def gen():
|
||||||
data = data.set_column(
|
for batch in reader:
|
||||||
data.column_names.index(vector_column),
|
for vector_column, conf in functions.items():
|
||||||
pa.field(vector_column, type=dtype),
|
func = conf.function
|
||||||
pa.array(col_data, type=dtype),
|
no_vector_column = vector_column not in batch.column_names
|
||||||
)
|
if no_vector_column or pc.all(pc.is_null(batch[vector_column])).as_py():
|
||||||
return data
|
col_data = func.compute_source_embeddings_with_retry(
|
||||||
|
batch[conf.source_column]
|
||||||
|
)
|
||||||
|
if no_vector_column:
|
||||||
|
batch = batch.append_column(
|
||||||
|
schema.field(vector_column),
|
||||||
|
pa.array(col_data, type=schema.field(vector_column).type),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = batch.set_column(
|
||||||
|
batch.column_names.index(vector_column),
|
||||||
|
schema.field(vector_column),
|
||||||
|
pa.array(col_data, type=schema.field(vector_column).type),
|
||||||
|
)
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||||
|
|
||||||
|
|
||||||
def _table_path(base: str, table_name: str) -> str:
|
def _table_path(base: str, table_name: str) -> str:
|
||||||
@@ -2085,10 +2113,37 @@ class LanceTable(Table):
|
|||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
embedding_functions: list of EmbeddingFunctionModel, default None
|
embedding_functions: list of EmbeddingFunctionModel, default None
|
||||||
The embedding functions to use when creating the table.
|
The embedding functions to use when creating the table.
|
||||||
|
data_storage_version: optional, str, default "stable"
|
||||||
|
Deprecated. Set `storage_options` when connecting to the database and set
|
||||||
|
`new_table_data_storage_version` in the options.
|
||||||
|
enable_v2_manifest_paths: optional, bool, default False
|
||||||
|
Deprecated. Set `storage_options` when connecting to the database and set
|
||||||
|
`new_table_enable_v2_manifest_paths` in the options.
|
||||||
"""
|
"""
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
self._conn = db
|
self._conn = db
|
||||||
|
|
||||||
|
if data_storage_version is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"setting data_storage_version directly on create_table is deprecated. ",
|
||||||
|
"Use database_options instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
if storage_options is None:
|
||||||
|
storage_options = {}
|
||||||
|
storage_options["new_table_data_storage_version"] = data_storage_version
|
||||||
|
if enable_v2_manifest_paths is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"setting enable_v2_manifest_paths directly on create_table is ",
|
||||||
|
"deprecated. Use database_options instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
if storage_options is None:
|
||||||
|
storage_options = {}
|
||||||
|
storage_options["new_table_enable_v2_manifest_paths"] = (
|
||||||
|
enable_v2_manifest_paths
|
||||||
|
)
|
||||||
|
|
||||||
self._table = LOOP.run(
|
self._table = LOOP.run(
|
||||||
self._conn._conn.create_table(
|
self._conn._conn.create_table(
|
||||||
name,
|
name,
|
||||||
@@ -2100,8 +2155,6 @@ class LanceTable(Table):
|
|||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
embedding_functions=embedding_functions,
|
embedding_functions=embedding_functions,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
data_storage_version=data_storage_version,
|
|
||||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
@@ -2332,11 +2385,13 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
|
|
||||||
def _handle_bad_vectors(
|
def _handle_bad_vectors(
|
||||||
table: pa.Table,
|
reader: pa.RecordBatchReader,
|
||||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.RecordBatchReader:
|
||||||
for field in table.schema:
|
vector_columns = []
|
||||||
|
|
||||||
|
for field in reader.schema:
|
||||||
# They can provide a 'vector' column that isn't yet a FSL
|
# They can provide a 'vector' column that isn't yet a FSL
|
||||||
named_vector_col = (
|
named_vector_col = (
|
||||||
(
|
(
|
||||||
@@ -2356,22 +2411,28 @@ def _handle_bad_vectors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if named_vector_col or likely_vector_col:
|
if named_vector_col or likely_vector_col:
|
||||||
table = _handle_bad_vector_column(
|
vector_columns.append(field.name)
|
||||||
table,
|
|
||||||
vector_column_name=field.name,
|
|
||||||
on_bad_vectors=on_bad_vectors,
|
|
||||||
fill_value=fill_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
return table
|
def gen():
|
||||||
|
for batch in reader:
|
||||||
|
for name in vector_columns:
|
||||||
|
batch = _handle_bad_vector_column(
|
||||||
|
batch,
|
||||||
|
vector_column_name=name,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
return pa.RecordBatchReader.from_batches(reader.schema, gen())
|
||||||
|
|
||||||
|
|
||||||
def _handle_bad_vector_column(
|
def _handle_bad_vector_column(
|
||||||
data: pa.Table,
|
data: pa.RecordBatch,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.RecordBatch:
|
||||||
"""
|
"""
|
||||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||||
|
|
||||||
@@ -2459,8 +2520,11 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
|
|||||||
return pc.is_in(indices, has_nan_indices)
|
return pc.is_in(indices, has_nan_indices)
|
||||||
|
|
||||||
|
|
||||||
def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
def _infer_target_schema(
|
||||||
schema = table.schema
|
reader: pa.RecordBatchReader,
|
||||||
|
) -> Tuple[pa.Schema, pa.RecordBatchReader]:
|
||||||
|
schema = reader.schema
|
||||||
|
peeked = None
|
||||||
|
|
||||||
for i, field in enumerate(schema):
|
for i, field in enumerate(schema):
|
||||||
if (
|
if (
|
||||||
@@ -2468,8 +2532,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
|||||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||||
and pa.types.is_floating(field.type.value_type)
|
and pa.types.is_floating(field.type.value_type)
|
||||||
):
|
):
|
||||||
|
if peeked is None:
|
||||||
|
peeked, reader = peek_reader(reader)
|
||||||
# Use the most common length of the list as the dimensions
|
# Use the most common length of the list as the dimensions
|
||||||
dim = _modal_list_size(table.column(i))
|
dim = _modal_list_size(peeked.column(i))
|
||||||
|
|
||||||
new_field = pa.field(
|
new_field = pa.field(
|
||||||
VECTOR_COLUMN_NAME,
|
VECTOR_COLUMN_NAME,
|
||||||
@@ -2483,8 +2549,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
|||||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||||
and pa.types.is_integer(field.type.value_type)
|
and pa.types.is_integer(field.type.value_type)
|
||||||
):
|
):
|
||||||
|
if peeked is None:
|
||||||
|
peeked, reader = peek_reader(reader)
|
||||||
# Use the most common length of the list as the dimensions
|
# Use the most common length of the list as the dimensions
|
||||||
dim = _modal_list_size(table.column(i))
|
dim = _modal_list_size(peeked.column(i))
|
||||||
new_field = pa.field(
|
new_field = pa.field(
|
||||||
VECTOR_COLUMN_NAME,
|
VECTOR_COLUMN_NAME,
|
||||||
pa.list_(pa.uint8(), dim),
|
pa.list_(pa.uint8(), dim),
|
||||||
@@ -2493,7 +2561,7 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
|||||||
|
|
||||||
schema = schema.set(i, new_field)
|
schema = schema.set(i, new_field)
|
||||||
|
|
||||||
return schema
|
return schema, reader
|
||||||
|
|
||||||
|
|
||||||
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||||
@@ -2615,7 +2683,7 @@ class AsyncTable:
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def is_open(self) -> bool:
|
def is_open(self) -> bool:
|
||||||
"""Return True if the table is closed."""
|
"""Return True if the table is open."""
|
||||||
return self._inner.is_open()
|
return self._inner.is_open()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@@ -2638,6 +2706,19 @@ class AsyncTable:
|
|||||||
"""
|
"""
|
||||||
return await self._inner.schema()
|
return await self._inner.schema()
|
||||||
|
|
||||||
|
async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||||
|
"""
|
||||||
|
Get the embedding functions for the table
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
funcs: Dict[str, EmbeddingFunctionConfig]
|
||||||
|
A mapping of the vector column to the embedding function
|
||||||
|
or empty dict if not configured.
|
||||||
|
"""
|
||||||
|
schema = await self.schema()
|
||||||
|
return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||||
|
|
||||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Count the number of rows in the table.
|
Count the number of rows in the table.
|
||||||
@@ -2867,6 +2948,234 @@ class AsyncTable:
|
|||||||
|
|
||||||
return LanceMergeInsertBuilder(self, on)
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[str]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: Literal["auto"] = ...,
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[str]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: Literal["hybrid"] = ...,
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> AsyncHybridQuery: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: Literal["auto"] = ...,
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> AsyncVectorQuery: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: Literal["fts"] = ...,
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> AsyncFTSQuery: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: Literal["vector"] = ...,
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> AsyncVectorQuery: ...
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: QueryType = "auto",
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> AsyncQuery:
|
||||||
|
"""Create a search query to find the nearest neighbors
|
||||||
|
of the given query vector. We currently support [vector search][search]
|
||||||
|
and [full-text search][experimental-full-text-search].
|
||||||
|
|
||||||
|
All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery].
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||||
|
The targetted vector to search for.
|
||||||
|
|
||||||
|
- *default None*.
|
||||||
|
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||||
|
|
||||||
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
|
the table
|
||||||
|
vector_column_name: str, optional
|
||||||
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
|
|
||||||
|
- If not specified then the vector column is inferred from
|
||||||
|
the table schema
|
||||||
|
|
||||||
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
|
needs to be specified. Otherwise, an error is raised.
|
||||||
|
query_type: str
|
||||||
|
*default "auto"*.
|
||||||
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
|
|
||||||
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
|
- If `query` is a list/np.ndarray then the query type is
|
||||||
|
"vector";
|
||||||
|
|
||||||
|
- If `query` is a PIL.Image.Image then either do vector search,
|
||||||
|
or raise an error if no corresponding embedding function is found.
|
||||||
|
|
||||||
|
- If `query` is a string, then the query type is "vector" if the
|
||||||
|
table has embedding functions else the query type is "fts"
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
A query builder object representing the query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_embedding(query):
|
||||||
|
return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray))
|
||||||
|
|
||||||
|
async def get_embedding_func(
|
||||||
|
vector_column_name: Optional[str],
|
||||||
|
query_type: QueryType,
|
||||||
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
|
||||||
|
) -> Tuple[str, EmbeddingFunctionConfig]:
|
||||||
|
schema = await self.schema()
|
||||||
|
vector_column_name = infer_vector_column_name(
|
||||||
|
schema=schema,
|
||||||
|
query_type=query_type,
|
||||||
|
query=query,
|
||||||
|
vector_column_name=vector_column_name,
|
||||||
|
)
|
||||||
|
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||||
|
schema.metadata
|
||||||
|
)
|
||||||
|
func = funcs.get(vector_column_name)
|
||||||
|
if func is None:
|
||||||
|
error = ValueError(
|
||||||
|
f"Column '{vector_column_name}' has no registered "
|
||||||
|
"embedding function."
|
||||||
|
)
|
||||||
|
if len(funcs) > 0:
|
||||||
|
add_note(
|
||||||
|
error,
|
||||||
|
"Embedding functions are registered for columns: "
|
||||||
|
f"{list(funcs.keys())}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
add_note(
|
||||||
|
error, "No embedding functions are registered for any columns."
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
return vector_column_name, func
|
||||||
|
|
||||||
|
async def make_embedding(embedding, query):
|
||||||
|
if embedding is not None:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# This function is likely to block, since it either calls an expensive
|
||||||
|
# function or makes an HTTP request to an embeddings REST API.
|
||||||
|
return (
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
embedding.function.compute_query_embeddings_with_retry,
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if query_type == "auto":
|
||||||
|
# Infer the query type.
|
||||||
|
if is_embedding(query):
|
||||||
|
vector_query = query
|
||||||
|
query_type = "vector"
|
||||||
|
elif isinstance(query, str):
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
indices,
|
||||||
|
(vector_column_name, embedding_conf),
|
||||||
|
) = await asyncio.gather(
|
||||||
|
self.list_indices(),
|
||||||
|
get_embedding_func(vector_column_name, "auto", query),
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if "Column" in str(
|
||||||
|
e
|
||||||
|
) and "has no registered embedding function" in str(e):
|
||||||
|
# If the column has no registered embedding function,
|
||||||
|
# then it's an FTS query.
|
||||||
|
query_type = "fts"
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
if embedding_conf is not None:
|
||||||
|
vector_query = await make_embedding(embedding_conf, query)
|
||||||
|
if any(
|
||||||
|
i.columns[0] == embedding_conf.source_column
|
||||||
|
and i.index_type == "FTS"
|
||||||
|
for i in indices
|
||||||
|
):
|
||||||
|
query_type = "hybrid"
|
||||||
|
else:
|
||||||
|
query_type = "vector"
|
||||||
|
else:
|
||||||
|
query_type = "fts"
|
||||||
|
else:
|
||||||
|
# it's an image or something else embeddable.
|
||||||
|
query_type = "vector"
|
||||||
|
elif query_type == "vector":
|
||||||
|
if is_embedding(query):
|
||||||
|
vector_query = query
|
||||||
|
else:
|
||||||
|
vector_column_name, embedding_conf = await get_embedding_func(
|
||||||
|
vector_column_name, query_type, query
|
||||||
|
)
|
||||||
|
vector_query = await make_embedding(embedding_conf, query)
|
||||||
|
elif query_type == "hybrid":
|
||||||
|
if is_embedding(query):
|
||||||
|
raise ValueError("Hybrid search requires a text query")
|
||||||
|
else:
|
||||||
|
vector_column_name, embedding_conf = await get_embedding_func(
|
||||||
|
vector_column_name, query_type, query
|
||||||
|
)
|
||||||
|
vector_query = await make_embedding(embedding_conf, query)
|
||||||
|
|
||||||
|
if query_type == "vector":
|
||||||
|
builder = self.query().nearest_to(vector_query)
|
||||||
|
if vector_column_name:
|
||||||
|
builder = builder.column(vector_column_name)
|
||||||
|
return builder
|
||||||
|
elif query_type == "fts":
|
||||||
|
return self.query().nearest_to_text(query, columns=fts_columns or [])
|
||||||
|
elif query_type == "hybrid":
|
||||||
|
builder = self.query().nearest_to(vector_query)
|
||||||
|
if vector_column_name:
|
||||||
|
builder = builder.column(vector_column_name)
|
||||||
|
return builder.nearest_to_text(query, columns=fts_columns or [])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown query type: '{query_type}'")
|
||||||
|
|
||||||
def vector_search(
|
def vector_search(
|
||||||
self,
|
self,
|
||||||
query_vector: Union[VEC, Tuple],
|
query_vector: Union[VEC, Tuple],
|
||||||
|
|||||||
@@ -75,6 +75,6 @@ async def test_binary_vector_async():
|
|||||||
|
|
||||||
query = np.random.randint(0, 2, size=256)
|
query = np.random.randint(0, 2, size=256)
|
||||||
packed_query = np.packbits(query)
|
packed_query = np.packbits(query)
|
||||||
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
|
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
|
||||||
# --8<-- [end:async_binary_vector]
|
# --8<-- [end:async_binary_vector]
|
||||||
await db.drop_table("my_binary_vectors")
|
await db.drop_table("my_binary_vectors")
|
||||||
|
|||||||
@@ -53,13 +53,13 @@ async def test_binary_vector_async():
|
|||||||
query = np.random.random(256)
|
query = np.random.random(256)
|
||||||
|
|
||||||
# Search for the vectors within the range of [0.1, 0.5)
|
# Search for the vectors within the range of [0.1, 0.5)
|
||||||
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
|
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
|
||||||
|
|
||||||
# Search for the vectors with the distance less than 0.5
|
# Search for the vectors with the distance less than 0.5
|
||||||
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
|
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
|
||||||
|
|
||||||
# Search for the vectors with the distance greater or equal to 0.1
|
# Search for the vectors with the distance greater or equal to 0.1
|
||||||
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
|
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
|
||||||
|
|
||||||
# --8<-- [end:async_distance_range]
|
# --8<-- [end:async_distance_range]
|
||||||
await db.drop_table("my_table")
|
await db.drop_table("my_table")
|
||||||
|
|||||||
@@ -28,3 +28,49 @@ def test_embeddings_openai():
|
|||||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
print(actual.text)
|
print(actual.text)
|
||||||
# --8<-- [end:openai_embeddings]
|
# --8<-- [end:openai_embeddings]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embeddings_openai_async():
|
||||||
|
uri = "memory://"
|
||||||
|
# --8<-- [start:async_openai_embeddings]
|
||||||
|
db = await lancedb.connect_async(uri)
|
||||||
|
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||||
|
|
||||||
|
class Words(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = await db.create_table("words", schema=Words, mode="overwrite")
|
||||||
|
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||||
|
|
||||||
|
query = "greetings"
|
||||||
|
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
|
||||||
|
print(actual.text)
|
||||||
|
# --8<-- [end:async_openai_embeddings]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_secret():
|
||||||
|
# --8<-- [start:register_secret]
|
||||||
|
registry = get_registry()
|
||||||
|
registry.set_var("api_key", "sk-...")
|
||||||
|
|
||||||
|
func = registry.get("openai").create(api_key="$var:api_key")
|
||||||
|
# --8<-- [end:register_secret]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("torch not installed")
|
||||||
|
|
||||||
|
# --8<-- [start:register_device]
|
||||||
|
import torch
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
registry.set_var("device", "cuda")
|
||||||
|
|
||||||
|
func = registry.get("huggingface").create(device="$var:device:cpu")
|
||||||
|
# --8<-- [end:register_device]
|
||||||
|
assert func.device == "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|||||||
@@ -72,8 +72,7 @@ async def test_ann_index_async():
|
|||||||
# --8<-- [end:create_ann_index_async]
|
# --8<-- [end:create_ann_index_async]
|
||||||
# --8<-- [start:vector_search_async]
|
# --8<-- [start:vector_search_async]
|
||||||
await (
|
await (
|
||||||
async_tbl.query()
|
(await async_tbl.search(np.random.random((32))))
|
||||||
.nearest_to(np.random.random((32)))
|
|
||||||
.limit(2)
|
.limit(2)
|
||||||
.nprobes(20)
|
.nprobes(20)
|
||||||
.refine_factor(10)
|
.refine_factor(10)
|
||||||
@@ -82,18 +81,14 @@ async def test_ann_index_async():
|
|||||||
# --8<-- [end:vector_search_async]
|
# --8<-- [end:vector_search_async]
|
||||||
# --8<-- [start:vector_search_async_with_filter]
|
# --8<-- [start:vector_search_async_with_filter]
|
||||||
await (
|
await (
|
||||||
async_tbl.query()
|
(await async_tbl.search(np.random.random((32))))
|
||||||
.nearest_to(np.random.random((32)))
|
|
||||||
.where("item != 'item 1141'")
|
.where("item != 'item 1141'")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
)
|
)
|
||||||
# --8<-- [end:vector_search_async_with_filter]
|
# --8<-- [end:vector_search_async_with_filter]
|
||||||
# --8<-- [start:vector_search_async_with_select]
|
# --8<-- [start:vector_search_async_with_select]
|
||||||
await (
|
await (
|
||||||
async_tbl.query()
|
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
|
||||||
.nearest_to(np.random.random((32)))
|
|
||||||
.select(["vector"])
|
|
||||||
.to_pandas()
|
|
||||||
)
|
)
|
||||||
# --8<-- [end:vector_search_async_with_select]
|
# --8<-- [end:vector_search_async_with_select]
|
||||||
|
|
||||||
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
|
|||||||
{"book_id": 3, "vector": [5.0, 6]},
|
{"book_id": 3, "vector": [5.0, 6]},
|
||||||
]
|
]
|
||||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
|
||||||
# --8<-- [end:vector_search_with_scalar_index_async]
|
# --8<-- [end:vector_search_with_scalar_index_async]
|
||||||
# --8<-- [start:update_scalar_index_async]
|
# --8<-- [start:update_scalar_index_async]
|
||||||
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])
|
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])
|
||||||
|
|||||||
36
python/python/tests/docs/test_pydantic_integration.py
Normal file
36
python/python/tests/docs/test_pydantic_integration.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
# --8<-- [start:imports]
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
|
# --8<-- [end:imports]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_model(tmp_path):
|
||||||
|
# --8<-- [start:base_model]
|
||||||
|
class PersonModel(LanceModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
vector: Vector(2)
|
||||||
|
|
||||||
|
# --8<-- [end:base_model]
|
||||||
|
|
||||||
|
# --8<-- [start:set_url]
|
||||||
|
url = "./example"
|
||||||
|
# --8<-- [end:set_url]
|
||||||
|
url = tmp_path
|
||||||
|
|
||||||
|
# --8<-- [start:base_example]
|
||||||
|
db = lancedb.connect(url)
|
||||||
|
table = db.create_table("person", schema=PersonModel)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
PersonModel(name="bob", age=1, vector=[1.0, 2.0]),
|
||||||
|
PersonModel(name="alice", age=2, vector=[3.0, 4.0]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert table.count_rows() == 2
|
||||||
|
person = table.search([0.0, 0.0]).limit(1).to_pydantic(PersonModel)
|
||||||
|
assert person[0].name == "bob"
|
||||||
|
# --8<-- [end:base_example]
|
||||||
@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
|
|||||||
|
|
||||||
query_vector = [100, 100]
|
query_vector = [100, 100]
|
||||||
# Pandas DataFrame
|
# Pandas DataFrame
|
||||||
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
|
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
|
||||||
print(df)
|
print(df)
|
||||||
# --8<-- [end:vector_search_async]
|
# --8<-- [end:vector_search_async]
|
||||||
# --8<-- [start:vector_search_with_filter_async]
|
# --8<-- [start:vector_search_with_filter_async]
|
||||||
# Apply the filter via LanceDB
|
# Apply the filter via LanceDB
|
||||||
results = (
|
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
|
||||||
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
|
|
||||||
)
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results["item"].iloc[0] == "foo"
|
assert results["item"].iloc[0] == "foo"
|
||||||
|
|
||||||
# Apply the filter via Pandas
|
# Apply the filter via Pandas
|
||||||
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
|
df = results = await (await async_tbl.search([100, 100])).to_pandas()
|
||||||
results = df[df.price < 15]
|
results = df[df.price < 15]
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results["item"].iloc[0] == "foo"
|
assert results["item"].iloc[0] == "foo"
|
||||||
@@ -188,3 +186,26 @@ def test_polars():
|
|||||||
# --8<-- [start:print_table_lazyform]
|
# --8<-- [start:print_table_lazyform]
|
||||||
print(ldf.first().collect())
|
print(ldf.first().collect())
|
||||||
# --8<-- [end:print_table_lazyform]
|
# --8<-- [end:print_table_lazyform]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_polars_async():
|
||||||
|
uri = "data/sample-lancedb"
|
||||||
|
db = await lancedb.connect_async(uri)
|
||||||
|
|
||||||
|
# --8<-- [start:create_table_polars_async]
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
"item": ["foo", "bar"],
|
||||||
|
"price": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
table = await db.create_table("pl_table_async", data=data)
|
||||||
|
# --8<-- [end:create_table_polars_async]
|
||||||
|
# --8<-- [start:vector_search_polars_async]
|
||||||
|
query = [3.0, 4.0]
|
||||||
|
result = await (await table.search(query)).limit(1).to_polars()
|
||||||
|
print(result)
|
||||||
|
print(type(result))
|
||||||
|
# --8<-- [end:vector_search_polars_async]
|
||||||
|
|||||||
@@ -117,12 +117,11 @@ async def test_vector_search_async():
|
|||||||
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
|
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
|
||||||
]
|
]
|
||||||
async_tbl = await async_db.create_table("vector_search_async", data=data)
|
async_tbl = await async_db.create_table("vector_search_async", data=data)
|
||||||
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
|
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
|
||||||
# --8<-- [end:exhaustive_search_async]
|
# --8<-- [end:exhaustive_search_async]
|
||||||
# --8<-- [start:exhaustive_search_async_cosine]
|
# --8<-- [start:exhaustive_search_async_cosine]
|
||||||
(
|
(
|
||||||
await async_tbl.query()
|
await (await async_tbl.search(np.random.random((1536))))
|
||||||
.nearest_to(np.random.random((1536)))
|
|
||||||
.distance_type("cosine")
|
.distance_type("cosine")
|
||||||
.limit(10)
|
.limit(10)
|
||||||
.to_list()
|
.to_list()
|
||||||
@@ -145,13 +144,13 @@ async def test_vector_search_async():
|
|||||||
async_tbl = await async_db.create_table("documents_async", data=data)
|
async_tbl = await async_db.create_table("documents_async", data=data)
|
||||||
# --8<-- [end:create_table_async_with_nested_schema]
|
# --8<-- [end:create_table_async_with_nested_schema]
|
||||||
# --8<-- [start:search_result_async_as_pyarrow]
|
# --8<-- [start:search_result_async_as_pyarrow]
|
||||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
|
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
|
||||||
# --8<-- [end:search_result_async_as_pyarrow]
|
# --8<-- [end:search_result_async_as_pyarrow]
|
||||||
# --8<-- [start:search_result_async_as_pandas]
|
# --8<-- [start:search_result_async_as_pandas]
|
||||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
|
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
|
||||||
# --8<-- [end:search_result_async_as_pandas]
|
# --8<-- [end:search_result_async_as_pandas]
|
||||||
# --8<-- [start:search_result_async_as_list]
|
# --8<-- [start:search_result_async_as_list]
|
||||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
|
await (await async_tbl.search(np.random.randn(1536))).to_list()
|
||||||
# --8<-- [end:search_result_async_as_list]
|
# --8<-- [end:search_result_async_as_list]
|
||||||
|
|
||||||
|
|
||||||
@@ -219,9 +218,7 @@ async def test_fts_native_async():
|
|||||||
|
|
||||||
# async API uses our native FTS algorithm
|
# async API uses our native FTS algorithm
|
||||||
await async_tbl.create_index("text", config=FTS())
|
await async_tbl.create_index("text", config=FTS())
|
||||||
await (
|
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
|
||||||
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
|
|
||||||
)
|
|
||||||
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
||||||
# ...
|
# ...
|
||||||
# --8<-- [end:basic_fts_async]
|
# --8<-- [end:basic_fts_async]
|
||||||
@@ -235,18 +232,11 @@ async def test_fts_native_async():
|
|||||||
)
|
)
|
||||||
# --8<-- [end:fts_config_folding_async]
|
# --8<-- [end:fts_config_folding_async]
|
||||||
# --8<-- [start:fts_prefiltering_async]
|
# --8<-- [start:fts_prefiltering_async]
|
||||||
await (
|
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
|
||||||
async_tbl.query()
|
|
||||||
.nearest_to_text("puppy")
|
|
||||||
.limit(10)
|
|
||||||
.where("text='foo'")
|
|
||||||
.to_list()
|
|
||||||
)
|
|
||||||
# --8<-- [end:fts_prefiltering_async]
|
# --8<-- [end:fts_prefiltering_async]
|
||||||
# --8<-- [start:fts_postfiltering_async]
|
# --8<-- [start:fts_postfiltering_async]
|
||||||
await (
|
await (
|
||||||
async_tbl.query()
|
(await async_tbl.search("puppy"))
|
||||||
.nearest_to_text("puppy")
|
|
||||||
.limit(10)
|
.limit(10)
|
||||||
.where("text='foo'")
|
.where("text='foo'")
|
||||||
.postfilter()
|
.postfilter()
|
||||||
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
|
|||||||
# Create a fts index before the hybrid search
|
# Create a fts index before the hybrid search
|
||||||
await async_tbl.create_index("text", config=FTS())
|
await async_tbl.create_index("text", config=FTS())
|
||||||
text_query = "flower moon"
|
text_query = "flower moon"
|
||||||
vector_query = embeddings.compute_query_embeddings(text_query)[0]
|
|
||||||
# hybrid search with default re-ranker
|
# hybrid search with default re-ranker
|
||||||
await (
|
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
|
||||||
async_tbl.query()
|
|
||||||
.nearest_to(vector_query)
|
|
||||||
.nearest_to_text(text_query)
|
|
||||||
.to_pandas()
|
|
||||||
)
|
|
||||||
# --8<-- [end:basic_hybrid_search_async]
|
# --8<-- [end:basic_hybrid_search_async]
|
||||||
# --8<-- [start:hybrid_search_pass_vector_text_async]
|
# --8<-- [start:hybrid_search_pass_vector_text_async]
|
||||||
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|||||||
@@ -299,12 +299,12 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect(tmp_path):
|
async def test_connect(tmp_path):
|
||||||
db = await lancedb.connect_async(tmp_path)
|
db = await lancedb.connect_async(tmp_path)
|
||||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)"
|
assert str(db) == f"ListingDatabase(uri={tmp_path}, read_consistency_interval=None)"
|
||||||
|
|
||||||
db = await lancedb.connect_async(
|
db = await lancedb.connect_async(
|
||||||
tmp_path, read_consistency_interval=timedelta(seconds=5)
|
tmp_path, read_consistency_interval=timedelta(seconds=5)
|
||||||
)
|
)
|
||||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)"
|
assert str(db) == f"ListingDatabase(uri={tmp_path}, read_consistency_interval=5s)"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -396,13 +396,16 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_table_v2_manifest_paths_async(tmp_path):
|
async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||||
db = await lancedb.connect_async(tmp_path)
|
db_with_v2_paths = await lancedb.connect_async(
|
||||||
|
tmp_path, storage_options={"new_table_enable_v2_manifest_paths": "true"}
|
||||||
|
)
|
||||||
|
db_no_v2_paths = await lancedb.connect_async(
|
||||||
|
tmp_path, storage_options={"new_table_enable_v2_manifest_paths": "false"}
|
||||||
|
)
|
||||||
# Create table in v2 mode with v2 manifest paths enabled
|
# Create table in v2 mode with v2 manifest paths enabled
|
||||||
tbl = await db.create_table(
|
tbl = await db_with_v2_paths.create_table(
|
||||||
"test_v2_manifest_paths",
|
"test_v2_manifest_paths",
|
||||||
data=[{"id": 0}],
|
data=[{"id": 0}],
|
||||||
use_legacy_format=False,
|
|
||||||
enable_v2_manifest_paths=True,
|
|
||||||
)
|
)
|
||||||
assert await tbl.uses_v2_manifest_paths()
|
assert await tbl.uses_v2_manifest_paths()
|
||||||
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
|
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
|
||||||
@@ -410,11 +413,9 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
|||||||
assert re.match(r"\d{20}\.manifest", manifest)
|
assert re.match(r"\d{20}\.manifest", manifest)
|
||||||
|
|
||||||
# Start a table in V1 mode then migrate
|
# Start a table in V1 mode then migrate
|
||||||
tbl = await db.create_table(
|
tbl = await db_no_v2_paths.create_table(
|
||||||
"test_v2_migration",
|
"test_v2_migration",
|
||||||
data=[{"id": 0}],
|
data=[{"id": 0}],
|
||||||
use_legacy_format=False,
|
|
||||||
enable_v2_manifest_paths=False,
|
|
||||||
)
|
)
|
||||||
assert not await tbl.uses_v2_manifest_paths()
|
assert not await tbl.uses_v2_manifest_paths()
|
||||||
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
|
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
|
||||||
@@ -498,6 +499,10 @@ def test_delete_table(tmp_db: lancedb.DBConnection):
|
|||||||
# if ignore_missing=True
|
# if ignore_missing=True
|
||||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||||
|
|
||||||
|
tmp_db.drop_all_tables()
|
||||||
|
|
||||||
|
assert tmp_db.table_names() == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
|
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
|
||||||
@@ -583,7 +588,7 @@ def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
async def test_create_in_v2_mode():
|
||||||
def make_data():
|
def make_data():
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
|
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
|
||||||
@@ -594,10 +599,13 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
|||||||
schema = pa.schema([pa.field("x", pa.int64())])
|
schema = pa.schema([pa.field("x", pa.int64())])
|
||||||
|
|
||||||
# Create table in v1 mode
|
# Create table in v1 mode
|
||||||
tbl = await mem_db_async.create_table(
|
|
||||||
"test", data=make_data(), schema=schema, data_storage_version="legacy"
|
v1_db = await lancedb.connect_async(
|
||||||
|
"memory://", storage_options={"new_table_data_storage_version": "legacy"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tbl = await v1_db.create_table("test", data=make_data(), schema=schema)
|
||||||
|
|
||||||
async def is_in_v2_mode(tbl):
|
async def is_in_v2_mode(tbl):
|
||||||
batches = (
|
batches = (
|
||||||
await tbl.query().limit(10 * 1024).to_batches(max_batch_length=1024 * 10)
|
await tbl.query().limit(10 * 1024).to_batches(max_batch_length=1024 * 10)
|
||||||
@@ -610,10 +618,12 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
|||||||
assert not await is_in_v2_mode(tbl)
|
assert not await is_in_v2_mode(tbl)
|
||||||
|
|
||||||
# Create table in v2 mode
|
# Create table in v2 mode
|
||||||
tbl = await mem_db_async.create_table(
|
v2_db = await lancedb.connect_async(
|
||||||
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
|
"memory://", storage_options={"new_table_data_storage_version": "stable"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tbl = await v2_db.create_table("test_v2", data=make_data(), schema=schema)
|
||||||
|
|
||||||
assert await is_in_v2_mode(tbl)
|
assert await is_in_v2_mode(tbl)
|
||||||
|
|
||||||
# Add data (should remain in v2 mode)
|
# Add data (should remain in v2 mode)
|
||||||
@@ -622,20 +632,18 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
|||||||
assert await is_in_v2_mode(tbl)
|
assert await is_in_v2_mode(tbl)
|
||||||
|
|
||||||
# Create empty table in v2 mode and add data
|
# Create empty table in v2 mode and add data
|
||||||
tbl = await mem_db_async.create_table(
|
tbl = await v2_db.create_table("test_empty_v2", data=None, schema=schema)
|
||||||
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
|
|
||||||
)
|
|
||||||
await tbl.add(make_table())
|
await tbl.add(make_table())
|
||||||
|
|
||||||
assert await is_in_v2_mode(tbl)
|
assert await is_in_v2_mode(tbl)
|
||||||
|
|
||||||
# Create empty table uses v1 mode by default
|
# Db uses v2 mode by default
|
||||||
tbl = await mem_db_async.create_table(
|
db = await lancedb.connect_async("memory://")
|
||||||
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
|
|
||||||
)
|
tbl = await db.create_table("test_empty_v2_default", data=None, schema=schema)
|
||||||
await tbl.add(make_table())
|
await tbl.add(make_table())
|
||||||
|
|
||||||
assert not await is_in_v2_mode(tbl)
|
assert await is_in_v2_mode(tbl)
|
||||||
|
|
||||||
|
|
||||||
def test_replace_index(mem_db: lancedb.DBConnection):
|
def test_replace_index(mem_db: lancedb.DBConnection):
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
from typing import List, Union
|
import os
|
||||||
|
from typing import List, Optional, Union
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
@@ -56,7 +57,7 @@ def test_embedding_function(tmp_path):
|
|||||||
conf = EmbeddingFunctionConfig(
|
conf = EmbeddingFunctionConfig(
|
||||||
source_column="text",
|
source_column="text",
|
||||||
vector_column="vector",
|
vector_column="vector",
|
||||||
function=MockTextEmbeddingFunction(),
|
function=MockTextEmbeddingFunction.create(),
|
||||||
)
|
)
|
||||||
metadata = registry.get_table_metadata([conf])
|
metadata = registry.get_table_metadata([conf])
|
||||||
table = table.replace_schema_metadata(metadata)
|
table = table.replace_schema_metadata(metadata)
|
||||||
@@ -80,6 +81,57 @@ def test_embedding_function(tmp_path):
|
|||||||
assert np.allclose(actual, expected)
|
assert np.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_function_variables():
|
||||||
|
@register("variable-testing")
|
||||||
|
class VariableTestingFunction(TextEmbeddingFunction):
|
||||||
|
key1: str
|
||||||
|
secret_key: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sensitive_keys():
|
||||||
|
return ["secret_key"]
|
||||||
|
|
||||||
|
def ndims():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate_embeddings(self, _texts):
|
||||||
|
pass
|
||||||
|
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
|
||||||
|
# Should error if variable is not set
|
||||||
|
with pytest.raises(ValueError, match="Variable 'test' not found"):
|
||||||
|
registry.get("variable-testing").create(
|
||||||
|
key1="$var:test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use default values if not set
|
||||||
|
func = registry.get("variable-testing").create(key1="$var:test:some_value")
|
||||||
|
assert func.key1 == "some_value"
|
||||||
|
|
||||||
|
# Should set a variable that the embedding function understands
|
||||||
|
registry.set_var("test", "some_value")
|
||||||
|
func = registry.get("variable-testing").create(key1="$var:test")
|
||||||
|
assert func.key1 == "some_value"
|
||||||
|
|
||||||
|
# Should reject secrets that aren't passed in as variables
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Sensitive key 'secret_key' cannot be set to a hardcoded value",
|
||||||
|
):
|
||||||
|
registry.get("variable-testing").create(
|
||||||
|
key1="whatever", secret_key="some_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not serialize secrets.
|
||||||
|
registry.set_var("secret", "secret_value")
|
||||||
|
func = registry.get("variable-testing").create(
|
||||||
|
key1="whatever", secret_key="$var:secret"
|
||||||
|
)
|
||||||
|
assert func.secret_key == "secret_value"
|
||||||
|
assert func.safe_model_dump()["secret_key"] == "$var:secret"
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_with_bad_results(tmp_path):
|
def test_embedding_with_bad_results(tmp_path):
|
||||||
@register("null-embedding")
|
@register("null-embedding")
|
||||||
class NullEmbeddingFunction(TextEmbeddingFunction):
|
class NullEmbeddingFunction(TextEmbeddingFunction):
|
||||||
@@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path):
|
|||||||
) -> list[Union[np.array, None]]:
|
) -> list[Union[np.array, None]]:
|
||||||
# Return None, which is bad if field is non-nullable
|
# Return None, which is bad if field is non-nullable
|
||||||
a = [
|
a = [
|
||||||
np.full(self.ndims(), np.nan)
|
(
|
||||||
if i % 2 == 0
|
np.full(self.ndims(), np.nan)
|
||||||
else np.random.randn(self.ndims())
|
if i % 2 == 0
|
||||||
|
else np.random.randn(self.ndims())
|
||||||
|
)
|
||||||
for i in range(len(texts))
|
for i in range(len(texts))
|
||||||
]
|
]
|
||||||
return a
|
return a
|
||||||
@@ -107,7 +161,7 @@ def test_embedding_with_bad_results(tmp_path):
|
|||||||
vector: Vector(model.ndims()) = model.VectorField()
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RuntimeError):
|
||||||
# Default on_bad_vectors is "error"
|
# Default on_bad_vectors is "error"
|
||||||
table.add([{"text": "hello world"}])
|
table.add([{"text": "hello world"}])
|
||||||
|
|
||||||
@@ -341,6 +395,7 @@ def test_add_optional_vector(tmp_path):
|
|||||||
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
|
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"embedding_type",
|
"embedding_type",
|
||||||
[
|
[
|
||||||
@@ -358,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type):
|
|||||||
|
|
||||||
# Note: Some embedding types might require specific parameters
|
# Note: Some embedding types might require specific parameters
|
||||||
try:
|
try:
|
||||||
model = registry.get(embedding_type).create()
|
model = registry.get(embedding_type).create({"max_retries": 1})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
|
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
|
||||||
|
|
||||||
@@ -391,3 +446,33 @@ def test_retry(mock_sleep):
|
|||||||
result = test_function()
|
result = test_function()
|
||||||
assert mock_sleep.call_count == 9
|
assert mock_sleep.call_count == 9
|
||||||
assert result == "result"
|
assert result == "result"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set"
|
||||||
|
)
|
||||||
|
def test_openai_propagates_api_key(monkeypatch):
|
||||||
|
# Make sure that if we set it as a variable, the API key is propagated
|
||||||
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
uri = "memory://"
|
||||||
|
registry = get_registry()
|
||||||
|
registry.set_var("open_api_key", api_key)
|
||||||
|
func = registry.get("openai").create(
|
||||||
|
name="text-embedding-ada-002",
|
||||||
|
max_retries=0,
|
||||||
|
api_key="$var:open_api_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Words(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
db = lancedb.connect(uri)
|
||||||
|
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||||
|
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||||
|
|
||||||
|
query = "greetings"
|
||||||
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
|
assert len(actual.text) > 0
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import pyarrow as pa
|
|||||||
import pydantic
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||||
|
from pydantic import BaseModel
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
@@ -252,3 +253,104 @@ def test_lance_model():
|
|||||||
|
|
||||||
t = TestModel()
|
t = TestModel()
|
||||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_nested_model():
|
||||||
|
class WAMedia(BaseModel):
|
||||||
|
url: str
|
||||||
|
mimetype: str
|
||||||
|
filename: Optional[str]
|
||||||
|
error: Optional[str]
|
||||||
|
data: bytes
|
||||||
|
|
||||||
|
class WALocation(BaseModel):
|
||||||
|
description: Optional[str]
|
||||||
|
latitude: str
|
||||||
|
longitude: str
|
||||||
|
|
||||||
|
class ReplyToMessage(BaseModel):
|
||||||
|
id: str
|
||||||
|
participant: str
|
||||||
|
body: str
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
id: str
|
||||||
|
timestamp: int
|
||||||
|
from_: str
|
||||||
|
fromMe: bool
|
||||||
|
to: str
|
||||||
|
body: str
|
||||||
|
hasMedia: Optional[bool]
|
||||||
|
media: WAMedia
|
||||||
|
mediaUrl: Optional[str]
|
||||||
|
ack: Optional[int]
|
||||||
|
ackName: Optional[str]
|
||||||
|
author: Optional[str]
|
||||||
|
location: Optional[WALocation]
|
||||||
|
vCards: Optional[List[str]]
|
||||||
|
replyTo: Optional[ReplyToMessage]
|
||||||
|
|
||||||
|
class AnyEvent(LanceModel):
|
||||||
|
id: str
|
||||||
|
session: str
|
||||||
|
metadata: Optional[str] = None
|
||||||
|
engine: str
|
||||||
|
event: str
|
||||||
|
|
||||||
|
class MessageEvent(AnyEvent):
|
||||||
|
payload: Message
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(MessageEvent)
|
||||||
|
|
||||||
|
payload = schema.field("payload")
|
||||||
|
assert payload.type == pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field("timestamp", pa.int64(), False),
|
||||||
|
pa.field("from_", pa.utf8(), False),
|
||||||
|
pa.field("fromMe", pa.bool_(), False),
|
||||||
|
pa.field("to", pa.utf8(), False),
|
||||||
|
pa.field("body", pa.utf8(), False),
|
||||||
|
pa.field("hasMedia", pa.bool_(), True),
|
||||||
|
pa.field(
|
||||||
|
"media",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("url", pa.utf8(), False),
|
||||||
|
pa.field("mimetype", pa.utf8(), False),
|
||||||
|
pa.field("filename", pa.utf8(), True),
|
||||||
|
pa.field("error", pa.utf8(), True),
|
||||||
|
pa.field("data", pa.binary(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
pa.field("mediaUrl", pa.utf8(), True),
|
||||||
|
pa.field("ack", pa.int64(), True),
|
||||||
|
pa.field("ackName", pa.utf8(), True),
|
||||||
|
pa.field("author", pa.utf8(), True),
|
||||||
|
pa.field(
|
||||||
|
"location",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("description", pa.utf8(), True),
|
||||||
|
pa.field("latitude", pa.utf8(), False),
|
||||||
|
pa.field("longitude", pa.utf8(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True, # Optional
|
||||||
|
),
|
||||||
|
pa.field("vCards", pa.list_(pa.utf8()), True),
|
||||||
|
pa.field(
|
||||||
|
"replyTo",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field("participant", pa.utf8(), False),
|
||||||
|
pa.field("body", pa.utf8(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,25 +1,35 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.index import IvfPq, FTS
|
from lancedb.db import AsyncConnection
|
||||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||||
|
from lancedb.embeddings.registry import get_registry, register
|
||||||
|
from lancedb.index import FTS, IvfPq
|
||||||
|
import lancedb.pydantic
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas.testing as tm
|
import pandas.testing as tm
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
import pyarrow.compute as pc
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from lancedb.query import (
|
from lancedb.query import (
|
||||||
|
AsyncFTSQuery,
|
||||||
|
AsyncHybridQuery,
|
||||||
AsyncQueryBase,
|
AsyncQueryBase,
|
||||||
|
AsyncVectorQuery,
|
||||||
LanceVectorQueryBuilder,
|
LanceVectorQueryBuilder,
|
||||||
Query,
|
Query,
|
||||||
)
|
)
|
||||||
|
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||||
from lancedb.table import AsyncTable, LanceTable
|
from lancedb.table import AsyncTable, LanceTable
|
||||||
|
from utils import exception_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@@ -232,6 +242,71 @@ async def test_distance_range_async(table_async: AsyncTable):
|
|||||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_distance_range_with_new_rows_async():
|
||||||
|
conn = await lancedb.connect_async(
|
||||||
|
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||||
|
np.random.rand(256, 2)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
table = await conn.create_table("test", data)
|
||||||
|
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
|
||||||
|
|
||||||
|
q = [0, 0]
|
||||||
|
rs = await table.query().nearest_to(q).to_arrow()
|
||||||
|
dists = rs["_distance"].to_pylist()
|
||||||
|
min_dist = dists[0]
|
||||||
|
max_dist = dists[-1]
|
||||||
|
|
||||||
|
# append more rows so that execution plan would be mixed with ANN & Flat KNN
|
||||||
|
new_data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(4, 2)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await table.add(new_data)
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(upper_bound=min_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(res) == 0
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(lower_bound=max_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
for dist in res["_distance"].to_pylist():
|
||||||
|
assert dist >= max_dist
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(upper_bound=max_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
for dist in res["_distance"].to_pylist():
|
||||||
|
assert dist < max_dist
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(lower_bound=min_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
for dist in res["_distance"].to_pylist():
|
||||||
|
assert dist >= min_dist
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
||||||
)
|
)
|
||||||
@@ -651,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
|
|||||||
tbl = await db.create_table("test", df)
|
tbl = await db.create_table("test", df)
|
||||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||||
|
nrows = 1000
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": [str(i) for i in range(nrows)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@register("test2")
|
||||||
|
class TestEmbedding(TextEmbeddingFunction):
|
||||||
|
def ndims(self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
vec = np.array([float(text) / 1000] * self.ndims())
|
||||||
|
embeddings.append(vec)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("test2").create()
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
|
||||||
|
|
||||||
|
funcs = await tbl.embedding_functions()
|
||||||
|
assert len(funcs) == 1
|
||||||
|
|
||||||
|
# No FTS or vector index
|
||||||
|
# Search for vector -> vector query
|
||||||
|
q = [0.1] * 4
|
||||||
|
query = await tbl.search(q)
|
||||||
|
assert isinstance(query, AsyncVectorQuery)
|
||||||
|
|
||||||
|
# Search for string -> vector query
|
||||||
|
query = await tbl.search("0.1")
|
||||||
|
assert isinstance(query, AsyncVectorQuery)
|
||||||
|
|
||||||
|
await tbl.create_index("text", config=FTS())
|
||||||
|
|
||||||
|
query = await tbl.search("0.1")
|
||||||
|
assert isinstance(query, AsyncHybridQuery)
|
||||||
|
|
||||||
|
data_with_vecs = await tbl.to_arrow()
|
||||||
|
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
|
||||||
|
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
|
||||||
|
with pytest.raises(
|
||||||
|
Exception,
|
||||||
|
match=(
|
||||||
|
"Cannot perform full text search unless an INVERTED index has "
|
||||||
|
"been created"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
query = await (await tbl2.search("0.1")).to_arrow()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||||
|
nrows, ndims = 1000, 16
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": [str(i) for i in range(nrows)],
|
||||||
|
"vector": pa.FixedSizeListArray.from_arrays(
|
||||||
|
pc.random(nrows * ndims).cast(pa.float32()), ndims
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
table = await mem_db_async.create_table("test", data)
|
||||||
|
await table.create_index("text", config=FTS())
|
||||||
|
|
||||||
|
# Validate that specifying fts, vector or hybrid gets the right query.
|
||||||
|
q = [0.1] * ndims
|
||||||
|
query = await table.search(q, query_type="vector")
|
||||||
|
assert isinstance(query, AsyncVectorQuery)
|
||||||
|
|
||||||
|
query = await table.search("0.1", query_type="fts")
|
||||||
|
assert isinstance(query, AsyncFTSQuery)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
|
||||||
|
await table.search("0.1", query_type="foo")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Column 'vector' has no registered embedding function"
|
||||||
|
) as e:
|
||||||
|
await table.search("0.1", query_type="vector")
|
||||||
|
|
||||||
|
assert "No embedding functions are registered for any columns" in exception_output(
|
||||||
|
e
|
||||||
|
)
|
||||||
|
|||||||
@@ -32,15 +32,16 @@ def make_mock_http_handler(handler):
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def mock_lancedb_connection(handler):
|
def mock_lancedb_connection(handler):
|
||||||
with http.server.HTTPServer(
|
with http.server.HTTPServer(
|
||||||
("localhost", 8080), make_mock_http_handler(handler)
|
("localhost", 0), make_mock_http_handler(handler)
|
||||||
) as server:
|
) as server:
|
||||||
|
port = server.server_address[1]
|
||||||
handle = threading.Thread(target=server.serve_forever)
|
handle = threading.Thread(target=server.serve_forever)
|
||||||
handle.start()
|
handle.start()
|
||||||
|
|
||||||
db = lancedb.connect(
|
db = lancedb.connect(
|
||||||
"db://dev",
|
"db://dev",
|
||||||
api_key="fake",
|
api_key="fake",
|
||||||
host_override="http://localhost:8080",
|
host_override=f"http://localhost:{port}",
|
||||||
client_config={
|
client_config={
|
||||||
"retry_config": {"retries": 2},
|
"retry_config": {"retries": 2},
|
||||||
"timeout_config": {
|
"timeout_config": {
|
||||||
@@ -57,22 +58,24 @@ def mock_lancedb_connection(handler):
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def mock_lancedb_connection_async(handler):
|
async def mock_lancedb_connection_async(handler, **client_config):
|
||||||
with http.server.HTTPServer(
|
with http.server.HTTPServer(
|
||||||
("localhost", 8080), make_mock_http_handler(handler)
|
("localhost", 0), make_mock_http_handler(handler)
|
||||||
) as server:
|
) as server:
|
||||||
|
port = server.server_address[1]
|
||||||
handle = threading.Thread(target=server.serve_forever)
|
handle = threading.Thread(target=server.serve_forever)
|
||||||
handle.start()
|
handle.start()
|
||||||
|
|
||||||
db = await lancedb.connect_async(
|
db = await lancedb.connect_async(
|
||||||
"db://dev",
|
"db://dev",
|
||||||
api_key="fake",
|
api_key="fake",
|
||||||
host_override="http://localhost:8080",
|
host_override=f"http://localhost:{port}",
|
||||||
client_config={
|
client_config={
|
||||||
"retry_config": {"retries": 2},
|
"retry_config": {"retries": 2},
|
||||||
"timeout_config": {
|
"timeout_config": {
|
||||||
"connect_timeout": 1,
|
"connect_timeout": 1,
|
||||||
},
|
},
|
||||||
|
**client_config,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -254,6 +257,9 @@ def test_table_create_indices():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
request.wfile.write(payload.encode())
|
request.wfile.write(payload.encode())
|
||||||
|
elif "/drop/" in request.path:
|
||||||
|
request.send_response(200)
|
||||||
|
request.end_headers()
|
||||||
else:
|
else:
|
||||||
request.send_response(404)
|
request.send_response(404)
|
||||||
request.end_headers()
|
request.end_headers()
|
||||||
@@ -265,6 +271,9 @@ def test_table_create_indices():
|
|||||||
table.create_scalar_index("id")
|
table.create_scalar_index("id")
|
||||||
table.create_fts_index("text")
|
table.create_fts_index("text")
|
||||||
table.create_scalar_index("vector")
|
table.create_scalar_index("vector")
|
||||||
|
table.drop_index("vector_idx")
|
||||||
|
table.drop_index("id_idx")
|
||||||
|
table.drop_index("text_idx")
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -329,6 +338,7 @@ def test_query_sync_empty_query():
|
|||||||
"filter": "true",
|
"filter": "true",
|
||||||
"vector": [],
|
"vector": [],
|
||||||
"columns": ["id"],
|
"columns": ["id"],
|
||||||
|
"prefilter": False,
|
||||||
"version": None,
|
"version": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,8 +389,14 @@ def test_query_sync_maximal():
|
|||||||
|
|
||||||
|
|
||||||
def test_query_sync_multiple_vectors():
|
def test_query_sync_multiple_vectors():
|
||||||
def handler(_body):
|
def handler(body):
|
||||||
return pa.table({"id": [1]})
|
# TODO: we will add the ability to get the server version,
|
||||||
|
# so that we can decide how to perform batch quires.
|
||||||
|
vectors = body["vector"]
|
||||||
|
res = []
|
||||||
|
for i, vector in enumerate(vectors):
|
||||||
|
res.append({"id": 1, "query_index": i})
|
||||||
|
return pa.Table.from_pylist(res)
|
||||||
|
|
||||||
with query_test_table(handler) as table:
|
with query_test_table(handler) as table:
|
||||||
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
||||||
@@ -397,6 +413,7 @@ def test_query_sync_fts():
|
|||||||
"columns": [],
|
"columns": [],
|
||||||
},
|
},
|
||||||
"k": 10,
|
"k": 10,
|
||||||
|
"prefilter": True,
|
||||||
"vector": [],
|
"vector": [],
|
||||||
"version": None,
|
"version": None,
|
||||||
}
|
}
|
||||||
@@ -414,6 +431,7 @@ def test_query_sync_fts():
|
|||||||
},
|
},
|
||||||
"k": 42,
|
"k": 42,
|
||||||
"vector": [],
|
"vector": [],
|
||||||
|
"prefilter": True,
|
||||||
"with_row_id": True,
|
"with_row_id": True,
|
||||||
"version": None,
|
"version": None,
|
||||||
}
|
}
|
||||||
@@ -440,6 +458,7 @@ def test_query_sync_hybrid():
|
|||||||
},
|
},
|
||||||
"k": 42,
|
"k": 42,
|
||||||
"vector": [],
|
"vector": [],
|
||||||
|
"prefilter": True,
|
||||||
"with_row_id": True,
|
"with_row_id": True,
|
||||||
"version": None,
|
"version": None,
|
||||||
}
|
}
|
||||||
@@ -522,3 +541,19 @@ def test_create_client():
|
|||||||
|
|
||||||
with pytest.warns(DeprecationWarning):
|
with pytest.warns(DeprecationWarning):
|
||||||
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_headers():
|
||||||
|
def handler(request):
|
||||||
|
assert request.headers["foo"] == "bar"
|
||||||
|
request.send_response(200)
|
||||||
|
request.send_header("Content-Type", "application/json")
|
||||||
|
request.end_headers()
|
||||||
|
request.wfile.write(b'{"tables": []}')
|
||||||
|
|
||||||
|
async with mock_lancedb_connection_async(
|
||||||
|
handler, extra_headers={"foo": "bar"}
|
||||||
|
) as db:
|
||||||
|
table_names = await db.table_names()
|
||||||
|
assert table_names == []
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts")
|
|||||||
def get_test_table(tmp_path, use_tantivy):
|
def get_test_table(tmp_path, use_tantivy):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
# Create a LanceDB table schema with a vector and a text column
|
# Create a LanceDB table schema with a vector and a text column
|
||||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||||
|
|
||||||
class MyTable(LanceModel):
|
class MyTable(LanceModel):
|
||||||
text: str = emb.SourceField()
|
text: str = emb.SourceField()
|
||||||
@@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
os.environ.get("OPENAI_API_KEY") is None
|
||||||
|
or os.environ.get("OPENAI_BASE_URL") is not None,
|
||||||
|
reason="OPENAI_API_KEY not set",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||||
def test_openai_reranker(tmp_path, use_tantivy):
|
def test_openai_reranker(tmp_path, use_tantivy):
|
||||||
|
|||||||
@@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
|
|||||||
text: str
|
text: str
|
||||||
vector: Vector(10)
|
vector: Vector(10)
|
||||||
|
|
||||||
func = MockTextEmbeddingFunction()
|
func = MockTextEmbeddingFunction.create()
|
||||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||||
|
|
||||||
@@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
|||||||
|
|
||||||
|
|
||||||
def test_add_with_embedding_function(mem_db: DBConnection):
|
def test_add_with_embedding_function(mem_db: DBConnection):
|
||||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||||
|
|
||||||
class MyTable(LanceModel):
|
class MyTable(LanceModel):
|
||||||
text: str = emb.SourceField()
|
text: str = emb.SourceField()
|
||||||
@@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection):
|
|||||||
|
|
||||||
def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
||||||
# Create a LanceDB table schema with a vector and a text column
|
# Create a LanceDB table schema with a vector and a text column
|
||||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
|
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create()
|
||||||
|
|
||||||
class MyTable(LanceModel):
|
class MyTable(LanceModel):
|
||||||
text: str = emb.SourceField()
|
text: str = emb.SourceField()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from lancedb.table import (
|
|||||||
_append_vector_columns,
|
_append_vector_columns,
|
||||||
_cast_to_target_schema,
|
_cast_to_target_schema,
|
||||||
_handle_bad_vectors,
|
_handle_bad_vectors,
|
||||||
_into_pyarrow_table,
|
_into_pyarrow_reader,
|
||||||
_sanitize_data,
|
_sanitize_data,
|
||||||
_infer_target_schema,
|
_infer_target_schema,
|
||||||
)
|
)
|
||||||
@@ -127,7 +127,7 @@ def test_append_vector_columns():
|
|||||||
conf = EmbeddingFunctionConfig(
|
conf = EmbeddingFunctionConfig(
|
||||||
source_column="text",
|
source_column="text",
|
||||||
vector_column="vector",
|
vector_column="vector",
|
||||||
function=MockTextEmbeddingFunction(),
|
function=MockTextEmbeddingFunction.create(),
|
||||||
)
|
)
|
||||||
metadata = registry.get_table_metadata([conf])
|
metadata = registry.get_table_metadata([conf])
|
||||||
|
|
||||||
@@ -145,19 +145,19 @@ def test_append_vector_columns():
|
|||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
output = _append_vector_columns(
|
output = _append_vector_columns(
|
||||||
data,
|
data.to_reader(),
|
||||||
schema, # metadata passed separate from schema
|
schema, # metadata passed separate from schema
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
).read_all()
|
||||||
assert output.schema == schema
|
assert output.schema == schema
|
||||||
assert output["vector"].null_count == 0
|
assert output["vector"].null_count == 0
|
||||||
|
|
||||||
# Adds if missing
|
# Adds if missing
|
||||||
data = pa.table({"text": ["hello"]})
|
data = pa.table({"text": ["hello"]})
|
||||||
output = _append_vector_columns(
|
output = _append_vector_columns(
|
||||||
data,
|
data.to_reader(),
|
||||||
schema.with_metadata(metadata),
|
schema.with_metadata(metadata),
|
||||||
)
|
).read_all()
|
||||||
assert output.schema == schema
|
assert output.schema == schema
|
||||||
assert output["vector"].null_count == 0
|
assert output["vector"].null_count == 0
|
||||||
|
|
||||||
@@ -170,9 +170,9 @@ def test_append_vector_columns():
|
|||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
output = _append_vector_columns(
|
output = _append_vector_columns(
|
||||||
data,
|
data.to_reader(),
|
||||||
schema.with_metadata(metadata),
|
schema.with_metadata(metadata),
|
||||||
)
|
).read_all()
|
||||||
assert output == data # No change
|
assert output == data # No change
|
||||||
|
|
||||||
# No provided schema
|
# No provided schema
|
||||||
@@ -182,9 +182,9 @@ def test_append_vector_columns():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
output = _append_vector_columns(
|
output = _append_vector_columns(
|
||||||
data,
|
data.to_reader(),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
).read_all()
|
||||||
expected_schema = pa.schema(
|
expected_schema = pa.schema(
|
||||||
{
|
{
|
||||||
"text": pa.string(),
|
"text": pa.string(),
|
||||||
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
|||||||
if on_bad_vectors == "error":
|
if on_bad_vectors == "error":
|
||||||
with pytest.raises(ValueError) as e:
|
with pytest.raises(ValueError) as e:
|
||||||
output = _handle_bad_vectors(
|
output = _handle_bad_vectors(
|
||||||
data,
|
data.to_reader(),
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
)
|
).read_all()
|
||||||
output = exception_output(e)
|
output = exception_output(e)
|
||||||
assert output == (
|
assert output == (
|
||||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||||
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
output = _handle_bad_vectors(
|
output = _handle_bad_vectors(
|
||||||
data,
|
data.to_reader(),
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=42.0,
|
fill_value=42.0,
|
||||||
)
|
).read_all()
|
||||||
|
|
||||||
if on_bad_vectors == "drop":
|
if on_bad_vectors == "drop":
|
||||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||||
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
|||||||
if on_bad_vectors == "error":
|
if on_bad_vectors == "error":
|
||||||
with pytest.raises(ValueError) as e:
|
with pytest.raises(ValueError) as e:
|
||||||
output = _handle_bad_vectors(
|
output = _handle_bad_vectors(
|
||||||
data,
|
data.to_reader(),
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
)
|
).read_all()
|
||||||
output = exception_output(e)
|
output = exception_output(e)
|
||||||
assert output == (
|
assert output == (
|
||||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||||
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
output = _handle_bad_vectors(
|
output = _handle_bad_vectors(
|
||||||
data,
|
data.to_reader(),
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=42.0,
|
fill_value=42.0,
|
||||||
)
|
).read_all()
|
||||||
|
|
||||||
if on_bad_vectors == "drop":
|
if on_bad_vectors == "drop":
|
||||||
expected = pa.array([[3.0, 4.0]])
|
expected = pa.array([[3.0, 4.0]])
|
||||||
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
|
|||||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||||
)
|
)
|
||||||
data = pa.table({"vector": vector})
|
data = pa.table({"vector": vector})
|
||||||
output = _handle_bad_vectors(data)
|
output = _handle_bad_vectors(data.to_reader()).read_all()
|
||||||
assert output["vector"] == vector
|
assert output["vector"] == vector
|
||||||
|
|
||||||
|
|
||||||
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
|
|||||||
)
|
)
|
||||||
def test_into_pyarrow_table(data):
|
def test_into_pyarrow_table(data):
|
||||||
expected = pa.table({"a": [1], "b": [2]})
|
expected = pa.table({"a": [1], "b": [2]})
|
||||||
output = _into_pyarrow_table(data())
|
output = _into_pyarrow_reader(data()).read_all()
|
||||||
assert output == expected
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
@@ -349,7 +349,7 @@ def test_infer_target_schema():
|
|||||||
"vector": pa.list_(pa.float32(), 2),
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output = _infer_target_schema(data)
|
output, _ = _infer_target_schema(data.to_reader())
|
||||||
assert output == expected
|
assert output == expected
|
||||||
|
|
||||||
# Handle large list and use modal size
|
# Handle large list and use modal size
|
||||||
@@ -370,7 +370,7 @@ def test_infer_target_schema():
|
|||||||
"vector": pa.list_(pa.float32(), 2),
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output = _infer_target_schema(data)
|
output, _ = _infer_target_schema(data.to_reader())
|
||||||
assert output == expected
|
assert output == expected
|
||||||
|
|
||||||
# ignore if not list
|
# ignore if not list
|
||||||
@@ -386,7 +386,7 @@ def test_infer_target_schema():
|
|||||||
schema=example,
|
schema=example,
|
||||||
)
|
)
|
||||||
expected = example
|
expected = example
|
||||||
output = _infer_target_schema(data)
|
output, _ = _infer_target_schema(data.to_reader())
|
||||||
assert output == expected
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
@@ -434,7 +434,7 @@ def test_sanitize_data(
|
|||||||
conf = EmbeddingFunctionConfig(
|
conf = EmbeddingFunctionConfig(
|
||||||
source_column="text",
|
source_column="text",
|
||||||
vector_column="vector",
|
vector_column="vector",
|
||||||
function=MockTextEmbeddingFunction(),
|
function=MockTextEmbeddingFunction.create(),
|
||||||
)
|
)
|
||||||
metadata = registry.get_table_metadata([conf])
|
metadata = registry.get_table_metadata([conf])
|
||||||
else:
|
else:
|
||||||
@@ -476,7 +476,7 @@ def test_sanitize_data(
|
|||||||
target_schema=schema,
|
target_schema=schema,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
allow_subschema=True,
|
allow_subschema=True,
|
||||||
)
|
).read_all()
|
||||||
|
|
||||||
assert output_data == expected
|
assert output_data == expected
|
||||||
|
|
||||||
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
|
|||||||
"vec2": pa.list_(pa.float32(), 2),
|
"vec2": pa.list_(pa.float32(), 2),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output = _cast_to_target_schema(data, target)
|
output = _cast_to_target_schema(data.to_reader(), target)
|
||||||
expected = pa.table(
|
expected = pa.table(
|
||||||
{
|
{
|
||||||
"id": [1],
|
"id": [1],
|
||||||
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
_cast_to_target_schema(data, target)
|
_cast_to_target_schema(data.to_reader(), target)
|
||||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
output = _cast_to_target_schema(
|
||||||
|
data.to_reader(), target, allow_subschema=True
|
||||||
|
).read_all()
|
||||||
expected_schema = pa.schema(
|
expected_schema = pa.schema(
|
||||||
{
|
{
|
||||||
"id": pa.int64(),
|
"id": pa.int64(),
|
||||||
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
|
|||||||
schema=expected_schema,
|
schema=expected_schema,
|
||||||
)
|
)
|
||||||
assert output == expected
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_data_stream():
|
||||||
|
# Make sure we don't collect the whole stream when running sanitize_data
|
||||||
|
schema = pa.schema({"a": pa.int32()})
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||||
|
raise ValueError("error")
|
||||||
|
|
||||||
|
reader = pa.RecordBatchReader.from_batches(schema, stream())
|
||||||
|
|
||||||
|
output = _sanitize_data(reader)
|
||||||
|
|
||||||
|
first = next(output)
|
||||||
|
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
next(output)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
|
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||||
use lancedb::connection::{Connection as LanceConnection, CreateTableMode, LanceFileVersion};
|
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||||
@@ -80,15 +80,13 @@ impl Connection {
|
|||||||
future_into_py(self_.py(), async move { op.execute().await.infer_error() })
|
future_into_py(self_.py(), async move { op.execute().await.infer_error() })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (name, mode, data, storage_options=None, data_storage_version=None, enable_v2_manifest_paths=None))]
|
#[pyo3(signature = (name, mode, data, storage_options=None))]
|
||||||
pub fn create_table<'a>(
|
pub fn create_table<'a>(
|
||||||
self_: PyRef<'a, Self>,
|
self_: PyRef<'a, Self>,
|
||||||
name: String,
|
name: String,
|
||||||
mode: &str,
|
mode: &str,
|
||||||
data: Bound<'_, PyAny>,
|
data: Bound<'_, PyAny>,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
data_storage_version: Option<String>,
|
|
||||||
enable_v2_manifest_paths: Option<bool>,
|
|
||||||
) -> PyResult<Bound<'a, PyAny>> {
|
) -> PyResult<Bound<'a, PyAny>> {
|
||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
|
|
||||||
@@ -101,32 +99,19 @@ impl Connection {
|
|||||||
builder = builder.storage_options(storage_options);
|
builder = builder.storage_options(storage_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
|
||||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(data_storage_version) = data_storage_version.as_ref() {
|
|
||||||
builder = builder.data_storage_version(
|
|
||||||
LanceFileVersion::from_str(data_storage_version)
|
|
||||||
.map_err(|e| PyValueError::new_err(e.to_string()))?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let table = builder.execute().await.infer_error()?;
|
let table = builder.execute().await.infer_error()?;
|
||||||
Ok(Table::new(table))
|
Ok(Table::new(table))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (name, mode, schema, storage_options=None, data_storage_version=None, enable_v2_manifest_paths=None))]
|
#[pyo3(signature = (name, mode, schema, storage_options=None))]
|
||||||
pub fn create_empty_table<'a>(
|
pub fn create_empty_table<'a>(
|
||||||
self_: PyRef<'a, Self>,
|
self_: PyRef<'a, Self>,
|
||||||
name: String,
|
name: String,
|
||||||
mode: &str,
|
mode: &str,
|
||||||
schema: Bound<'_, PyAny>,
|
schema: Bound<'_, PyAny>,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
data_storage_version: Option<String>,
|
|
||||||
enable_v2_manifest_paths: Option<bool>,
|
|
||||||
) -> PyResult<Bound<'a, PyAny>> {
|
) -> PyResult<Bound<'a, PyAny>> {
|
||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
|
|
||||||
@@ -140,17 +125,6 @@ impl Connection {
|
|||||||
builder = builder.storage_options(storage_options);
|
builder = builder.storage_options(storage_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
|
||||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(data_storage_version) = data_storage_version.as_ref() {
|
|
||||||
builder = builder.data_storage_version(
|
|
||||||
LanceFileVersion::from_str(data_storage_version)
|
|
||||||
.map_err(|e| PyValueError::new_err(e.to_string()))?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let table = builder.execute().await.infer_error()?;
|
let table = builder.execute().await.infer_error()?;
|
||||||
Ok(Table::new(table))
|
Ok(Table::new(table))
|
||||||
@@ -196,12 +170,11 @@ impl Connection {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn drop_db(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
pub fn drop_all_tables(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
future_into_py(
|
future_into_py(self_.py(), async move {
|
||||||
self_.py(),
|
inner.drop_all_tables().await.infer_error()
|
||||||
async move { inner.drop_db().await.infer_error() },
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,6 +222,7 @@ pub struct PyClientConfig {
|
|||||||
user_agent: String,
|
user_agent: String,
|
||||||
retry_config: Option<PyClientRetryConfig>,
|
retry_config: Option<PyClientRetryConfig>,
|
||||||
timeout_config: Option<PyClientTimeoutConfig>,
|
timeout_config: Option<PyClientTimeoutConfig>,
|
||||||
|
extra_headers: Option<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
@@ -300,6 +274,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
|||||||
user_agent: value.user_agent,
|
user_agent: value.user_agent,
|
||||||
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
||||||
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
||||||
|
extra_headers: value.extra_headers.unwrap_or_default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ use arrow::pyarrow::FromPyArrow;
|
|||||||
use lancedb::index::scalar::FullTextSearchQuery;
|
use lancedb::index::scalar::FullTextSearchQuery;
|
||||||
use lancedb::query::QueryExecutionOptions;
|
use lancedb::query::QueryExecutionOptions;
|
||||||
use lancedb::query::{
|
use lancedb::query::{
|
||||||
ExecutableQuery, HasQuery, Query as LanceDbQuery, QueryBase, Select,
|
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||||
VectorQuery as LanceDbVectorQuery,
|
|
||||||
};
|
};
|
||||||
use pyo3::exceptions::PyRuntimeError;
|
use pyo3::exceptions::PyRuntimeError;
|
||||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||||
@@ -313,7 +312,8 @@ impl VectorQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
|
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
|
||||||
let fts_query = Query::new(self.inner.mut_query().clone()).nearest_to_text(query)?;
|
let base_query = self.inner.clone().into_plain();
|
||||||
|
let fts_query = Query::new(base_query).nearest_to_text(query)?;
|
||||||
Ok(HybridQuery {
|
Ok(HybridQuery {
|
||||||
inner_vec: self.clone(),
|
inner_vec: self.clone(),
|
||||||
inner_fts: fts_query,
|
inner_fts: fts_query,
|
||||||
@@ -411,10 +411,14 @@ impl HybridQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_limit(&mut self) -> Option<u32> {
|
pub fn get_limit(&mut self) -> Option<u32> {
|
||||||
self.inner_fts.inner.limit.map(|i| i as u32)
|
self.inner_fts
|
||||||
|
.inner
|
||||||
|
.current_request()
|
||||||
|
.limit
|
||||||
|
.map(|i| i as u32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_with_row_id(&mut self) -> bool {
|
pub fn get_with_row_id(&mut self) -> bool {
|
||||||
self.inner_fts.inner.with_row_id
|
self.inner_fts.inner.current_request().with_row_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-node"
|
name = "lancedb-node"
|
||||||
version = "0.15.1-beta.3"
|
version = "0.16.1-beta.3"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
|
|||||||
@@ -169,5 +169,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
|||||||
cx.export_function("tableAddColumns", JsTable::js_add_columns)?;
|
cx.export_function("tableAddColumns", JsTable::js_add_columns)?;
|
||||||
cx.export_function("tableAlterColumns", JsTable::js_alter_columns)?;
|
cx.export_function("tableAlterColumns", JsTable::js_alter_columns)?;
|
||||||
cx.export_function("tableDropColumns", JsTable::js_drop_columns)?;
|
cx.export_function("tableDropColumns", JsTable::js_drop_columns)?;
|
||||||
|
cx.export_function("tableDropIndex", JsTable::js_drop_index)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -638,4 +638,8 @@ impl JsTable {
|
|||||||
|
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_drop_index(_cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
todo!("not implemented")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.15.1-beta.3"
|
version = "0.16.1-beta.3"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -19,7 +19,10 @@ arrow-ord = { workspace = true }
|
|||||||
arrow-cast = { workspace = true }
|
arrow-cast = { workspace = true }
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
chrono = { workspace = true }
|
chrono = { workspace = true }
|
||||||
|
datafusion-catalog.workspace = true
|
||||||
datafusion-common.workspace = true
|
datafusion-common.workspace = true
|
||||||
|
datafusion-execution.workspace = true
|
||||||
|
datafusion-expr.workspace = true
|
||||||
datafusion-physical-plan.workspace = true
|
datafusion-physical-plan.workspace = true
|
||||||
object_store = { workspace = true }
|
object_store = { workspace = true }
|
||||||
snafu = { workspace = true }
|
snafu = { workspace = true }
|
||||||
@@ -33,7 +36,7 @@ lance-table = { workspace = true }
|
|||||||
lance-linalg = { workspace = true }
|
lance-linalg = { workspace = true }
|
||||||
lance-testing = { workspace = true }
|
lance-testing = { workspace = true }
|
||||||
lance-encoding = { workspace = true }
|
lance-encoding = { workspace = true }
|
||||||
moka = { workspace = true}
|
moka = { workspace = true }
|
||||||
pin-project = { workspace = true }
|
pin-project = { workspace = true }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
@@ -82,7 +85,8 @@ aws-sdk-s3 = { version = "1.38.0" }
|
|||||||
aws-sdk-kms = { version = "1.37" }
|
aws-sdk-kms = { version = "1.37" }
|
||||||
aws-config = { version = "1.0" }
|
aws-config = { version = "1.0" }
|
||||||
aws-smithy-runtime = { version = "1.3" }
|
aws-smithy-runtime = { version = "1.3" }
|
||||||
http-body = "1" # Matching reqwest
|
datafusion.workspace = true
|
||||||
|
http-body = "1" # Matching reqwest
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
@@ -98,7 +102,7 @@ sentence-transformers = [
|
|||||||
"dep:candle-core",
|
"dep:candle-core",
|
||||||
"dep:candle-transformers",
|
"dep:candle-transformers",
|
||||||
"dep:candle-nn",
|
"dep:candle-nn",
|
||||||
"dep:tokenizers"
|
"dep:tokenizers",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TLS
|
# TLS
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
133
rust/lancedb/src/database.rs
Normal file
133
rust/lancedb/src/database.rs
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
//! The database module defines the `Database` trait and related types.
|
||||||
|
//!
|
||||||
|
//! A "database" is a generic concept for something that manages tables and their metadata.
|
||||||
|
//!
|
||||||
|
//! We provide a basic implementation of a database that requires no additional infrastructure
|
||||||
|
//! and is based off listing directories in a filesystem.
|
||||||
|
//!
|
||||||
|
//! Users may want to provider their own implementations for a variety of reasons:
|
||||||
|
//! * Tables may be arranged in a different order on the S3 filesystem
|
||||||
|
//! * Tables may be managed by some kind of independent application (e.g. some database)
|
||||||
|
//! * Tables may be managed by a database system (e.g. Postgres)
|
||||||
|
//! * A custom table implementation (e.g. remote table, etc.) may be used
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchReader;
|
||||||
|
use lance::dataset::ReadParams;
|
||||||
|
|
||||||
|
use crate::error::Result;
|
||||||
|
use crate::table::{BaseTable, TableDefinition, WriteOptions};
|
||||||
|
|
||||||
|
pub mod listing;
|
||||||
|
|
||||||
|
pub trait DatabaseOptions {
|
||||||
|
fn serialize_into_map(&self, map: &mut HashMap<String, String>);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request to list names of tables in the database
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct TableNamesRequest {
|
||||||
|
/// If present, only return names that come lexicographically after the supplied
|
||||||
|
/// value.
|
||||||
|
///
|
||||||
|
/// This can be combined with limit to implement pagination by setting this to
|
||||||
|
/// the last table name from the previous page.
|
||||||
|
pub start_after: Option<String>,
|
||||||
|
/// The maximum number of table names to return
|
||||||
|
pub limit: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request to open a table
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct OpenTableRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub index_cache_size: Option<u32>,
|
||||||
|
pub lance_read_params: Option<ReadParams>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableRequest) -> OpenTableRequest + Send>;
|
||||||
|
|
||||||
|
/// Describes what happens when creating a table and a table with
|
||||||
|
/// the same name already exists
|
||||||
|
pub enum CreateTableMode {
|
||||||
|
/// If the table already exists, an error is returned
|
||||||
|
Create,
|
||||||
|
/// If the table already exists, it is opened. Any provided data is
|
||||||
|
/// ignored. The function will be passed an OpenTableBuilder to customize
|
||||||
|
/// how the table is opened
|
||||||
|
ExistOk(TableBuilderCallback),
|
||||||
|
/// If the table already exists, it is overwritten
|
||||||
|
Overwrite,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CreateTableMode {
|
||||||
|
pub fn exist_ok(
|
||||||
|
callback: impl FnOnce(OpenTableRequest) -> OpenTableRequest + Send + 'static,
|
||||||
|
) -> Self {
|
||||||
|
Self::ExistOk(Box::new(callback))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CreateTableMode {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Create
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The data to start a table or a schema to create an empty table
|
||||||
|
pub enum CreateTableData {
|
||||||
|
/// Creates a table using data, no schema required as it will be obtained from the data
|
||||||
|
Data(Box<dyn RecordBatchReader + Send>),
|
||||||
|
/// Creates an empty table, the definition / schema must be provided separately
|
||||||
|
Empty(TableDefinition),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request to create a table
|
||||||
|
pub struct CreateTableRequest {
|
||||||
|
/// The name of the new table
|
||||||
|
pub name: String,
|
||||||
|
/// Initial data to write to the table, can be None to create an empty table
|
||||||
|
pub data: CreateTableData,
|
||||||
|
/// The mode to use when creating the table
|
||||||
|
pub mode: CreateTableMode,
|
||||||
|
/// Options to use when writing data (only used if `data` is not None)
|
||||||
|
pub write_options: WriteOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CreateTableRequest {
|
||||||
|
pub fn new(name: String, data: CreateTableData) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
data,
|
||||||
|
mode: CreateTableMode::default(),
|
||||||
|
write_options: WriteOptions::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `Database` trait defines the interface for database implementations.
|
||||||
|
///
|
||||||
|
/// A database is responsible for managing tables and their metadata.
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
pub trait Database:
|
||||||
|
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
||||||
|
{
|
||||||
|
/// List the names of tables in the database
|
||||||
|
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>>;
|
||||||
|
/// Create a table in the database
|
||||||
|
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||||
|
/// Open a table in the database
|
||||||
|
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||||
|
/// Rename a table in the database
|
||||||
|
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||||
|
/// Drop a table in the database
|
||||||
|
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||||
|
/// Drop all tables in the database
|
||||||
|
async fn drop_all_tables(&self) -> Result<()>;
|
||||||
|
fn as_any(&self) -> &dyn std::any::Any;
|
||||||
|
}
|
||||||
542
rust/lancedb/src/database/listing.rs
Normal file
542
rust/lancedb/src/database/listing.rs
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
//! Provides the `ListingDatabase`, a simple database where tables are folders in a directory
|
||||||
|
|
||||||
|
use std::fs::create_dir_all;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchIterator;
|
||||||
|
use lance::dataset::{ReadParams, WriteMode};
|
||||||
|
use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore};
|
||||||
|
use lance_encoding::version::LanceFileVersion;
|
||||||
|
use lance_table::io::commit::commit_handler_from_url;
|
||||||
|
use object_store::local::LocalFileSystem;
|
||||||
|
use snafu::{OptionExt, ResultExt};
|
||||||
|
|
||||||
|
use crate::connection::ConnectRequest;
|
||||||
|
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||||
|
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||||
|
use crate::table::NativeTable;
|
||||||
|
use crate::utils::validate_table_name;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||||
|
OpenTableRequest, TableNamesRequest,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// File extension to indicate a lance table
|
||||||
|
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||||
|
|
||||||
|
pub const OPT_NEW_TABLE_STORAGE_VERSION: &str = "new_table_data_storage_version";
|
||||||
|
pub const OPT_NEW_TABLE_V2_MANIFEST_PATHS: &str = "new_table_enable_v2_manifest_paths";
|
||||||
|
|
||||||
|
/// Controls how new tables should be created
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct NewTableConfig {
|
||||||
|
/// The storage version to use for new tables
|
||||||
|
///
|
||||||
|
/// If unset, then the latest stable version will be used
|
||||||
|
pub data_storage_version: Option<LanceFileVersion>,
|
||||||
|
/// Whether to enable V2 manifest paths for new tables
|
||||||
|
///
|
||||||
|
/// V2 manifest paths are more efficient than V2 manifest paths but are not
|
||||||
|
/// supported by old clients.
|
||||||
|
pub enable_v2_manifest_paths: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Options specific to the listing database
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct ListingDatabaseOptions {
|
||||||
|
/// Controls what kind of Lance tables will be created by this database
|
||||||
|
pub new_table_config: NewTableConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ListingDatabaseOptions {
|
||||||
|
fn parse_from_map(map: &HashMap<String, String>) -> Result<Self> {
|
||||||
|
let new_table_config = NewTableConfig {
|
||||||
|
data_storage_version: map
|
||||||
|
.get(OPT_NEW_TABLE_STORAGE_VERSION)
|
||||||
|
.map(|s| s.parse())
|
||||||
|
.transpose()?,
|
||||||
|
enable_v2_manifest_paths: map
|
||||||
|
.get(OPT_NEW_TABLE_V2_MANIFEST_PATHS)
|
||||||
|
.map(|s| {
|
||||||
|
s.parse::<bool>().map_err(|_| Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"enable_v2_manifest_paths must be a boolean, received {}",
|
||||||
|
s
|
||||||
|
),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?,
|
||||||
|
};
|
||||||
|
Ok(Self { new_table_config })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseOptions for ListingDatabaseOptions {
|
||||||
|
fn serialize_into_map(&self, map: &mut HashMap<String, String>) {
|
||||||
|
if let Some(storage_version) = &self.new_table_config.data_storage_version {
|
||||||
|
map.insert(
|
||||||
|
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
|
||||||
|
storage_version.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(enable_v2_manifest_paths) = self.new_table_config.enable_v2_manifest_paths {
|
||||||
|
map.insert(
|
||||||
|
OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(),
|
||||||
|
enable_v2_manifest_paths.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A database that stores tables in a flat directory structure
|
||||||
|
///
|
||||||
|
/// Tables are stored as directories in the base path of the object store.
|
||||||
|
///
|
||||||
|
/// It is called a "listing database" because we use a "list directory" operation
|
||||||
|
/// to discover what tables are available. Table names are determined from the directory
|
||||||
|
/// names.
|
||||||
|
///
|
||||||
|
/// For example, given the following directory structure:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// /data
|
||||||
|
/// /table1.lance
|
||||||
|
/// /table2.lance
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// We will have two tables named `table1` and `table2`.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ListingDatabase {
|
||||||
|
object_store: ObjectStore,
|
||||||
|
query_string: Option<String>,
|
||||||
|
|
||||||
|
pub(crate) uri: String,
|
||||||
|
pub(crate) base_path: object_store::path::Path,
|
||||||
|
|
||||||
|
// the object store wrapper to use on write path
|
||||||
|
pub(crate) store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||||
|
|
||||||
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
|
|
||||||
|
// Storage options to be inherited by tables created from this connection
|
||||||
|
storage_options: HashMap<String, String>,
|
||||||
|
|
||||||
|
// Options for tables created by this connection
|
||||||
|
new_table_config: NewTableConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ListingDatabase {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"ListingDatabase(uri={}, read_consistency_interval={})",
|
||||||
|
self.uri,
|
||||||
|
match self.read_consistency_interval {
|
||||||
|
None => {
|
||||||
|
"None".to_string()
|
||||||
|
}
|
||||||
|
Some(duration) => {
|
||||||
|
format!("{}s", duration.as_secs_f64())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const LANCE_EXTENSION: &str = "lance";
|
||||||
|
const ENGINE: &str = "engine";
|
||||||
|
const MIRRORED_STORE: &str = "mirroredStore";
|
||||||
|
|
||||||
|
/// A connection to LanceDB
|
||||||
|
impl ListingDatabase {
|
||||||
|
/// Connect to a listing database
|
||||||
|
///
|
||||||
|
/// The URI should be a path to a directory where the tables are stored.
|
||||||
|
///
|
||||||
|
/// See [`ListingDatabaseOptions`] for options that can be set on the connection (via
|
||||||
|
/// `storage_options`).
|
||||||
|
pub async fn connect_with_options(request: &ConnectRequest) -> Result<Self> {
|
||||||
|
let uri = &request.uri;
|
||||||
|
let parse_res = url::Url::parse(uri);
|
||||||
|
|
||||||
|
let options = ListingDatabaseOptions::parse_from_map(&request.storage_options)?;
|
||||||
|
|
||||||
|
// TODO: pass params regardless of OS
|
||||||
|
match parse_res {
|
||||||
|
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||||
|
Self::open_path(
|
||||||
|
uri,
|
||||||
|
request.read_consistency_interval,
|
||||||
|
options.new_table_config,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
Ok(mut url) => {
|
||||||
|
// iter thru the query params and extract the commit store param
|
||||||
|
let mut engine = None;
|
||||||
|
let mut mirrored_store = None;
|
||||||
|
let mut filtered_querys = vec![];
|
||||||
|
|
||||||
|
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||||
|
// THE API WILL CHANGE
|
||||||
|
for (key, value) in url.query_pairs() {
|
||||||
|
if key == ENGINE {
|
||||||
|
engine = Some(value.to_string());
|
||||||
|
} else if key == MIRRORED_STORE {
|
||||||
|
if cfg!(windows) {
|
||||||
|
return Err(Error::NotSupported {
|
||||||
|
message: "mirrored store is not supported on windows".into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
mirrored_store = Some(value.to_string());
|
||||||
|
} else {
|
||||||
|
// to owned so we can modify the url
|
||||||
|
filtered_querys.push((key.to_string(), value.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out the commit store query param -- it's a lancedb param
|
||||||
|
url.query_pairs_mut().clear();
|
||||||
|
url.query_pairs_mut().extend_pairs(filtered_querys);
|
||||||
|
// Take a copy of the query string so we can propagate it to lance
|
||||||
|
let query_string = url.query().map(|s| s.to_string());
|
||||||
|
// clear the query string so we can use the url as the base uri
|
||||||
|
// use .set_query(None) instead of .set_query("") because the latter
|
||||||
|
// will add a trailing '?' to the url
|
||||||
|
url.set_query(None);
|
||||||
|
|
||||||
|
let table_base_uri = if let Some(store) = engine {
|
||||||
|
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
||||||
|
WARN_ONCE.call_once(|| {
|
||||||
|
log::warn!("Specifying engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
||||||
|
});
|
||||||
|
let old_scheme = url.scheme().to_string();
|
||||||
|
let new_scheme = format!("{}+{}", old_scheme, store);
|
||||||
|
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
||||||
|
} else {
|
||||||
|
url.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let plain_uri = url.to_string();
|
||||||
|
|
||||||
|
let registry = Arc::new(ObjectStoreRegistry::default());
|
||||||
|
let storage_options = request.storage_options.clone();
|
||||||
|
let os_params = ObjectStoreParams {
|
||||||
|
storage_options: Some(storage_options.clone()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let (object_store, base_path) =
|
||||||
|
ObjectStore::from_uri_and_params(registry, &plain_uri, &os_params).await?;
|
||||||
|
if object_store.is_local() {
|
||||||
|
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let write_store_wrapper = match mirrored_store {
|
||||||
|
Some(path) => {
|
||||||
|
let mirrored_store = Arc::new(LocalFileSystem::new_with_prefix(path)?);
|
||||||
|
let wrapper = MirroringObjectStoreWrapper::new(mirrored_store);
|
||||||
|
Some(Arc::new(wrapper) as Arc<dyn WrappingObjectStore>)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
uri: table_base_uri,
|
||||||
|
query_string,
|
||||||
|
base_path,
|
||||||
|
object_store,
|
||||||
|
store_wrapper: write_store_wrapper,
|
||||||
|
read_consistency_interval: request.read_consistency_interval,
|
||||||
|
storage_options,
|
||||||
|
new_table_config: options.new_table_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
Self::open_path(
|
||||||
|
uri,
|
||||||
|
request.read_consistency_interval,
|
||||||
|
options.new_table_config,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn open_path(
|
||||||
|
path: &str,
|
||||||
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
|
new_table_config: NewTableConfig,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||||
|
if object_store.is_local() {
|
||||||
|
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
uri: path.to_string(),
|
||||||
|
query_string: None,
|
||||||
|
base_path,
|
||||||
|
object_store,
|
||||||
|
store_wrapper: None,
|
||||||
|
read_consistency_interval,
|
||||||
|
storage_options: HashMap::new(),
|
||||||
|
new_table_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to create a local directory to store the lancedb dataset
|
||||||
|
fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
|
||||||
|
let path = Path::new(path);
|
||||||
|
if !path.try_exists()? {
|
||||||
|
create_dir_all(path)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the URI of a table in the database.
|
||||||
|
fn table_uri(&self, name: &str) -> Result<String> {
|
||||||
|
validate_table_name(name)?;
|
||||||
|
|
||||||
|
let path = Path::new(&self.uri);
|
||||||
|
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||||
|
|
||||||
|
let mut uri = table_uri
|
||||||
|
.as_path()
|
||||||
|
.to_str()
|
||||||
|
.context(InvalidTableNameSnafu {
|
||||||
|
name,
|
||||||
|
reason: "Name is not valid URL",
|
||||||
|
})?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// If there are query string set on the connection, propagate to lance
|
||||||
|
if let Some(query) = self.query_string.as_ref() {
|
||||||
|
uri.push('?');
|
||||||
|
uri.push_str(query.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Database for ListingDatabase {
|
||||||
|
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||||
|
let mut f = self
|
||||||
|
.object_store
|
||||||
|
.read_dir(self.base_path.clone())
|
||||||
|
.await?
|
||||||
|
.iter()
|
||||||
|
.map(Path::new)
|
||||||
|
.filter(|path| {
|
||||||
|
let is_lance = path
|
||||||
|
.extension()
|
||||||
|
.and_then(|e| e.to_str())
|
||||||
|
.map(|e| e == LANCE_EXTENSION);
|
||||||
|
is_lance.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
f.sort();
|
||||||
|
if let Some(start_after) = request.start_after {
|
||||||
|
let index = f
|
||||||
|
.iter()
|
||||||
|
.position(|name| name.as_str() > start_after.as_str())
|
||||||
|
.unwrap_or(f.len());
|
||||||
|
f.drain(0..index);
|
||||||
|
}
|
||||||
|
if let Some(limit) = request.limit {
|
||||||
|
f.truncate(limit as usize);
|
||||||
|
}
|
||||||
|
Ok(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_table(&self, mut request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||||
|
let table_uri = self.table_uri(&request.name)?;
|
||||||
|
// Inherit storage options from the connection
|
||||||
|
let storage_options = request
|
||||||
|
.write_options
|
||||||
|
.lance_write_params
|
||||||
|
.get_or_insert_with(Default::default)
|
||||||
|
.store_params
|
||||||
|
.get_or_insert_with(Default::default)
|
||||||
|
.storage_options
|
||||||
|
.get_or_insert_with(Default::default);
|
||||||
|
for (key, value) in self.storage_options.iter() {
|
||||||
|
if !storage_options.contains_key(key) {
|
||||||
|
storage_options.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let storage_options = storage_options.clone();
|
||||||
|
|
||||||
|
let mut write_params = request.write_options.lance_write_params.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Some(storage_version) = &self.new_table_config.data_storage_version {
|
||||||
|
write_params.data_storage_version = Some(*storage_version);
|
||||||
|
} else {
|
||||||
|
// Allow the user to override the storage version via storage options (backwards compatibility)
|
||||||
|
if let Some(data_storage_version) = storage_options.get(OPT_NEW_TABLE_STORAGE_VERSION) {
|
||||||
|
write_params.data_storage_version = Some(data_storage_version.parse()?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(enable_v2_manifest_paths) = self.new_table_config.enable_v2_manifest_paths {
|
||||||
|
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||||
|
} else {
|
||||||
|
// Allow the user to override the storage version via storage options (backwards compatibility)
|
||||||
|
if let Some(enable_v2_manifest_paths) = storage_options
|
||||||
|
.get(OPT_NEW_TABLE_V2_MANIFEST_PATHS)
|
||||||
|
.map(|s| s.parse::<bool>().unwrap())
|
||||||
|
{
|
||||||
|
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches!(&request.mode, CreateTableMode::Overwrite) {
|
||||||
|
write_params.mode = WriteMode::Overwrite;
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = match request.data {
|
||||||
|
CreateTableData::Data(data) => data,
|
||||||
|
CreateTableData::Empty(table_definition) => {
|
||||||
|
let schema = table_definition.schema.clone();
|
||||||
|
Box::new(RecordBatchIterator::new(vec![], schema))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let data_schema = data.schema();
|
||||||
|
|
||||||
|
match NativeTable::create(
|
||||||
|
&table_uri,
|
||||||
|
&request.name,
|
||||||
|
data,
|
||||||
|
self.store_wrapper.clone(),
|
||||||
|
Some(write_params),
|
||||||
|
self.read_consistency_interval,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(table) => Ok(Arc::new(table)),
|
||||||
|
Err(Error::TableAlreadyExists { name }) => match request.mode {
|
||||||
|
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||||
|
CreateTableMode::ExistOk(callback) => {
|
||||||
|
let req = OpenTableRequest {
|
||||||
|
name: request.name.clone(),
|
||||||
|
index_cache_size: None,
|
||||||
|
lance_read_params: None,
|
||||||
|
};
|
||||||
|
let req = (callback)(req);
|
||||||
|
let table = self.open_table(req).await?;
|
||||||
|
|
||||||
|
let table_schema = table.schema().await?;
|
||||||
|
|
||||||
|
if table_schema != data_schema {
|
||||||
|
return Err(Error::Schema {
|
||||||
|
message: "Provided schema does not match existing table schema"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(table)
|
||||||
|
}
|
||||||
|
CreateTableMode::Overwrite => unreachable!(),
|
||||||
|
},
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||||
|
let table_uri = self.table_uri(&request.name)?;
|
||||||
|
|
||||||
|
// Inherit storage options from the connection
|
||||||
|
let storage_options = request
|
||||||
|
.lance_read_params
|
||||||
|
.get_or_insert_with(Default::default)
|
||||||
|
.store_options
|
||||||
|
.get_or_insert_with(Default::default)
|
||||||
|
.storage_options
|
||||||
|
.get_or_insert_with(Default::default);
|
||||||
|
for (key, value) in self.storage_options.iter() {
|
||||||
|
if !storage_options.contains_key(key) {
|
||||||
|
storage_options.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some ReadParams are exposed in the OpenTableBuilder, but we also
|
||||||
|
// let the user provide their own ReadParams.
|
||||||
|
//
|
||||||
|
// If we have a user provided ReadParams use that
|
||||||
|
// If we don't then start with the default ReadParams and customize it with
|
||||||
|
// the options from the OpenTableBuilder
|
||||||
|
let read_params = request.lance_read_params.unwrap_or_else(|| {
|
||||||
|
let mut default_params = ReadParams::default();
|
||||||
|
if let Some(index_cache_size) = request.index_cache_size {
|
||||||
|
default_params.index_cache_size = index_cache_size as usize;
|
||||||
|
}
|
||||||
|
default_params
|
||||||
|
});
|
||||||
|
|
||||||
|
let native_table = Arc::new(
|
||||||
|
NativeTable::open_with_params(
|
||||||
|
&table_uri,
|
||||||
|
&request.name,
|
||||||
|
self.store_wrapper.clone(),
|
||||||
|
Some(read_params),
|
||||||
|
self.read_consistency_interval,
|
||||||
|
)
|
||||||
|
.await?,
|
||||||
|
);
|
||||||
|
Ok(native_table)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||||
|
Err(Error::NotSupported {
|
||||||
|
message: "rename_table is not supported in LanceDB OSS".to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||||
|
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
|
||||||
|
let full_path = self.base_path.child(dir_name.clone());
|
||||||
|
self.object_store
|
||||||
|
.remove_dir_all(full_path.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|err| match err {
|
||||||
|
// this error is not lance::Error::DatasetNotFound,
|
||||||
|
// as the method `remove_dir_all` may be used to remove something not be a dataset
|
||||||
|
lance::Error::NotFound { .. } => Error::TableNotFound {
|
||||||
|
name: name.to_owned(),
|
||||||
|
},
|
||||||
|
_ => Error::from(err),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let object_store_params = ObjectStoreParams {
|
||||||
|
storage_options: Some(self.storage_options.clone()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut uri = self.uri.clone();
|
||||||
|
if let Some(query_string) = &self.query_string {
|
||||||
|
uri.push_str(&format!("?{}", query_string));
|
||||||
|
}
|
||||||
|
let commit_handler = commit_handler_from_url(&uri, &Some(object_store_params))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
commit_handler.delete(&full_path).await.unwrap();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn drop_all_tables(&self) -> Result<()> {
|
||||||
|
self.object_store
|
||||||
|
.remove_dir_all(self.base_path.clone())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,8 @@ pub enum Error {
|
|||||||
InvalidInput { message: String },
|
InvalidInput { message: String },
|
||||||
#[snafu(display("Table '{name}' was not found"))]
|
#[snafu(display("Table '{name}' was not found"))]
|
||||||
TableNotFound { name: String },
|
TableNotFound { name: String },
|
||||||
|
#[snafu(display("Index '{name}' was not found"))]
|
||||||
|
IndexNotFound { name: String },
|
||||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||||
EmbeddingFunctionNotFound { name: String, reason: String },
|
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use serde::Deserialize;
|
|||||||
use serde_with::skip_serializing_none;
|
use serde_with::skip_serializing_none;
|
||||||
use vector::IvfFlatIndexBuilder;
|
use vector::IvfFlatIndexBuilder;
|
||||||
|
|
||||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
use crate::{table::BaseTable, DistanceType, Error, Result};
|
||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
||||||
@@ -65,14 +65,14 @@ pub enum Index {
|
|||||||
///
|
///
|
||||||
/// The methods on this builder are used to specify options common to all indices.
|
/// The methods on this builder are used to specify options common to all indices.
|
||||||
pub struct IndexBuilder {
|
pub struct IndexBuilder {
|
||||||
parent: Arc<dyn TableInternal>,
|
parent: Arc<dyn BaseTable>,
|
||||||
pub(crate) index: Index,
|
pub(crate) index: Index,
|
||||||
pub(crate) columns: Vec<String>,
|
pub(crate) columns: Vec<String>,
|
||||||
pub(crate) replace: bool,
|
pub(crate) replace: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IndexBuilder {
|
impl IndexBuilder {
|
||||||
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
|
pub(crate) fn new(parent: Arc<dyn BaseTable>, columns: Vec<String>, index: Index) -> Self {
|
||||||
Self {
|
Self {
|
||||||
parent,
|
parent,
|
||||||
index,
|
index,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user