mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
Compare commits
22 Commits
v0.3.2
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3cf986777 | ||
|
|
c73fcc8898 | ||
|
|
cd9debc3b7 | ||
|
|
26a97ba997 | ||
|
|
ce19fedb08 | ||
|
|
14e8e48de2 | ||
|
|
c30faf6083 | ||
|
|
64a4f025bb | ||
|
|
6dc968e7d3 | ||
|
|
06b5b69f1e | ||
|
|
6bd3a838fc | ||
|
|
f36fea8f20 | ||
|
|
0a30591729 | ||
|
|
0ed39b6146 | ||
|
|
a8c7f80073 | ||
|
|
0293bbe142 | ||
|
|
7372656369 | ||
|
|
d46bc5dd6e | ||
|
|
86efb11572 | ||
|
|
bb01ad5290 | ||
|
|
1b8cda0941 | ||
|
|
bc85a749a3 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.2
|
current_version = 0.3.3
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ exclude = ["python"]
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.8.5", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.8.7", "features" = ["dynamodb"] }
|
||||||
lance-linalg = { "version" = "=0.8.5" }
|
lance-linalg = { "version" = "=0.8.7" }
|
||||||
lance-testing = { "version" = "=0.8.5" }
|
lance-testing = { "version" = "=0.8.7" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "47.0.0", optional = false }
|
arrow = { version = "47.0.0", optional = false }
|
||||||
arrow-array = "47.0"
|
arrow-array = "47.0"
|
||||||
@@ -18,7 +18,7 @@ arrow-schema = "47.0"
|
|||||||
arrow-arith = "47.0"
|
arrow-arith = "47.0"
|
||||||
arrow-cast = "47.0"
|
arrow-cast = "47.0"
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||||
"num-traits"
|
"num-traits"
|
||||||
] }
|
] }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
|||||||
26
docs/README.md
Normal file
26
docs/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# LanceDB Documentation
|
||||||
|
|
||||||
|
LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
|
||||||
|
|
||||||
|
Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
|
||||||
|
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
|
||||||
|
unreleased features.
|
||||||
|
|
||||||
|
## Building the docs
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
1. Install LanceDB. From LanceDB repo root: `pip install -e python`
|
||||||
|
2. Install dependencies. From LanceDB repo root: `pip install -r docs/requirements.txt`
|
||||||
|
3. Make sure you have node and npm setup
|
||||||
|
4. Make sure protobuf and libssl are installed
|
||||||
|
|
||||||
|
### Building node module and create markdown files
|
||||||
|
|
||||||
|
See [Javascript docs README](docs/src/javascript/README.md)
|
||||||
|
|
||||||
|
### Build docs
|
||||||
|
From LanceDB repo root:
|
||||||
|
|
||||||
|
Run: `PYTHONPATH=. mkdocs build -f docs/mkdocs.yml`
|
||||||
|
|
||||||
|
If successful, you should see a `docs/site` directory that you can verify locally.
|
||||||
@@ -73,12 +73,14 @@ nav:
|
|||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- SQL filters: sql.md
|
- SQL filters: sql.md
|
||||||
- Indexing: ann_indexes.md
|
- Indexing: ann_indexes.md
|
||||||
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- 🧬 Embeddings:
|
- 🧬 Embeddings:
|
||||||
- embeddings/index.md
|
- embeddings/index.md
|
||||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||||
- Available Functions: embeddings/default_embedding_functions.md
|
- Available Functions: embeddings/default_embedding_functions.md
|
||||||
- Create Custom Embedding Functions: embeddings/api.md
|
- Create Custom Embedding Functions: embeddings/api.md
|
||||||
- Example- MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||||
|
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- 🔍 Python full-text search: fts.md
|
- 🔍 Python full-text search: fts.md
|
||||||
- 🔌 Integrations:
|
- 🔌 Integrations:
|
||||||
- integrations/index.md
|
- integrations/index.md
|
||||||
@@ -110,12 +112,14 @@ nav:
|
|||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- SQL filters: sql.md
|
- SQL filters: sql.md
|
||||||
- Indexing: ann_indexes.md
|
- Indexing: ann_indexes.md
|
||||||
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- Embeddings:
|
- Embeddings:
|
||||||
- embeddings/index.md
|
- embeddings/index.md
|
||||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||||
- Available Functions: embeddings/default_embedding_functions.md
|
- Available Functions: embeddings/default_embedding_functions.md
|
||||||
- Create Custom Embedding Functions: embeddings/api.md
|
- Create Custom Embedding Functions: embeddings/api.md
|
||||||
- Example- MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||||
|
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- Python full-text search: fts.md
|
- Python full-text search: fts.md
|
||||||
- Integrations:
|
- Integrations:
|
||||||
- integrations/index.md
|
- integrations/index.md
|
||||||
@@ -146,6 +150,8 @@ nav:
|
|||||||
|
|
||||||
extra_css:
|
extra_css:
|
||||||
- styles/global.css
|
- styles/global.css
|
||||||
|
extra_javascript:
|
||||||
|
- scripts/posthog.js
|
||||||
|
|
||||||
extra:
|
extra:
|
||||||
analytics:
|
analytics:
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
|
|||||||
|
|
||||||
## Documentation Quick Links
|
## Documentation Quick Links
|
||||||
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
||||||
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
|
* [`Embedding Functions`](embeddings/index.md) - functions for working with embeddings.
|
||||||
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
||||||
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
||||||
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "88c1af18",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Example - MultiModal CLIP Embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "c6b5d346-2c2a-4341-a132-00e53543f8d1",
|
"id": "c6b5d346-2c2a-4341-a132-00e53543f8d1",
|
||||||
|
|||||||
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -26,17 +26,17 @@ pip install lancedb
|
|||||||
|
|
||||||
## Embeddings
|
## Embeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
::: lancedb.embeddings.base.EmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
::: lancedb.embeddings.sentence_transformers.SentenceTransformerEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
::: lancedb.embeddings.openai.OpenAIEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
::: lancedb.embeddings.open_clip.OpenClipEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.with_embeddings
|
::: lancedb.embeddings.with_embeddings
|
||||||
|
|
||||||
|
|||||||
4
docs/src/scripts/posthog.js
Normal file
4
docs/src/scripts/posthog.js
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
window.addEventListener("DOMContentLoaded", (event) => {
|
||||||
|
!function(t,e){var o,n,p,r;e.__SV||(window.posthog=e,e._i=[],e.init=function(i,s,a){function g(t,e){var o=e.split(".");2==o.length&&(t=t[o[0]],e=o[1]),t[e]=function(){t.push([e].concat(Array.prototype.slice.call(arguments,0)))}}(p=t.createElement("script")).type="text/javascript",p.async=!0,p.src=s.api_host+"/static/array.js",(r=t.getElementsByTagName("script")[0]).parentNode.insertBefore(p,r);var u=e;for(void 0!==a?u=e[a]=[]:a="posthog",u.people=u.people||[],u.toString=function(t){var e="posthog";return"posthog"!==a&&(e+="."+a),t||(e+=" (stub)"),e},u.people.toString=function(){return u.toString(1)+".people (stub)"},o="capture identify alias people.set people.set_once set_config register register_once unregister opt_out_capturing has_opted_out_capturing opt_in_capturing reset isFeatureEnabled onFeatureFlags getFeatureFlag getFeatureFlagPayload reloadFeatureFlags group updateEarlyAccessFeatureEnrollment getEarlyAccessFeatures getActiveMatchingSurveys getSurveys".split(" "),n=0;n<o.length;n++)g(u,o[n]);e._i.push([i,s,a])},e.__SV=1)}(document,window.posthog||[]);
|
||||||
|
posthog.init('phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP',{api_host:'https://app.posthog.com'})
|
||||||
|
});
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
In a recommendation system or search engine, you can find similar products from
|
In a recommendation system or search engine, you can find similar products from
|
||||||
the one you searched.
|
the one you searched.
|
||||||
In LLM and other AI applications,
|
In LLM and other AI applications,
|
||||||
each data point can be [presented by the embeddings generated from some models](embedding.md),
|
each data point can be [presented by the embeddings generated from some models](embeddings/index.md),
|
||||||
it returns the most relevant features.
|
it returns the most relevant features.
|
||||||
|
|
||||||
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
||||||
|
|||||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"lockfileVersion": 2,
|
"lockfileVersion": 2,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -53,11 +53,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.3.1",
|
"@lancedb/vectordb-darwin-arm64": "0.3.3",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.3.1",
|
"@lancedb/vectordb-darwin-x64": "0.3.3",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.1",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.3",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.1",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.3",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.1"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.3"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@apache-arrow/ts": {
|
"node_modules/@apache-arrow/ts": {
|
||||||
@@ -317,9 +317,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.3.tgz",
|
||||||
"integrity": "sha512-h3yUP249xaO3rrRuVC4oRxEm5/9T66CGKiI8OwYCJUOEFrfz/jj+6PK8geMn7IqbPnOY9YRPSEi/Cc3EdFd6Sg==",
|
"integrity": "sha512-nvyj7xNX2/wb/PH5TjyhLR/NQ1jVuoBw2B5UaSg7qf8Tnm5SSXWQ7F25RVKcKwh72fz1qB+CWW24ftZnRzbT/Q==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -329,9 +329,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.3.tgz",
|
||||||
"integrity": "sha512-SQ32iMMVfvjXgvFGSGdsXcSnVDypR6eE06d7VIXsuKAg6P9e1XUhB4YcsHGeAEEv3gEoUSgsljo92ZvXJcWouQ==",
|
"integrity": "sha512-7CW+nILyPHp6cua0Rl0xaTDWw/vajEn/jCsEjFYgDmE+rtf5Z5Fum41FxR9C2TtIAvUK+nWb5mkYeOLqU6vRvg==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -341,9 +341,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.3.tgz",
|
||||||
"integrity": "sha512-+jk2nJnaIWTqcOAyix2y+ClLNM5ECIdwyHZp5KjDqOlP6Z7eb5V2Xsah0AFp8nX3BiRRvqj3zR3zi26D7OBnYw==",
|
"integrity": "sha512-MmhwbacKxZPkLwwOqysVY8mUb8lFoyFIPlYhSLV4xS1C8X4HWALljIul1qMl1RYudp9Uc3PsOzRexl+OvCGfUw==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -353,9 +353,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.3.tgz",
|
||||||
"integrity": "sha512-I42Zf2lH8SUZLLYDDG4kzZ8iPq2wf1cXMh9iKNiLwgl5BnRsZVQ5A5k0uCX7IV7FcnHL/febKOxixXQyoKNAzw==",
|
"integrity": "sha512-OrNlsKi/QPw59Po040oRKn8IuqFEk4upc/4FaFKqVkcmQjjZrMg5Kgy9ZfWIhHdAnWXXggZZIPArpt0X1B0ceA==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -365,9 +365,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.3.tgz",
|
||||||
"integrity": "sha512-3OBS+fc4kcwhkqIy5b2Nump/iYoAgQd6gmYIJux3LJbMCc4yDcPJdFGVQkWu43JfBh7YOWPfOng2NSCUDBGmoA==",
|
"integrity": "sha512-lIT0A7a6eqX51IfGyhECtpXXgsr//kgbd+HZbcCdPy2GMmNezSch/7V22zExDSpF32hX8WfgcTLYCVWVilggDQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -4869,33 +4869,33 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-darwin-arm64": {
|
"@lancedb/vectordb-darwin-arm64": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.3.tgz",
|
||||||
"integrity": "sha512-h3yUP249xaO3rrRuVC4oRxEm5/9T66CGKiI8OwYCJUOEFrfz/jj+6PK8geMn7IqbPnOY9YRPSEi/Cc3EdFd6Sg==",
|
"integrity": "sha512-nvyj7xNX2/wb/PH5TjyhLR/NQ1jVuoBw2B5UaSg7qf8Tnm5SSXWQ7F25RVKcKwh72fz1qB+CWW24ftZnRzbT/Q==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-darwin-x64": {
|
"@lancedb/vectordb-darwin-x64": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.3.tgz",
|
||||||
"integrity": "sha512-SQ32iMMVfvjXgvFGSGdsXcSnVDypR6eE06d7VIXsuKAg6P9e1XUhB4YcsHGeAEEv3gEoUSgsljo92ZvXJcWouQ==",
|
"integrity": "sha512-7CW+nILyPHp6cua0Rl0xaTDWw/vajEn/jCsEjFYgDmE+rtf5Z5Fum41FxR9C2TtIAvUK+nWb5mkYeOLqU6vRvg==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.3.tgz",
|
||||||
"integrity": "sha512-+jk2nJnaIWTqcOAyix2y+ClLNM5ECIdwyHZp5KjDqOlP6Z7eb5V2Xsah0AFp8nX3BiRRvqj3zR3zi26D7OBnYw==",
|
"integrity": "sha512-MmhwbacKxZPkLwwOqysVY8mUb8lFoyFIPlYhSLV4xS1C8X4HWALljIul1qMl1RYudp9Uc3PsOzRexl+OvCGfUw==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-linux-x64-gnu": {
|
"@lancedb/vectordb-linux-x64-gnu": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.3.tgz",
|
||||||
"integrity": "sha512-I42Zf2lH8SUZLLYDDG4kzZ8iPq2wf1cXMh9iKNiLwgl5BnRsZVQ5A5k0uCX7IV7FcnHL/febKOxixXQyoKNAzw==",
|
"integrity": "sha512-OrNlsKi/QPw59Po040oRKn8IuqFEk4upc/4FaFKqVkcmQjjZrMg5Kgy9ZfWIhHdAnWXXggZZIPArpt0X1B0ceA==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-win32-x64-msvc": {
|
"@lancedb/vectordb-win32-x64-msvc": {
|
||||||
"version": "0.3.1",
|
"version": "0.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.3.tgz",
|
||||||
"integrity": "sha512-3OBS+fc4kcwhkqIy5b2Nump/iYoAgQd6gmYIJux3LJbMCc4yDcPJdFGVQkWu43JfBh7YOWPfOng2NSCUDBGmoA==",
|
"integrity": "sha512-lIT0A7a6eqX51IfGyhECtpXXgsr//kgbd+HZbcCdPy2GMmNezSch/7V22zExDSpF32hX8WfgcTLYCVWVilggDQ==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
"@neon-rs/cli": {
|
"@neon-rs/cli": {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.2",
|
"version": "0.3.3",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -81,10 +81,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.3.2",
|
"@lancedb/vectordb-darwin-arm64": "0.3.3",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.3.2",
|
"@lancedb/vectordb-darwin-x64": "0.3.3",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.2",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.3",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.2",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.3",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.2"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
|||||||
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
||||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
if (err != null) throw err
|
if (err != null) throw err
|
||||||
// there should be two dirs
|
// there should be three dirs
|
||||||
assert.equal(files.length, 2)
|
assert.equal(files.length, 3)
|
||||||
assert.isTrue(files[0].isDirectory())
|
assert.isTrue(files[0].isDirectory())
|
||||||
assert.isTrue(files[1].isDirectory())
|
assert.isTrue(files[1].isDirectory())
|
||||||
|
|
||||||
@@ -76,6 +76,12 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
|||||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.manifest'))
|
||||||
|
})
|
||||||
|
|
||||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||||
if (err != null) throw err
|
if (err != null) throw err
|
||||||
assert.equal(files.length, 1)
|
assert.equal(files.length, 1)
|
||||||
@@ -88,8 +94,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
|||||||
|
|
||||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
if (err != null) throw err
|
if (err != null) throw err
|
||||||
// there should be two dirs
|
// there should be four dirs
|
||||||
assert.equal(files.length, 3)
|
assert.equal(files.length, 4)
|
||||||
assert.isTrue(files[0].isDirectory())
|
assert.isTrue(files[0].isDirectory())
|
||||||
assert.isTrue(files[1].isDirectory())
|
assert.isTrue(files[1].isDirectory())
|
||||||
assert.isTrue(files[2].isDirectory())
|
assert.isTrue(files[2].isDirectory())
|
||||||
@@ -128,12 +134,13 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
|||||||
|
|
||||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
if (err != null) throw err
|
if (err != null) throw err
|
||||||
// there should be two dirs
|
// there should be five dirs
|
||||||
assert.equal(files.length, 4)
|
assert.equal(files.length, 5)
|
||||||
assert.isTrue(files[0].isDirectory())
|
assert.isTrue(files[0].isDirectory())
|
||||||
assert.isTrue(files[1].isDirectory())
|
assert.isTrue(files[1].isDirectory())
|
||||||
assert.isTrue(files[2].isDirectory())
|
assert.isTrue(files[2].isDirectory())
|
||||||
assert.isTrue(files[3].isDirectory())
|
assert.isTrue(files[3].isDirectory())
|
||||||
|
assert.isTrue(files[4].isDirectory())
|
||||||
|
|
||||||
// Three TXs now
|
// Three TXs now
|
||||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.1
|
current_version = 0.3.2
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -11,16 +11,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||||
from .cohere import CohereEmbeddingFunction
|
from .cohere import CohereEmbeddingFunction
|
||||||
from .functions import (
|
from .open_clip import OpenClipEmbeddings
|
||||||
EmbeddingFunction,
|
from .openai import OpenAIEmbeddings
|
||||||
EmbeddingFunctionConfig,
|
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||||
EmbeddingFunctionRegistry,
|
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||||
OpenAIEmbeddings,
|
|
||||||
OpenClipEmbeddings,
|
|
||||||
SentenceTransformerEmbeddings,
|
|
||||||
TextEmbeddingFunction,
|
|
||||||
register,
|
|
||||||
)
|
|
||||||
from .utils import with_embeddings
|
from .utils import with_embeddings
|
||||||
|
|||||||
138
python/lancedb/embeddings/base.py
Normal file
138
python/lancedb/embeddings/base.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import importlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
|
from .utils import TEXT
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunction(BaseModel, ABC):
|
||||||
|
"""
|
||||||
|
An ABC for embedding functions.
|
||||||
|
|
||||||
|
All concrete embedding functions must implement the following:
|
||||||
|
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||||
|
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||||
|
For text data, the two will be the same. For multi-modal data, the source column
|
||||||
|
might be images and the vector column might be text.
|
||||||
|
3. ndims method which returns the number of dimensions of the vector column
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ndims: int = PrivateAttr()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, **kwargs):
|
||||||
|
"""
|
||||||
|
Create an instance of the embedding function
|
||||||
|
"""
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for the source column in the database
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
elif isinstance(texts, pa.Array):
|
||||||
|
texts = texts.to_pylist()
|
||||||
|
elif isinstance(texts, pa.ChunkedArray):
|
||||||
|
texts = texts.combine_chunks().to_pylist()
|
||||||
|
return texts
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def safe_import(cls, module: str, mitigation=None):
|
||||||
|
"""
|
||||||
|
Import the specified module. If the module is not installed,
|
||||||
|
raise an ImportError with a helpful message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
module : str
|
||||||
|
The name of the module to import
|
||||||
|
mitigation : Optional[str]
|
||||||
|
The package(s) to install to mitigate the error.
|
||||||
|
If not provided then the module name will be used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(f"Please install {mitigation or module}")
|
||||||
|
|
||||||
|
def safe_model_dump(self):
|
||||||
|
from ..pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
return dict(self)
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def ndims(self):
|
||||||
|
"""
|
||||||
|
Return the dimensions of the vector column
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def SourceField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the source column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
def VectorField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the target vector column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunctionConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
This model encapsulates the configuration for a embedding function
|
||||||
|
in a lancedb table. It holds the embedding function, the source column,
|
||||||
|
and the vector column
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_column: str
|
||||||
|
source_column: str
|
||||||
|
function: EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingFunction(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
A callable ABC for embedding functions that take text as input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
return self.generate_embeddings(texts)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Generate the embeddings for the given texts
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -16,7 +16,8 @@ from typing import ClassVar, List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .functions import TextEmbeddingFunction, register
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
from .utils import api_key_not_found_help
|
from .utils import api_key_not_found_help
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,578 +0,0 @@
|
|||||||
# Copyright (c) 2023. LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import concurrent.futures
|
|
||||||
import importlib
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import urllib.error
|
|
||||||
import urllib.parse as urlparse
|
|
||||||
import urllib.request
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pyarrow as pa
|
|
||||||
from cachetools import cached
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionRegistry:
|
|
||||||
"""
|
|
||||||
This is a singleton class used to register embedding functions
|
|
||||||
and fetch them by name. It also handles serializing and deserializing.
|
|
||||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
|
||||||
or TextEmbeddingFunction and registering it with the registry.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
|
||||||
>>> @registry.register("my-embedding-function")
|
|
||||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
|
||||||
... def ndims(self) -> int:
|
|
||||||
... return 128
|
|
||||||
...
|
|
||||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
|
||||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
|
||||||
...
|
|
||||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
|
||||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
|
||||||
...
|
|
||||||
>>> registry.get("my-embedding-function")
|
|
||||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
return __REGISTRY__
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._functions = {}
|
|
||||||
|
|
||||||
def register(self, alias: str = None):
|
|
||||||
"""
|
|
||||||
This creates a decorator that can be used to register
|
|
||||||
an EmbeddingFunction.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
alias : Optional[str]
|
|
||||||
a human friendly name for the embedding function. If not
|
|
||||||
provided, the class name will be used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# This is a decorator for a class that inherits from BaseModel
|
|
||||||
# It adds the class to the registry
|
|
||||||
def decorator(cls):
|
|
||||||
if not issubclass(cls, EmbeddingFunction):
|
|
||||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
|
||||||
if cls.__name__ in self._functions:
|
|
||||||
raise KeyError(f"{cls.__name__} was already registered")
|
|
||||||
key = alias or cls.__name__
|
|
||||||
self._functions[key] = cls
|
|
||||||
cls.__embedding_function_registry_alias__ = alias
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""
|
|
||||||
Reset the registry to its initial state
|
|
||||||
"""
|
|
||||||
self._functions = {}
|
|
||||||
|
|
||||||
def get(self, name: str):
|
|
||||||
"""
|
|
||||||
Fetch an embedding function class by name
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The name of the embedding function to fetch
|
|
||||||
Either the alias or the class name if no alias was provided
|
|
||||||
during registration
|
|
||||||
"""
|
|
||||||
return self._functions[name]
|
|
||||||
|
|
||||||
def parse_functions(
|
|
||||||
self, metadata: Optional[Dict[bytes, bytes]]
|
|
||||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
|
||||||
"""
|
|
||||||
Parse the metadata from an arrow table and
|
|
||||||
return a mapping of the vector column to the
|
|
||||||
embedding function and source column
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
metadata : Optional[Dict[bytes, bytes]]
|
|
||||||
The metadata from an arrow table. Note that
|
|
||||||
the keys and values are bytes (pyarrow api)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
functions : dict
|
|
||||||
A mapping of vector column name to embedding function.
|
|
||||||
An empty dict is returned if input is None or does not
|
|
||||||
contain b"embedding_functions".
|
|
||||||
"""
|
|
||||||
if metadata is None or b"embedding_functions" not in metadata:
|
|
||||||
return {}
|
|
||||||
serialized = metadata[b"embedding_functions"]
|
|
||||||
raw_list = json.loads(serialized.decode("utf-8"))
|
|
||||||
return {
|
|
||||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
|
||||||
vector_column=obj["vector_column"],
|
|
||||||
source_column=obj["source_column"],
|
|
||||||
function=self.get(obj["name"])(**obj["model"]),
|
|
||||||
)
|
|
||||||
for obj in raw_list
|
|
||||||
}
|
|
||||||
|
|
||||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
|
||||||
"""
|
|
||||||
Convert the given embedding function and source / vector column configs
|
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
|
||||||
"""
|
|
||||||
func = conf.function
|
|
||||||
name = getattr(
|
|
||||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
|
||||||
)
|
|
||||||
json_data = func.safe_model_dump()
|
|
||||||
return {
|
|
||||||
"name": name,
|
|
||||||
"model": json_data,
|
|
||||||
"source_column": conf.source_column,
|
|
||||||
"vector_column": conf.vector_column,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_table_metadata(self, func_list):
|
|
||||||
"""
|
|
||||||
Convert a list of embedding functions and source / vector configs
|
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
|
||||||
"""
|
|
||||||
if func_list is None or len(func_list) == 0:
|
|
||||||
return None
|
|
||||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
|
||||||
# Note that metadata dictionary values must be bytes
|
|
||||||
# so we need to json dump then utf8 encode
|
|
||||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
|
||||||
return {"embedding_functions": metadata}
|
|
||||||
|
|
||||||
|
|
||||||
# Global instance
|
|
||||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
||||||
IMAGES = Union[
|
|
||||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunction(BaseModel, ABC):
|
|
||||||
"""
|
|
||||||
An ABC for embedding functions.
|
|
||||||
|
|
||||||
All concrete embedding functions must implement the following:
|
|
||||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
|
||||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
|
||||||
For text data, the two will be the same. For multi-modal data, the source column
|
|
||||||
might be images and the vector column might be text.
|
|
||||||
3. ndims method which returns the number of dimensions of the vector column
|
|
||||||
"""
|
|
||||||
|
|
||||||
_ndims: int = PrivateAttr()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, **kwargs):
|
|
||||||
"""
|
|
||||||
Create an instance of the embedding function
|
|
||||||
"""
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for a given user query
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for the source column in the database
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
|
||||||
"""
|
|
||||||
Sanitize the input to the embedding function.
|
|
||||||
"""
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
elif isinstance(texts, pa.Array):
|
|
||||||
texts = texts.to_pylist()
|
|
||||||
elif isinstance(texts, pa.ChunkedArray):
|
|
||||||
texts = texts.combine_chunks().to_pylist()
|
|
||||||
return texts
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safe_import(cls, module: str, mitigation=None):
|
|
||||||
"""
|
|
||||||
Import the specified module. If the module is not installed,
|
|
||||||
raise an ImportError with a helpful message.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
module : str
|
|
||||||
The name of the module to import
|
|
||||||
mitigation : Optional[str]
|
|
||||||
The package(s) to install to mitigate the error.
|
|
||||||
If not provided then the module name will be used.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return importlib.import_module(module)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(f"Please install {mitigation or module}")
|
|
||||||
|
|
||||||
def safe_model_dump(self):
|
|
||||||
from ..pydantic import PYDANTIC_VERSION
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
|
||||||
return dict(self)
|
|
||||||
return self.model_dump()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def ndims(self):
|
|
||||||
"""
|
|
||||||
Return the dimensions of the vector column
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def SourceField(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Creates a pydantic Field that can automatically annotate
|
|
||||||
the source column for this embedding function
|
|
||||||
"""
|
|
||||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
|
||||||
|
|
||||||
def VectorField(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Creates a pydantic Field that can automatically annotate
|
|
||||||
the target vector column for this embedding function
|
|
||||||
"""
|
|
||||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
This model encapsulates the configuration for a embedding function
|
|
||||||
in a lancedb table. It holds the embedding function, the source column,
|
|
||||||
and the vector column
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_column: str
|
|
||||||
source_column: str
|
|
||||||
function: EmbeddingFunction
|
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingFunction(EmbeddingFunction):
|
|
||||||
"""
|
|
||||||
A callable ABC for embedding functions that take text as input
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
|
||||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
|
||||||
|
|
||||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
|
||||||
texts = self.sanitize_input(texts)
|
|
||||||
return self.generate_embeddings(texts)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Generate the embeddings for the given texts
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
|
||||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
|
||||||
|
|
||||||
|
|
||||||
@register("sentence-transformers")
|
|
||||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the sentence-transformers library
|
|
||||||
|
|
||||||
https://huggingface.co/sentence-transformers
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "all-MiniLM-L6-v2"
|
|
||||||
device: str = "cpu"
|
|
||||||
normalize: bool = True
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._ndims = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_model(self):
|
|
||||||
"""
|
|
||||||
Get the sentence-transformers embedding model specified by the
|
|
||||||
name and device. This is cached so that the model is only loaded
|
|
||||||
once per process.
|
|
||||||
"""
|
|
||||||
return self.__class__.get_embedding_model(self.name, self.device)
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
if self._ndims is None:
|
|
||||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
|
||||||
return self._ndims
|
|
||||||
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given texts
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
texts: list[str] or np.ndarray (of str)
|
|
||||||
The texts to embed
|
|
||||||
"""
|
|
||||||
return self.embedding_model.encode(
|
|
||||||
list(texts),
|
|
||||||
convert_to_numpy=True,
|
|
||||||
normalize_embeddings=self.normalize,
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@cached(cache={})
|
|
||||||
def get_embedding_model(cls, name, device):
|
|
||||||
"""
|
|
||||||
Get the sentence-transformers embedding model specified by the
|
|
||||||
name and device. This is cached so that the model is only loaded
|
|
||||||
once per process.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The name of the model to load
|
|
||||||
device : str
|
|
||||||
The device to load the model on
|
|
||||||
|
|
||||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
|
||||||
"""
|
|
||||||
sentence_transformers = cls.safe_import(
|
|
||||||
"sentence_transformers", "sentence-transformers"
|
|
||||||
)
|
|
||||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@register("openai")
|
|
||||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the OpenAI API
|
|
||||||
|
|
||||||
https://platform.openai.com/docs/guides/embeddings
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
# TODO don't hardcode this
|
|
||||||
return 1536
|
|
||||||
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given texts
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
texts: list[str] or np.ndarray (of str)
|
|
||||||
The texts to embed
|
|
||||||
"""
|
|
||||||
# TODO retry, rate limit, token limit
|
|
||||||
openai = self.safe_import("openai")
|
|
||||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
|
||||||
return [v["embedding"] for v in rs]
|
|
||||||
|
|
||||||
|
|
||||||
@register("open-clip")
|
|
||||||
class OpenClipEmbeddings(EmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the OpenClip API
|
|
||||||
For multi-modal text-to-image search
|
|
||||||
|
|
||||||
https://github.com/mlfoundations/open_clip
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "ViT-B-32"
|
|
||||||
pretrained: str = "laion2b_s34b_b79k"
|
|
||||||
device: str = "cpu"
|
|
||||||
batch_size: int = 64
|
|
||||||
normalize: bool = True
|
|
||||||
_model = PrivateAttr()
|
|
||||||
_preprocess = PrivateAttr()
|
|
||||||
_tokenizer = PrivateAttr()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
open_clip = self.safe_import("open_clip", "open-clip")
|
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
|
||||||
self.name, pretrained=self.pretrained
|
|
||||||
)
|
|
||||||
model.to(self.device)
|
|
||||||
self._model, self._preprocess = model, preprocess
|
|
||||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
|
||||||
self._ndims = None
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
if self._ndims is None:
|
|
||||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
|
||||||
return self._ndims
|
|
||||||
|
|
||||||
def compute_query_embeddings(
|
|
||||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
|
||||||
) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for a given user query
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query : Union[str, PIL.Image.Image]
|
|
||||||
The query to embed. A query can be either text or an image.
|
|
||||||
"""
|
|
||||||
if isinstance(query, str):
|
|
||||||
return [self.generate_text_embeddings(query)]
|
|
||||||
else:
|
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
|
||||||
if isinstance(query, PIL.Image.Image):
|
|
||||||
return [self.generate_image_embedding(query)]
|
|
||||||
else:
|
|
||||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
|
||||||
|
|
||||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
|
||||||
torch = self.safe_import("torch")
|
|
||||||
text = self.sanitize_input(text)
|
|
||||||
text = self._tokenizer(text)
|
|
||||||
text.to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
text_features = self._model.encode_text(text.to(self.device))
|
|
||||||
if self.normalize:
|
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
||||||
return text_features.cpu().numpy().squeeze()
|
|
||||||
|
|
||||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
|
||||||
"""
|
|
||||||
Sanitize the input to the embedding function.
|
|
||||||
"""
|
|
||||||
if isinstance(images, (str, bytes)):
|
|
||||||
images = [images]
|
|
||||||
elif isinstance(images, pa.Array):
|
|
||||||
images = images.to_pylist()
|
|
||||||
elif isinstance(images, pa.ChunkedArray):
|
|
||||||
images = images.combine_chunks().to_pylist()
|
|
||||||
return images
|
|
||||||
|
|
||||||
def compute_source_embeddings(
|
|
||||||
self, images: IMAGES, *args, **kwargs
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given images
|
|
||||||
"""
|
|
||||||
images = self.sanitize_input(images)
|
|
||||||
embeddings = []
|
|
||||||
for i in range(0, len(images), self.batch_size):
|
|
||||||
j = min(i + self.batch_size, len(images))
|
|
||||||
batch = images[i:j]
|
|
||||||
embeddings.extend(self._parallel_get(batch))
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Issue concurrent requests to retrieve the image data
|
|
||||||
"""
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(self.generate_image_embedding, image)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
return [future.result() for future in tqdm(futures)]
|
|
||||||
|
|
||||||
def generate_image_embedding(
|
|
||||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generate the embedding for a single image
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
image : Union[str, bytes, PIL.Image.Image]
|
|
||||||
The image to embed. If the image is a str, it is treated as a uri.
|
|
||||||
If the image is bytes, it is treated as the raw image bytes.
|
|
||||||
"""
|
|
||||||
torch = self.safe_import("torch")
|
|
||||||
# TODO handle retry and errors for https
|
|
||||||
image = self._to_pil(image)
|
|
||||||
image = self._preprocess(image).unsqueeze(0)
|
|
||||||
with torch.no_grad():
|
|
||||||
return self._encode_and_normalize_image(image)
|
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes]):
|
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
|
||||||
if isinstance(image, bytes):
|
|
||||||
return PIL.Image.open(io.BytesIO(image))
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
elif isinstance(image, str):
|
|
||||||
parsed = urlparse.urlparse(image)
|
|
||||||
# TODO handle drive letter on windows.
|
|
||||||
if parsed.scheme == "file":
|
|
||||||
return PIL.Image.open(parsed.path)
|
|
||||||
elif parsed.scheme == "":
|
|
||||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
|
||||||
elif parsed.scheme.startswith("http"):
|
|
||||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
|
||||||
|
|
||||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
|
||||||
"""
|
|
||||||
encode a single image tensor and optionally normalize the output
|
|
||||||
"""
|
|
||||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
|
||||||
if self.normalize:
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
||||||
return image_features.cpu().numpy().squeeze()
|
|
||||||
|
|
||||||
|
|
||||||
def url_retrieve(url: str):
|
|
||||||
"""
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
url: str
|
|
||||||
URL to download from
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(url) as conn:
|
|
||||||
return conn.read()
|
|
||||||
except (socket.gaierror, urllib.error.URLError) as err:
|
|
||||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
|
||||||
163
python/lancedb/embeddings/open_clip.py
Normal file
163
python/lancedb/embeddings/open_clip.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib.parse as urlparse
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .base import EmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import IMAGES, url_retrieve
|
||||||
|
|
||||||
|
|
||||||
|
@register("open-clip")
|
||||||
|
class OpenClipEmbeddings(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenClip API
|
||||||
|
For multi-modal text-to-image search
|
||||||
|
|
||||||
|
https://github.com/mlfoundations/open_clip
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "ViT-B-32"
|
||||||
|
pretrained: str = "laion2b_s34b_b79k"
|
||||||
|
device: str = "cpu"
|
||||||
|
batch_size: int = 64
|
||||||
|
normalize: bool = True
|
||||||
|
_model = PrivateAttr()
|
||||||
|
_preprocess = PrivateAttr()
|
||||||
|
_tokenizer = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
open_clip = self.safe_import("open_clip", "open-clip")
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
|
self.name, pretrained=self.pretrained
|
||||||
|
)
|
||||||
|
model.to(self.device)
|
||||||
|
self._model, self._preprocess = model, preprocess
|
||||||
|
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : Union[str, PIL.Image.Image]
|
||||||
|
The query to embed. A query can be either text or an image.
|
||||||
|
"""
|
||||||
|
if isinstance(query, str):
|
||||||
|
return [self.generate_text_embeddings(query)]
|
||||||
|
else:
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(query, PIL.Image.Image):
|
||||||
|
return [self.generate_image_embedding(query)]
|
||||||
|
else:
|
||||||
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
text = self.sanitize_input(text)
|
||||||
|
text = self._tokenizer(text)
|
||||||
|
text.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self._model.encode_text(text.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
return text_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, bytes)):
|
||||||
|
images = [images]
|
||||||
|
elif isinstance(images, pa.Array):
|
||||||
|
images = images.to_pylist()
|
||||||
|
elif isinstance(images, pa.ChunkedArray):
|
||||||
|
images = images.combine_chunks().to_pylist()
|
||||||
|
return images
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, images: IMAGES, *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given images
|
||||||
|
"""
|
||||||
|
images = self.sanitize_input(images)
|
||||||
|
embeddings = []
|
||||||
|
for i in range(0, len(images), self.batch_size):
|
||||||
|
j = min(i + self.batch_size, len(images))
|
||||||
|
batch = images[i:j]
|
||||||
|
embeddings.extend(self._parallel_get(batch))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Issue concurrent requests to retrieve the image data
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
return [future.result() for future in tqdm(futures)]
|
||||||
|
|
||||||
|
def generate_image_embedding(
|
||||||
|
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate the embedding for a single image
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image : Union[str, bytes, PIL.Image.Image]
|
||||||
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
|
"""
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
# TODO handle retry and errors for https
|
||||||
|
image = self._to_pil(image)
|
||||||
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
elif isinstance(image, str):
|
||||||
|
parsed = urlparse.urlparse(image)
|
||||||
|
# TODO handle drive letter on windows.
|
||||||
|
if parsed.scheme == "file":
|
||||||
|
return PIL.Image.open(parsed.path)
|
||||||
|
elif parsed.scheme == "":
|
||||||
|
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||||
|
elif parsed.scheme.startswith("http"):
|
||||||
|
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
|
|
||||||
|
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
encode a single image tensor and optionally normalize the output
|
||||||
|
"""
|
||||||
|
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features.cpu().numpy().squeeze()
|
||||||
37
python/lancedb/embeddings/openai.py
Normal file
37
python/lancedb/embeddings/openai.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
|
||||||
|
|
||||||
|
@register("openai")
|
||||||
|
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenAI API
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
# TODO don't hardcode this
|
||||||
|
return 1536
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
# TODO retry, rate limit, token limit
|
||||||
|
openai = self.safe_import("openai")
|
||||||
|
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||||
|
return [v["embedding"] for v in rs]
|
||||||
186
python/lancedb/embeddings/registry.py
Normal file
186
python/lancedb/embeddings/registry.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunctionRegistry:
|
||||||
|
"""
|
||||||
|
This is a singleton class used to register embedding functions
|
||||||
|
and fetch them by name. It also handles serializing and deserializing.
|
||||||
|
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||||
|
or TextEmbeddingFunction and registering it with the registry.
|
||||||
|
|
||||||
|
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
>>> @registry.register("my-embedding-function")
|
||||||
|
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||||
|
... def ndims(self) -> int:
|
||||||
|
... return 128
|
||||||
|
...
|
||||||
|
... def compute_query_embeddings(self, query: str, *args, **kwargs):
|
||||||
|
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
...
|
||||||
|
... def compute_source_embeddings(self, texts, *args, **kwargs):
|
||||||
|
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||||
|
...
|
||||||
|
>>> registry.get("my-embedding-function")
|
||||||
|
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
return __REGISTRY__
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._functions = {}
|
||||||
|
|
||||||
|
def register(self, alias: str = None):
|
||||||
|
"""
|
||||||
|
This creates a decorator that can be used to register
|
||||||
|
an EmbeddingFunction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
alias : Optional[str]
|
||||||
|
a human friendly name for the embedding function. If not
|
||||||
|
provided, the class name will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This is a decorator for a class that inherits from BaseModel
|
||||||
|
# It adds the class to the registry
|
||||||
|
def decorator(cls):
|
||||||
|
if not issubclass(cls, EmbeddingFunction):
|
||||||
|
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||||
|
if cls.__name__ in self._functions:
|
||||||
|
raise KeyError(f"{cls.__name__} was already registered")
|
||||||
|
key = alias or cls.__name__
|
||||||
|
self._functions[key] = cls
|
||||||
|
cls.__embedding_function_registry_alias__ = alias
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the registry to its initial state
|
||||||
|
"""
|
||||||
|
self._functions = {}
|
||||||
|
|
||||||
|
def get(self, name: str):
|
||||||
|
"""
|
||||||
|
Fetch an embedding function class by name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the embedding function to fetch
|
||||||
|
Either the alias or the class name if no alias was provided
|
||||||
|
during registration
|
||||||
|
"""
|
||||||
|
return self._functions[name]
|
||||||
|
|
||||||
|
def parse_functions(
|
||||||
|
self, metadata: Optional[Dict[bytes, bytes]]
|
||||||
|
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||||
|
"""
|
||||||
|
Parse the metadata from an arrow table and
|
||||||
|
return a mapping of the vector column to the
|
||||||
|
embedding function and source column
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metadata : Optional[Dict[bytes, bytes]]
|
||||||
|
The metadata from an arrow table. Note that
|
||||||
|
the keys and values are bytes (pyarrow api)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
functions : dict
|
||||||
|
A mapping of vector column name to embedding function.
|
||||||
|
An empty dict is returned if input is None or does not
|
||||||
|
contain b"embedding_functions".
|
||||||
|
"""
|
||||||
|
if metadata is None or b"embedding_functions" not in metadata:
|
||||||
|
return {}
|
||||||
|
serialized = metadata[b"embedding_functions"]
|
||||||
|
raw_list = json.loads(serialized.decode("utf-8"))
|
||||||
|
return {
|
||||||
|
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||||
|
vector_column=obj["vector_column"],
|
||||||
|
source_column=obj["source_column"],
|
||||||
|
function=self.get(obj["name"])(**obj["model"]),
|
||||||
|
)
|
||||||
|
for obj in raw_list
|
||||||
|
}
|
||||||
|
|
||||||
|
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||||
|
"""
|
||||||
|
Convert the given embedding function and source / vector column configs
|
||||||
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
|
"""
|
||||||
|
func = conf.function
|
||||||
|
name = getattr(
|
||||||
|
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||||
|
)
|
||||||
|
json_data = func.safe_model_dump()
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"model": json_data,
|
||||||
|
"source_column": conf.source_column,
|
||||||
|
"vector_column": conf.vector_column,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_table_metadata(self, func_list):
|
||||||
|
"""
|
||||||
|
Convert a list of embedding functions and source / vector configs
|
||||||
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
|
"""
|
||||||
|
if func_list is None or len(func_list) == 0:
|
||||||
|
return None
|
||||||
|
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||||
|
# Note that metadata dictionary values must be bytes
|
||||||
|
# so we need to json dump then utf8 encode
|
||||||
|
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||||
|
return {"embedding_functions": metadata}
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||||
|
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry():
|
||||||
|
"""
|
||||||
|
Utility function to get the global instance of the registry
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
EmbeddingFunctionRegistry
|
||||||
|
The global registry instance
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
openai = registry.get("openai").create()
|
||||||
|
"""
|
||||||
|
return __REGISTRY__.get_instance()
|
||||||
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from cachetools import cached
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
|
||||||
|
|
||||||
|
@register("sentence-transformers")
|
||||||
|
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the sentence-transformers library
|
||||||
|
|
||||||
|
https://huggingface.co/sentence-transformers
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "all-MiniLM-L6-v2"
|
||||||
|
device: str = "cpu"
|
||||||
|
normalize: bool = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the sentence-transformers embedding model specified by the
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
"""
|
||||||
|
return self.__class__.get_embedding_model(self.name, self.device)
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
return self.embedding_model.encode(
|
||||||
|
list(texts),
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@cached(cache={})
|
||||||
|
def get_embedding_model(cls, name, device):
|
||||||
|
"""
|
||||||
|
Get the sentence-transformers embedding model specified by the
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the model to load
|
||||||
|
device : str
|
||||||
|
The device to load the model on
|
||||||
|
|
||||||
|
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||||
|
"""
|
||||||
|
sentence_transformers = cls.safe_import(
|
||||||
|
"sentence_transformers", "sentence-transformers"
|
||||||
|
)
|
||||||
|
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||||
@@ -12,8 +12,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, Union
|
import urllib.error
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -24,7 +26,12 @@ from ..util import safe_import_pandas
|
|||||||
from ..utils.general import LOGGER
|
from ..utils.general import LOGGER
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||||
|
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||||
|
IMAGES = Union[
|
||||||
|
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def with_embeddings(
|
def with_embeddings(
|
||||||
@@ -155,6 +162,20 @@ class FunctionWrapper:
|
|||||||
yield from _chunker(arr)
|
yield from _chunker(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def url_retrieve(url: str):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url: str
|
||||||
|
URL to download from
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url) as conn:
|
||||||
|
return conn.read()
|
||||||
|
except (socket.gaierror, urllib.error.URLError) as err:
|
||||||
|
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||||
|
|
||||||
|
|
||||||
def api_key_not_found_help(provider):
|
def api_key_not_found_help(provider):
|
||||||
LOGGER.error(f"Could not find API key for {provider}.")
|
LOGGER.error(f"Could not find API key for {provider}.")
|
||||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import inspect
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import date, datetime
|
||||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -159,6 +160,10 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
|||||||
return pa.bool_()
|
return pa.bool_()
|
||||||
elif py_type == bytes:
|
elif py_type == bytes:
|
||||||
return pa.binary()
|
return pa.binary()
|
||||||
|
elif py_type == date:
|
||||||
|
return pa.date32()
|
||||||
|
elif py_type == datetime:
|
||||||
|
return pa.timestamp("us")
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
||||||
)
|
)
|
||||||
@@ -322,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
for vec, func in vec_and_function:
|
for vec, func in vec_and_function:
|
||||||
for source, field_info in cls.safe_get_fields().items():
|
for source, field_info in cls.safe_get_fields().items():
|
||||||
src_func = get_extras(field_info, "source_column_for")
|
src_func = get_extras(field_info, "source_column_for")
|
||||||
if src_func == func:
|
if src_func is func:
|
||||||
|
# note we can't use == here since the function is a pydantic
|
||||||
|
# model so two instances of the same function are ==, so if you
|
||||||
|
# have multiple vector columns from multiple sources, both will
|
||||||
|
# be mapped to the same source column
|
||||||
|
# GH594
|
||||||
configs.append(
|
configs.append(
|
||||||
EmbeddingFunctionConfig(
|
EmbeddingFunctionConfig(
|
||||||
source_column=source, vector_column=vec, function=func
|
source_column=source, vector_column=vec, function=func
|
||||||
|
|||||||
@@ -151,10 +151,15 @@ class RestfulLanceDBClient:
|
|||||||
return await deserialize(resp)
|
return await deserialize(resp)
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def list_tables(self):
|
async def list_tables(self, limit: int, page_token: str):
|
||||||
"""List all tables in the database."""
|
"""List all tables in the database."""
|
||||||
json = await self.get("/v1/table/", {})
|
try:
|
||||||
return json["tables"]
|
json = await self.get(
|
||||||
|
"/v1/table/", {"limit": limit, "page_token": page_token}
|
||||||
|
)
|
||||||
|
return json["tables"]
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return []
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import Iterator, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -52,10 +52,27 @@ class RemoteDBConnection(DBConnection):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoveConnect(name={self.db_name})"
|
return f"RemoveConnect(name={self.db_name})"
|
||||||
|
|
||||||
def table_names(self) -> List[str]:
|
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
|
||||||
"""List the names of all tables in the database."""
|
"""List the names of all tables in the database.
|
||||||
result = self._loop.run_until_complete(self._client.list_tables())
|
Parameters
|
||||||
return result
|
----------
|
||||||
|
last_token: str
|
||||||
|
The last token to start the new page.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
An iterator of table names.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
result = self._loop.run_until_complete(
|
||||||
|
self._client.list_tables(limit, last_token)
|
||||||
|
)
|
||||||
|
if len(result) > 0:
|
||||||
|
last_token = result[len(result) - 1]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
for item in result:
|
||||||
|
yield result
|
||||||
|
|
||||||
def open_table(self, name: str) -> Table:
|
def open_table(self, name: str) -> Table:
|
||||||
"""Open a Lance Table in the database.
|
"""Open a Lance Table in the database.
|
||||||
@@ -122,3 +139,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
f"/v1/table/{name}/drop/",
|
f"/v1/table/{name}/drop/",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the connection to the database."""
|
||||||
|
self._loop.close()
|
||||||
|
await self._client.close()
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ from lance.dataset import CleanupStats, ReaderLike
|
|||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .embeddings.functions import EmbeddingFunctionConfig
|
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import fs_from_uri, safe_import_pandas
|
from .util import fs_from_uri, safe_import_pandas
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.8.5",
|
"pylance==0.8.7",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.1.0",
|
"tqdm>=4.1.0",
|
||||||
@@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
|||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"]
|
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
lancedb = "lancedb.cli.cli:cli"
|
lancedb = "lancedb.cli.cli:cli"
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
# These are integration tests for embedding functions.
|
# These are integration tests for embedding functions.
|
||||||
@@ -31,12 +31,15 @@ from lancedb.pydantic import LanceModel, Vector
|
|||||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||||
def test_sentence_transformer(alias, tmp_path):
|
def test_sentence_transformer(alias, tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = get_registry()
|
||||||
func = registry.get(alias).create()
|
func = registry.get(alias).create()
|
||||||
|
func2 = registry.get(alias).create()
|
||||||
|
|
||||||
class Words(LanceModel):
|
class Words(LanceModel):
|
||||||
text: str = func.SourceField()
|
text: str = func.SourceField()
|
||||||
|
text2: str = func2.SourceField()
|
||||||
vector: Vector(func.ndims()) = func.VectorField()
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
vector2: Vector(func2.ndims()) = func2.VectorField()
|
||||||
|
|
||||||
table = db.create_table("words", schema=Words)
|
table = db.create_table("words", schema=Words)
|
||||||
table.add(
|
table.add(
|
||||||
@@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
|
|||||||
"foo",
|
"foo",
|
||||||
"bar",
|
"bar",
|
||||||
"baz",
|
"baz",
|
||||||
]
|
],
|
||||||
|
"text2": [
|
||||||
|
"to be or not to be",
|
||||||
|
"that is the question",
|
||||||
|
"for whether tis nobler",
|
||||||
|
"in the mind to suffer",
|
||||||
|
"the slings and arrows",
|
||||||
|
"of outrageous fortune",
|
||||||
|
"or to take arms",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
|
|||||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||||
assert actual.text == expected.text
|
assert actual.text == expected.text
|
||||||
assert actual.text == "hello world"
|
assert actual.text == "hello world"
|
||||||
|
assert not np.allclose(actual.vector, actual.vector2)
|
||||||
|
|
||||||
|
actual = (
|
||||||
|
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
||||||
|
)
|
||||||
|
assert actual.text != "hello world"
|
||||||
|
assert not np.allclose(actual.vector, actual.vector2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@@ -69,7 +88,7 @@ def test_openclip(tmp_path):
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = get_registry()
|
||||||
func = registry.get("open-clip").create()
|
func = registry.get("open-clip").create()
|
||||||
|
|
||||||
class Images(LanceModel):
|
class Images(LanceModel):
|
||||||
@@ -131,11 +150,7 @@ def test_openclip(tmp_path):
|
|||||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||||
) # also skip if cohere not installed
|
) # also skip if cohere not installed
|
||||||
def test_cohere_embedding_function():
|
def test_cohere_embedding_function():
|
||||||
cohere = (
|
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
||||||
EmbeddingFunctionRegistry.get_instance()
|
|
||||||
.get("cohere")
|
|
||||||
.create(name="embed-multilingual-v2.0")
|
|
||||||
)
|
|
||||||
|
|
||||||
class TextModel(LanceModel):
|
class TextModel(LanceModel):
|
||||||
text: str = cohere.SourceField()
|
text: str = cohere.SourceField()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import date, datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -40,10 +41,18 @@ def test_pydantic_to_arrow():
|
|||||||
li: List[int]
|
li: List[int]
|
||||||
opt: Optional[str] = None
|
opt: Optional[str] = None
|
||||||
st: StructModel
|
st: StructModel
|
||||||
|
dt: date
|
||||||
|
dtt: datetime
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
id=1,
|
||||||
|
s="hello",
|
||||||
|
vec=[1.0, 2.0, 3.0],
|
||||||
|
li=[2, 3, 4],
|
||||||
|
st=StructModel(a="a", b=1.0),
|
||||||
|
dt=date.today(),
|
||||||
|
dtt=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -62,6 +71,8 @@ def test_pydantic_to_arrow():
|
|||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
pa.field("dt", pa.date32(), False),
|
||||||
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
@@ -79,10 +90,18 @@ def test_pydantic_to_arrow_py38():
|
|||||||
li: List[int]
|
li: List[int]
|
||||||
opt: Optional[str] = None
|
opt: Optional[str] = None
|
||||||
st: StructModel
|
st: StructModel
|
||||||
|
dt: date
|
||||||
|
dtt: datetime
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
id=1,
|
||||||
|
s="hello",
|
||||||
|
vec=[1.0, 2.0, 3.0],
|
||||||
|
li=[2, 3, 4],
|
||||||
|
st=StructModel(a="a", b=1.0),
|
||||||
|
dt=date.today(),
|
||||||
|
dtt=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -101,6 +120,8 @@ def test_pydantic_to_arrow_py38():
|
|||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
pa.field("dt", pa.date32(), False),
|
||||||
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.3.2"
|
version = "0.3.3"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.3.2"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ trait PrimaryOnly {
|
|||||||
|
|
||||||
impl PrimaryOnly for Path {
|
impl PrimaryOnly for Path {
|
||||||
fn primary_only(&self) -> bool {
|
fn primary_only(&self) -> bool {
|
||||||
self.to_string().contains("manifest")
|
self.filename().unwrap_or("") == "_latest.manifest"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,8 +118,10 @@ impl ObjectStore for MirroringObjectStore {
|
|||||||
self.primary.head(location).await
|
self.primary.head(location).await
|
||||||
}
|
}
|
||||||
|
|
||||||
// garbage collection on secondary will happen async from other means
|
|
||||||
async fn delete(&self, location: &Path) -> Result<()> {
|
async fn delete(&self, location: &Path) -> Result<()> {
|
||||||
|
if !location.primary_only() {
|
||||||
|
self.secondary.delete(location).await?;
|
||||||
|
}
|
||||||
self.primary.delete(location).await
|
self.primary.delete(location).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +134,7 @@ impl ObjectStore for MirroringObjectStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
|
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
|
||||||
if from.primary_only() {
|
if to.primary_only() {
|
||||||
self.primary.copy(from, to).await
|
self.primary.copy(from, to).await
|
||||||
} else {
|
} else {
|
||||||
self.secondary.copy(from, to).await?;
|
self.secondary.copy(from, to).await?;
|
||||||
@@ -142,6 +144,9 @@ impl ObjectStore for MirroringObjectStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
|
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
|
||||||
|
if !to.primary_only() {
|
||||||
|
self.secondary.copy(from, to).await?;
|
||||||
|
}
|
||||||
self.primary.copy_if_not_exists(from, to).await
|
self.primary.copy_if_not_exists(from, to).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -379,7 +384,7 @@ mod test {
|
|||||||
let primary_f = primary_elem.unwrap().unwrap();
|
let primary_f = primary_elem.unwrap().unwrap();
|
||||||
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
|
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
|
||||||
let primary_raw_path = primary_f.file_name().to_str().unwrap();
|
let primary_raw_path = primary_f.file_name().to_str().unwrap();
|
||||||
if primary_raw_path.contains("manifest") || primary_raw_path.contains("_versions") {
|
if primary_raw_path.contains("_latest.manifest") {
|
||||||
primary_elem = primary_iter.next();
|
primary_elem = primary_iter.next();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,6 +153,22 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn checkout_latest(&self) -> Result<Self> {
|
||||||
|
let latest_version_id = self.dataset.latest_version_id().await?;
|
||||||
|
let dataset = if latest_version_id == self.dataset.version().version {
|
||||||
|
self.dataset.clone()
|
||||||
|
} else {
|
||||||
|
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Table {
|
||||||
|
name: self.name.clone(),
|
||||||
|
uri: self.uri.clone(),
|
||||||
|
dataset,
|
||||||
|
store_wrapper: self.store_wrapper.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn get_table_name(uri: &str) -> Result<String> {
|
fn get_table_name(uri: &str) -> Result<String> {
|
||||||
let path = Path::new(uri);
|
let path = Path::new(uri);
|
||||||
let name = path
|
let name = path
|
||||||
|
|||||||
Reference in New Issue
Block a user