mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
Compare commits
6 Commits
v0.4.14
...
tuning/dat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99d1a06a44 | ||
|
|
f23641d703 | ||
|
|
e9e0a37ca8 | ||
|
|
c37a28abbd | ||
|
|
98c1e635b3 | ||
|
|
9992b927fd |
8
.github/workflows/docs_test.yml
vendored
8
.github/workflows/docs_test.yml
vendored
@@ -18,7 +18,7 @@ on:
|
||||
env:
|
||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||
# "1" means line tables only, which is useful for panic tracebacks.
|
||||
RUSTFLAGS: "-C debuginfo=1 -C target-cpu=native -C target-feature=+f16c,+avx2,+fma"
|
||||
RUSTFLAGS: "-C debuginfo=1 -C target-cpu=haswell -C target-feature=+f16c,+avx2,+fma"
|
||||
RUST_BACKTRACE: "1"
|
||||
|
||||
jobs:
|
||||
@@ -28,6 +28,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Print CPU capabilities
|
||||
run: cat /proc/cpuinfo
|
||||
- name: Install dependecies needed for ubuntu
|
||||
run: |
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
@@ -39,7 +41,7 @@ jobs:
|
||||
cache: "pip"
|
||||
cache-dependency-path: "docs/test/requirements.txt"
|
||||
- name: Rust cache
|
||||
uses: swatinem/rust-cache@v2
|
||||
uses: swatinem/rust-cache@v2
|
||||
- name: Build Python
|
||||
working-directory: docs/test
|
||||
run:
|
||||
@@ -64,6 +66,8 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Print CPU capabilities
|
||||
run: cat /proc/cpuinfo
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
|
||||
3
.github/workflows/node.yml
vendored
3
.github/workflows/node.yml
vendored
@@ -20,7 +20,8 @@ env:
|
||||
# "1" means line tables only, which is useful for panic tracebacks.
|
||||
#
|
||||
# Use native CPU to accelerate tests if possible, especially for f16
|
||||
RUSTFLAGS: "-C debuginfo=1 -C target-cpu=native -C target-feature=+f16c,+avx2,+fma"
|
||||
# target-cpu=haswell fixes failing ci build
|
||||
RUSTFLAGS: "-C debuginfo=1 -C target-cpu=haswell -C target-feature=+f16c,+avx2,+fma"
|
||||
RUST_BACKTRACE: "1"
|
||||
|
||||
jobs:
|
||||
|
||||
326
docs/mkdocs.yml
326
docs/mkdocs.yml
@@ -38,178 +38,180 @@ theme:
|
||||
custom_dir: overrides
|
||||
|
||||
plugins:
|
||||
- search
|
||||
- autorefs
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
paths: [../python]
|
||||
options:
|
||||
docstring_style: numpy
|
||||
heading_level: 4
|
||||
show_source: true
|
||||
show_symbol_type_in_heading: true
|
||||
show_signature_annotations: true
|
||||
members_order: source
|
||||
import:
|
||||
# for cross references
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- ultralytics:
|
||||
verbose: True
|
||||
enabled: True
|
||||
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
|
||||
add_image: True # Automatically add meta image
|
||||
add_keywords: True # Add page keywords in the header tag
|
||||
add_share_buttons: True # Add social share buttons
|
||||
add_authors: False # Display page authors
|
||||
add_desc: False
|
||||
add_dates: False
|
||||
- search
|
||||
- autorefs
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
paths: [../python]
|
||||
options:
|
||||
docstring_style: numpy
|
||||
heading_level: 3
|
||||
show_source: true
|
||||
show_symbol_type_in_heading: true
|
||||
show_signature_annotations: true
|
||||
show_root_heading: true
|
||||
members_order: source
|
||||
import:
|
||||
# for cross references
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- ultralytics:
|
||||
verbose: True
|
||||
enabled: True
|
||||
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
|
||||
add_image: True # Automatically add meta image
|
||||
add_keywords: True # Add page keywords in the header tag
|
||||
add_share_buttons: True # Add social share buttons
|
||||
add_authors: False # Display page authors
|
||||
add_desc: False
|
||||
add_dates: False
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- footnotes
|
||||
- pymdownx.details
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
line_spans: __span
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets:
|
||||
base_path: ..
|
||||
dedent_subsections: true
|
||||
- pymdownx.superfences
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- md_in_html
|
||||
- attr_list
|
||||
- admonition
|
||||
- footnotes
|
||||
- pymdownx.details
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
line_spans: __span
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets:
|
||||
base_path: ..
|
||||
dedent_subsections: true
|
||||
- pymdownx.superfences
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- md_in_html
|
||||
- attr_list
|
||||
|
||||
nav:
|
||||
- Home:
|
||||
- LanceDB: index.md
|
||||
- 🏃🏼♂️ Quick start: basic.md
|
||||
- 📚 Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing: concepts/index_ivfpq.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- 🔨 Guides:
|
||||
- Working with tables: guides/tables.md
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- 🧬 Managing embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔌 Integrations:
|
||||
- Tools and data formats: integrations/index.md
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🔗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain JS/TS 🔗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LlamaIndex 🦙: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- 🎯 Examples:
|
||||
- Overview: examples/index.md
|
||||
- 🐍 Python:
|
||||
- Overview: examples/examples_python.md
|
||||
- Home:
|
||||
- LanceDB: index.md
|
||||
- 🏃🏼♂️ Quick start: basic.md
|
||||
- 📚 Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing: concepts/index_ivfpq.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- 🔨 Guides:
|
||||
- Working with tables: guides/tables.md
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- Sync -> Async Migration Guide: migration.md
|
||||
- 🧬 Managing embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔌 Integrations:
|
||||
- Tools and data formats: integrations/index.md
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🔗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain JS/TS 🔗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LlamaIndex 🦙: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- 🎯 Examples:
|
||||
- Overview: examples/index.md
|
||||
- 🐍 Python:
|
||||
- Overview: examples/examples_python.md
|
||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||
- 👾 JavaScript:
|
||||
- Overview: examples/examples_js.md
|
||||
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
- 🦀 Rust:
|
||||
- Overview: examples/examples_rust.md
|
||||
- 🔧 CLI & Config: cli_config.md
|
||||
- 💭 FAQs: faq.md
|
||||
- ⚙️ API reference:
|
||||
- 🐍 Python: python/python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/
|
||||
- ☁️ LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/saas-modules.md
|
||||
|
||||
- Quick start: basic.md
|
||||
- Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing: concepts/index_ivfpq.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- Guides:
|
||||
- Working with tables: guides/tables.md
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- Sync -> Async Migration Guide: migration.md
|
||||
- Managing Embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Integrations:
|
||||
- Overview: integrations/index.md
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🦜️🔗↗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain.js 🦜️🔗↗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LlamaIndex 🦙↗: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- Examples:
|
||||
- examples/index.md
|
||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||
- 👾 JavaScript:
|
||||
- Overview: examples/examples_js.md
|
||||
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||
- YouTube Transcript Search (JS): examples/youtube_transcript_bot_with_nodejs.md
|
||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
- 🦀 Rust:
|
||||
- Overview: examples/examples_rust.md
|
||||
- 🔧 CLI & Config: cli_config.md
|
||||
- 💭 FAQs: faq.md
|
||||
- ⚙️ API reference:
|
||||
- 🐍 Python: python/python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/
|
||||
- ☁️ LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/saas-modules.md
|
||||
|
||||
|
||||
- Quick start: basic.md
|
||||
- Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing: concepts/index_ivfpq.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- Guides:
|
||||
- Working with tables: guides/tables.md
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- Managing Embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Integrations:
|
||||
- Overview: integrations/index.md
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB : python/duckdb.md
|
||||
- LangChain 🦜️🔗↗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain.js 🦜️🔗↗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LlamaIndex 🦙↗: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- Examples:
|
||||
- examples/index.md
|
||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||
- YouTube Transcript Search (JS): examples/youtube_transcript_bot_with_nodejs.md
|
||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
- API reference:
|
||||
- Overview: api_reference.md
|
||||
- Python: python/python.md
|
||||
- Javascript: javascript/modules.md
|
||||
- Rust: https://docs.rs/lancedb/latest/lancedb/index.html
|
||||
- LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/saas-modules.md
|
||||
- API reference:
|
||||
- Overview: api_reference.md
|
||||
- Python: python/python.md
|
||||
- Javascript: javascript/modules.md
|
||||
- Rust: https://docs.rs/lancedb/latest/lancedb/index.html
|
||||
- LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/saas-modules.md
|
||||
|
||||
extra_css:
|
||||
- styles/global.css
|
||||
|
||||
@@ -48,11 +48,20 @@
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
uri = "data/sample-lancedb"
|
||||
db = lancedb.connect(uri)
|
||||
```
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_basic.py:imports"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:connect"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_basic.py:connect_async"
|
||||
```
|
||||
|
||||
!!! note "Asynchronous Python API"
|
||||
|
||||
The asynchronous Python API is new and has some slight differences compared
|
||||
to the synchronous API. Feel free to start using the asynchronous version.
|
||||
Once all features have migrated we will start to move the synchronous API to
|
||||
use the same syntax as the asynchronous API. To help with this migration we
|
||||
have created a [migration guide](migration.md) detailing the differences.
|
||||
|
||||
=== "Typescript"
|
||||
|
||||
@@ -82,15 +91,14 @@ If you need a reminder of the uri, you can call `db.uri()`.
|
||||
### Create a table from initial data
|
||||
|
||||
If you have data to insert into the table at creation time, you can simultaneously create a
|
||||
table and insert the data into it. The schema of the data will be used as the schema of the
|
||||
table and insert the data into it. The schema of the data will be used as the schema of the
|
||||
table.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
tbl = db.create_table("my_table",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_table"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_table_async"
|
||||
```
|
||||
|
||||
If the table already exists, LanceDB will raise an error by default.
|
||||
@@ -100,10 +108,8 @@ table.
|
||||
You can also pass in a pandas DataFrame directly:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
df = pd.DataFrame([{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||
tbl = db.create_table("table_from_df", data=df)
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_table_pandas"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_table_async_pandas"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -138,15 +144,14 @@ table.
|
||||
|
||||
Sometimes you may not have the data to insert into the table at creation time.
|
||||
In this case, you can create an empty table and specify the schema, so that you can add
|
||||
data to the table at a later time (as long as it conforms to the schema). This is
|
||||
data to the table at a later time (as long as it conforms to the schema). This is
|
||||
similar to a `CREATE TABLE` statement in SQL.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
import pyarrow as pa
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
tbl = db.create_table("empty_table", schema=schema)
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_empty_table"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_empty_table_async"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -168,7 +173,8 @@ Once created, you can open a table as follows:
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
tbl = db.open_table("my_table")
|
||||
--8<-- "python/python/tests/docs/test_basic.py:open_table"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:open_table_async"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -188,7 +194,8 @@ If you forget the name of your table, you can always get a listing of all table
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
print(db.table_names())
|
||||
--8<-- "python/python/tests/docs/test_basic.py:table_names"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:table_names_async"
|
||||
```
|
||||
|
||||
=== "Javascript"
|
||||
@@ -210,15 +217,8 @@ After a table has been created, you can always add more data to it as follows:
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
|
||||
# Option 1: Add a list of dicts to a table
|
||||
data = [{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
||||
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}]
|
||||
tbl.add(data)
|
||||
|
||||
# Option 2: Add a pandas DataFrame to a table
|
||||
df = pd.DataFrame(data)
|
||||
tbl.add(data)
|
||||
--8<-- "python/python/tests/docs/test_basic.py:add_data"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:add_data_async"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -240,7 +240,8 @@ Once you've embedded the query, you can find its nearest neighbors as follows:
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
tbl.search([100, 100]).limit(2).to_pandas()
|
||||
--8<-- "python/python/tests/docs/test_basic.py:vector_search"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:vector_search_async"
|
||||
```
|
||||
|
||||
This returns a pandas DataFrame with the results.
|
||||
@@ -274,7 +275,8 @@ LanceDB allows you to create an ANN index on a table as follows:
|
||||
=== "Python"
|
||||
|
||||
```py
|
||||
tbl.create_index()
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_index"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:create_index_async"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -286,15 +288,15 @@ LanceDB allows you to create an ANN index on a table as follows:
|
||||
=== "Rust"
|
||||
|
||||
```rust
|
||||
--8<-- "rust/lancedb/examples/simple.rs:create_index"
|
||||
--8<-- "rust/lancedb/examples/simple.rs:create_index"
|
||||
```
|
||||
|
||||
!!! note "Why do I need to create an index manually?"
|
||||
LanceDB does not automatically create the ANN index for two reasons. The first is that it's optimized
|
||||
for really fast retrievals via a disk-based index, and the second is that data and query workloads can
|
||||
be very diverse, so there's no one-size-fits-all index configuration. LanceDB provides many parameters
|
||||
to fine-tune index size, query latency and accuracy. See the section on
|
||||
[ANN indexes](ann_indexes.md) for more details.
|
||||
LanceDB does not automatically create the ANN index for two reasons. The first is that it's optimized
|
||||
for really fast retrievals via a disk-based index, and the second is that data and query workloads can
|
||||
be very diverse, so there's no one-size-fits-all index configuration. LanceDB provides many parameters
|
||||
to fine-tune index size, query latency and accuracy. See the section on
|
||||
[ANN indexes](ann_indexes.md) for more details.
|
||||
|
||||
## Delete rows from a table
|
||||
|
||||
@@ -305,7 +307,8 @@ This can delete any number of rows that match the filter.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
tbl.delete('item = "fizz"')
|
||||
--8<-- "python/python/tests/docs/test_basic.py:delete_rows"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:delete_rows_async"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
@@ -322,7 +325,7 @@ This can delete any number of rows that match the filter.
|
||||
|
||||
The deletion predicate is a SQL expression that supports the same expressions
|
||||
as the `where()` clause (`only_if()` in Rust) on a search. They can be as
|
||||
simple or complex as needed. To see what expressions are supported, see the
|
||||
simple or complex as needed. To see what expressions are supported, see the
|
||||
[SQL filters](sql.md) section.
|
||||
|
||||
=== "Python"
|
||||
@@ -344,7 +347,8 @@ Use the `drop_table()` method on the database to remove a table.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
db.drop_table("my_table")
|
||||
--8<-- "python/python/tests/docs/test_basic.py:drop_table"
|
||||
--8<-- "python/python/tests/docs/test_basic.py:drop_table_async"
|
||||
```
|
||||
|
||||
This permanently removes the table and is not recoverable, unlike deleting rows.
|
||||
|
||||
@@ -19,27 +19,163 @@ Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model |
|
||||
|
||||
|
||||
```python
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("sentence-transformers").create(device="cpu")
|
||||
??? "Check out available sentence-transformer models here!"
|
||||
```markdown
|
||||
- sentence-transformers/all-MiniLM-L12-v2
|
||||
- sentence-transformers/paraphrase-mpnet-base-v2
|
||||
- sentence-transformers/gtr-t5-base
|
||||
- sentence-transformers/LaBSE
|
||||
- sentence-transformers/all-MiniLM-L6-v2
|
||||
- sentence-transformers/bert-base-nli-max-tokens
|
||||
- sentence-transformers/bert-base-nli-mean-tokens
|
||||
- sentence-transformers/bert-base-nli-stsb-mean-tokens
|
||||
- sentence-transformers/bert-base-wikipedia-sections-mean-tokens
|
||||
- sentence-transformers/bert-large-nli-cls-token
|
||||
- sentence-transformers/bert-large-nli-max-tokens
|
||||
- sentence-transformers/bert-large-nli-mean-tokens
|
||||
- sentence-transformers/bert-large-nli-stsb-mean-tokens
|
||||
- sentence-transformers/distilbert-base-nli-max-tokens
|
||||
- sentence-transformers/distilbert-base-nli-mean-tokens
|
||||
- sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
- sentence-transformers/distilroberta-base-msmarco-v1
|
||||
- sentence-transformers/distilroberta-base-msmarco-v2
|
||||
- sentence-transformers/nli-bert-base-cls-pooling
|
||||
- sentence-transformers/nli-bert-base-max-pooling
|
||||
- sentence-transformers/nli-bert-base
|
||||
- sentence-transformers/nli-bert-large-cls-pooling
|
||||
- sentence-transformers/nli-bert-large-max-pooling
|
||||
- sentence-transformers/nli-bert-large
|
||||
- sentence-transformers/nli-distilbert-base-max-pooling
|
||||
- sentence-transformers/nli-distilbert-base
|
||||
- sentence-transformers/nli-roberta-base
|
||||
- sentence-transformers/nli-roberta-large
|
||||
- sentence-transformers/roberta-base-nli-mean-tokens
|
||||
- sentence-transformers/roberta-base-nli-stsb-mean-tokens
|
||||
- sentence-transformers/roberta-large-nli-mean-tokens
|
||||
- sentence-transformers/roberta-large-nli-stsb-mean-tokens
|
||||
- sentence-transformers/stsb-bert-base
|
||||
- sentence-transformers/stsb-bert-large
|
||||
- sentence-transformers/stsb-distilbert-base
|
||||
- sentence-transformers/stsb-roberta-base
|
||||
- sentence-transformers/stsb-roberta-large
|
||||
- sentence-transformers/xlm-r-100langs-bert-base-nli-mean-tokens
|
||||
- sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens
|
||||
- sentence-transformers/xlm-r-base-en-ko-nli-ststb
|
||||
- sentence-transformers/xlm-r-bert-base-nli-mean-tokens
|
||||
- sentence-transformers/xlm-r-bert-base-nli-stsb-mean-tokens
|
||||
- sentence-transformers/xlm-r-large-en-ko-nli-ststb
|
||||
- sentence-transformers/bert-base-nli-cls-token
|
||||
- sentence-transformers/all-distilroberta-v1
|
||||
- sentence-transformers/multi-qa-MiniLM-L6-dot-v1
|
||||
- sentence-transformers/multi-qa-distilbert-cos-v1
|
||||
- sentence-transformers/multi-qa-distilbert-dot-v1
|
||||
- sentence-transformers/multi-qa-mpnet-base-cos-v1
|
||||
- sentence-transformers/multi-qa-mpnet-base-dot-v1
|
||||
- sentence-transformers/nli-distilroberta-base-v2
|
||||
- sentence-transformers/all-MiniLM-L6-v1
|
||||
- sentence-transformers/all-mpnet-base-v1
|
||||
- sentence-transformers/all-mpnet-base-v2
|
||||
- sentence-transformers/all-roberta-large-v1
|
||||
- sentence-transformers/allenai-specter
|
||||
- sentence-transformers/average_word_embeddings_glove.6B.300d
|
||||
- sentence-transformers/average_word_embeddings_glove.840B.300d
|
||||
- sentence-transformers/average_word_embeddings_komninos
|
||||
- sentence-transformers/average_word_embeddings_levy_dependency
|
||||
- sentence-transformers/clip-ViT-B-32-multilingual-v1
|
||||
- sentence-transformers/clip-ViT-B-32
|
||||
- sentence-transformers/distilbert-base-nli-stsb-quora-ranking
|
||||
- sentence-transformers/distilbert-multilingual-nli-stsb-quora-ranking
|
||||
- sentence-transformers/distilroberta-base-paraphrase-v1
|
||||
- sentence-transformers/distiluse-base-multilingual-cased-v1
|
||||
- sentence-transformers/distiluse-base-multilingual-cased-v2
|
||||
- sentence-transformers/distiluse-base-multilingual-cased
|
||||
- sentence-transformers/facebook-dpr-ctx_encoder-multiset-base
|
||||
- sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base
|
||||
- sentence-transformers/facebook-dpr-question_encoder-multiset-base
|
||||
- sentence-transformers/facebook-dpr-question_encoder-single-nq-base
|
||||
- sentence-transformers/gtr-t5-large
|
||||
- sentence-transformers/gtr-t5-xl
|
||||
- sentence-transformers/gtr-t5-xxl
|
||||
- sentence-transformers/msmarco-MiniLM-L-12-v3
|
||||
- sentence-transformers/msmarco-MiniLM-L-6-v3
|
||||
- sentence-transformers/msmarco-MiniLM-L12-cos-v5
|
||||
- sentence-transformers/msmarco-MiniLM-L6-cos-v5
|
||||
- sentence-transformers/msmarco-bert-base-dot-v5
|
||||
- sentence-transformers/msmarco-bert-co-condensor
|
||||
- sentence-transformers/msmarco-distilbert-base-dot-prod-v3
|
||||
- sentence-transformers/msmarco-distilbert-base-tas-b
|
||||
- sentence-transformers/msmarco-distilbert-base-v2
|
||||
- sentence-transformers/msmarco-distilbert-base-v3
|
||||
- sentence-transformers/msmarco-distilbert-base-v4
|
||||
- sentence-transformers/msmarco-distilbert-cos-v5
|
||||
- sentence-transformers/msmarco-distilbert-dot-v5
|
||||
- sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned
|
||||
- sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch
|
||||
- sentence-transformers/msmarco-distilroberta-base-v2
|
||||
- sentence-transformers/msmarco-roberta-base-ance-firstp
|
||||
- sentence-transformers/msmarco-roberta-base-v2
|
||||
- sentence-transformers/msmarco-roberta-base-v3
|
||||
- sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
||||
- sentence-transformers/nli-mpnet-base-v2
|
||||
- sentence-transformers/nli-roberta-base-v2
|
||||
- sentence-transformers/nq-distilbert-base-v1
|
||||
- sentence-transformers/paraphrase-MiniLM-L12-v2
|
||||
- sentence-transformers/paraphrase-MiniLM-L3-v2
|
||||
- sentence-transformers/paraphrase-MiniLM-L6-v2
|
||||
- sentence-transformers/paraphrase-TinyBERT-L6-v2
|
||||
- sentence-transformers/paraphrase-albert-base-v2
|
||||
- sentence-transformers/paraphrase-albert-small-v2
|
||||
- sentence-transformers/paraphrase-distilroberta-base-v1
|
||||
- sentence-transformers/paraphrase-distilroberta-base-v2
|
||||
- sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
|
||||
- sentence-transformers/paraphrase-multilingual-mpnet-base-v2
|
||||
- sentence-transformers/paraphrase-xlm-r-multilingual-v1
|
||||
- sentence-transformers/quora-distilbert-base
|
||||
- sentence-transformers/quora-distilbert-multilingual
|
||||
- sentence-transformers/sentence-t5-base
|
||||
- sentence-transformers/sentence-t5-large
|
||||
- sentence-transformers/sentence-t5-xxl
|
||||
- sentence-transformers/sentence-t5-xl
|
||||
- sentence-transformers/stsb-distilroberta-base-v2
|
||||
- sentence-transformers/stsb-mpnet-base-v2
|
||||
- sentence-transformers/stsb-roberta-base-v2
|
||||
- sentence-transformers/stsb-xlm-r-multilingual
|
||||
- sentence-transformers/xlm-r-distilroberta-base-paraphrase-v1
|
||||
- sentence-transformers/clip-ViT-L-14
|
||||
- sentence-transformers/clip-ViT-B-16
|
||||
- sentence-transformers/use-cmlm-multilingual
|
||||
- sentence-transformers/all-MiniLM-L12-v1
|
||||
```
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
!!! info
|
||||
You can also load many other model architectures from the library. For example models from sources such as BAAI, nomic, salesforce research, etc.
|
||||
See this HF hub page for all [supported models](https://huggingface.co/models?library=sentence-transformers).
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
!!! note "BAAI Embeddings example"
|
||||
Here is an example that uses BAAI embedding model from the HuggingFace Hub [supported models](https://huggingface.co/models?library=sentence-transformers)
|
||||
```python
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
Visit sentence-transformers [HuggingFace HUB](https://huggingface.co/sentence-transformers) page for more information on the available models.
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
### OpenAI embeddings
|
||||
LanceDB registers the OpenAI embeddings function in the registry by default, as `openai`. Below are the parameters that you can customize when creating the instances:
|
||||
|
||||
150
docs/src/eval/bench_fine_tuned_hybrid.py
Normal file
150
docs/src/eval/bench_fine_tuned_hybrid.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import os
|
||||
import requests
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from lancedb.rerankers import CrossEncoderReranker, ColbertReranker, CohereReranker, LinearCombinationReranker
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
|
||||
import time
|
||||
import lancedb
|
||||
import wandb
|
||||
from pydantic import BaseModel, root_validator
|
||||
from typing import Optional
|
||||
|
||||
TRAIN_DATASET_FPATH = './data/train_dataset.json'
|
||||
VAL_DATASET_FPATH = './data/val_dataset.json'
|
||||
|
||||
with open(TRAIN_DATASET_FPATH, 'r+') as f:
|
||||
train_dataset = json.load(f)
|
||||
|
||||
with open(VAL_DATASET_FPATH, 'r+') as f:
|
||||
val_dataset = json.load(f)
|
||||
|
||||
def train_embedding_model(epoch):
|
||||
def download_test_files(url):
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
|
||||
ds.save(name)
|
||||
return ds
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
ds = get_dataset(train_url, "qa_dataset_uber")
|
||||
|
||||
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5")
|
||||
model.finetune(trainset=ds, valset=None, path="model_airbnb", epochs=epoch, log_wandb=True, run_name="lyft_finetune")
|
||||
|
||||
|
||||
def evaluate(
|
||||
dataset,
|
||||
embed_model,
|
||||
reranker=None,
|
||||
top_k=5,
|
||||
verbose=False,
|
||||
):
|
||||
corpus = dataset['corpus']
|
||||
queries = dataset['queries']
|
||||
relevant_docs = dataset['relevant_docs']
|
||||
|
||||
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
tbl = vector_store.connection.open_table(vector_store.table_name)
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
query_vector = embed_model.get_query_embedding(query)
|
||||
try:
|
||||
if reranker is None:
|
||||
rs = tbl.search(query_vector).limit(top_k).to_pandas()
|
||||
else:
|
||||
rs = tbl.search((query_vector, query)).rerank(reranker=reranker).limit(top_k).to_pandas()
|
||||
except Exception as e:
|
||||
print(f'Error with query: {query_id} {e}')
|
||||
continue
|
||||
retrieved_ids = rs['id'].tolist()[:top_k]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
if len(eval_results) == 0:
|
||||
print(f"Query: {query}")
|
||||
print(f"Expected: {expected_id}")
|
||||
print(f"Retrieved: {retrieved_ids}")
|
||||
eval_result = {
|
||||
'is_hit': is_hit,
|
||||
'retrieved': retrieved_ids,
|
||||
'expected': expected_id,
|
||||
'query': query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
return eval_results
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_embedding_model(4)
|
||||
#embed_model = OpenAIEmbedding() # model="text-embedding-3-small"
|
||||
rerankers = {
|
||||
"Vector Search": None,
|
||||
"Cohere": CohereReranker(),
|
||||
"Cross Encoder": CrossEncoderReranker(),
|
||||
"Colbert": ColbertReranker(),
|
||||
"linear": LinearCombinationReranker(),
|
||||
}
|
||||
top_ks = [3]
|
||||
for top_k in top_ks:
|
||||
#for epoch in epochs:
|
||||
for name, reranker in rerankers.items():
|
||||
#embed_model = HuggingFaceEmbedding("./model_airbnb")
|
||||
embed_model = OpenAIEmbedding()
|
||||
wandb.init(project=f"Reranker-based", name=name)
|
||||
val_eval_results = evaluate(val_dataset, embed_model, reranker=reranker, top_k=top_k)
|
||||
df = pd.DataFrame(val_eval_results)
|
||||
|
||||
hit_rate = df['is_hit'].mean()
|
||||
print(f'Hit rate: {hit_rate:.2f}')
|
||||
wandb.log({f"openai_base_hit_rate_@{top_k}": hit_rate})
|
||||
wandb.finish()
|
||||
|
||||
|
||||
71
docs/src/eval/test_fine_tune_from_llm.py
Normal file
71
docs/src/eval/test_fine_tune_from_llm.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import json
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
|
||||
|
||||
trainset = get_dataset(test_url, "qa_dataset_1")
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=trainset, valset=valset, path="model_finetuned_1", epochs=4)
|
||||
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(valset, path="val_res")
|
||||
rs2 = tuned.evaluate(valset, path="val_res")
|
||||
rs3 = openai.evaluate(valset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
|
||||
119
docs/src/eval/test_fine_tune_from_responses.py
Normal file
119
docs/src/eval/test_fine_tune_from_responses.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
|
||||
def get_node(url):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
return nodes
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
nodes = get_node(train_url)
|
||||
|
||||
db = lancedb.connect("~/lancedb/fine-tuning")
|
||||
model = get_registry().get("openai").create()
|
||||
class Schema(LanceModel):
|
||||
id: str
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(node.node_id), "text": node.text} for node in nodes]
|
||||
retriever.add(pylist)
|
||||
|
||||
|
||||
|
||||
ds_name = "response_data"
|
||||
if os.path.exists(ds_name):
|
||||
ds = QADataset.load(ds_name)
|
||||
else:
|
||||
# Generate questions
|
||||
llm = Openai()
|
||||
text_chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for chunk in tqdm(text_chunks):
|
||||
text = chunk.text
|
||||
questions = llm.get_questions(DEFAULT_PROMPT_TMPL.format(context_str=text, num_questions_per_chunk=2))
|
||||
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [retriever.search(question).to_pandas()["id"].tolist()[0]]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
ds.save(ds_name)
|
||||
|
||||
|
||||
# Fine-tune model
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
res_base = model.evaluate(valset)
|
||||
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4, log_wandb=True)
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned")
|
||||
res_tuned = tuned.evaluate(valset)
|
||||
|
||||
openai_model = get_registry().get("openai").create()
|
||||
#res_openai = openai_model.evaluate(valset)
|
||||
|
||||
#print(f"openai model results: {pd.DataFrame(res_openai)['is_hit'].mean()}")
|
||||
print(f"base model results: {pd.DataFrame(res_base)['is_hit'].mean()}")
|
||||
print(f"tuned model results: {pd.DataFrame(res_tuned)['is_hit'].mean()}")
|
||||
|
||||
|
||||
76
docs/src/migration.md
Normal file
76
docs/src/migration.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Rust-backed Client Migration Guide
|
||||
|
||||
In an effort to ensure all clients have the same set of capabilities we have begun migrating the
|
||||
python and node clients onto a common Rust base library. In python, this new client is part of
|
||||
the same lancedb package, exposed as an asynchronous client. Once the asynchronous client has
|
||||
reached full functionality we will begin migrating the synchronous library to be a thin wrapper
|
||||
around the asynchronous client.
|
||||
|
||||
This guide describes the differences between the two APIs and will hopefully assist users
|
||||
that would like to migrate to the new API.
|
||||
|
||||
## Closeable Connections
|
||||
|
||||
The Connection now has a `close` method. You can call this when
|
||||
you are done with the connection to eagerly free resources. Currently
|
||||
this is limited to freeing/closing the HTTP connection for remote
|
||||
connections. In the future we may add caching or other resources to
|
||||
native connections so this is probably a good practice even if you
|
||||
aren't using remote connections.
|
||||
|
||||
In addition, the connection can be used as a context manager which may
|
||||
be a more convenient way to ensure the connection is closed.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
|
||||
async def my_async_fn():
|
||||
with await lancedb.connect_async("my_uri") as db:
|
||||
print(await db.table_names())
|
||||
```
|
||||
|
||||
It is not mandatory to call the `close` method. If you do not call it
|
||||
then the connection will be closed when the object is garbage collected.
|
||||
|
||||
## Closeable Table
|
||||
|
||||
The Table now also has a `close` method, similar to the connection. This
|
||||
can be used to eagerly free the cache used by a Table object. Similar to
|
||||
the connection, it can be used as a context manager and it is not mandatory
|
||||
to call the `close` method.
|
||||
|
||||
### Changes to Table APIs
|
||||
|
||||
- Previously `Table.schema` was a property. Now it is an async method.
|
||||
- The method `Table.__len__` was removed and `len(table)` will no longer
|
||||
work. Use `Table.count_rows` instead.
|
||||
|
||||
### Creating Indices
|
||||
|
||||
The `Table.create_index` method is now used for creating both vector indices
|
||||
and scalar indices. It currently requires a column name to be specified (the
|
||||
column to index). Vector index defaults are now smarter and scale better with
|
||||
the size of the data.
|
||||
|
||||
To specify index configuration details you will need to specify which kind of
|
||||
index you are using.
|
||||
|
||||
### Querying
|
||||
|
||||
The `Table.search` method has been renamed to `AsyncTable.vector_search` for
|
||||
clarity.
|
||||
|
||||
## Features not yet supported
|
||||
|
||||
The following features are not yet supported by the asynchronous API. However,
|
||||
we plan to support them soon.
|
||||
|
||||
- You cannot specify an embedding function when creating or opening a table.
|
||||
You must calculate embeddings yourself if using the asynchronous API
|
||||
- The merge insert operation is not supported in the asynchronous API
|
||||
- Cleanup / compact / optimize indices are not supported in the asynchronous API
|
||||
- add / alter columns is not supported in the asynchronous API
|
||||
- The asynchronous API does not yet support any full text search or reranking
|
||||
search
|
||||
- Remote connections to LanceDb Cloud are not yet supported.
|
||||
- The method Table.head is not yet supported.
|
||||
@@ -8,17 +8,20 @@ This section contains the API reference for the OSS Python API.
|
||||
pip install lancedb
|
||||
```
|
||||
|
||||
## Connection
|
||||
The following methods describe the synchronous API client. There
|
||||
is also an [asynchronous API client](#connections-asynchronous).
|
||||
|
||||
## Connections (Synchronous)
|
||||
|
||||
::: lancedb.connect
|
||||
|
||||
::: lancedb.db.DBConnection
|
||||
|
||||
## Table
|
||||
## Tables (Synchronous)
|
||||
|
||||
::: lancedb.table.Table
|
||||
|
||||
## Querying
|
||||
## Querying (Synchronous)
|
||||
|
||||
::: lancedb.query.Query
|
||||
|
||||
@@ -86,4 +89,42 @@ pip install lancedb
|
||||
|
||||
::: lancedb.rerankers.cross_encoder.CrossEncoderReranker
|
||||
|
||||
::: lancedb.rerankers.openai.OpenaiReranker
|
||||
::: lancedb.rerankers.openai.OpenaiReranker
|
||||
|
||||
## Connections (Asynchronous)
|
||||
|
||||
Connections represent a connection to a LanceDb database and
|
||||
can be used to create, list, or open tables.
|
||||
|
||||
::: lancedb.connect_async
|
||||
|
||||
::: lancedb.db.AsyncConnection
|
||||
|
||||
## Tables (Asynchronous)
|
||||
|
||||
Table hold your actual data as a collection of records / rows.
|
||||
|
||||
::: lancedb.table.AsyncTable
|
||||
|
||||
## Indices (Asynchronous)
|
||||
|
||||
Indices can be created on a table to speed up queries. This section
|
||||
lists the indices that LanceDb supports.
|
||||
|
||||
::: lancedb.index.BTree
|
||||
|
||||
::: lancedb.index.IvfPq
|
||||
|
||||
## Querying (Asynchronous)
|
||||
|
||||
Queries allow you to return data from your database. Basic queries can be
|
||||
created with the [AsyncTable.query][lancedb.table.AsyncTable.query] method
|
||||
to return the entire (typically filtered) table. Vector searches return the
|
||||
rows nearest to a query vector and can be created with the
|
||||
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] method.
|
||||
|
||||
::: lancedb.query.AsyncQueryBase
|
||||
|
||||
::: lancedb.query.AsyncQuery
|
||||
|
||||
::: lancedb.query.AsyncVectorQuery
|
||||
|
||||
56
node/package-lock.json
generated
56
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.13",
|
||||
"version": "0.4.14",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.4.13",
|
||||
"version": "0.4.14",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.13",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.13",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.13",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.13",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.13"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.14",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.14",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.14",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.14",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.14"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -334,9 +334,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.4.13",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.13.tgz",
|
||||
"integrity": "sha512-JfroNCG8yKIU931Y+x8d0Fp8C9DHUSC5j+CjI+e5err7rTWtie4j3JbsXlWAnPFaFEOg0Xk3BWkSikCvhPGJGg==",
|
||||
"version": "0.4.14",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.14.tgz",
|
||||
"integrity": "sha512-fw6mf6UhFf4j2kKdFcw0P+SOiIqmRbt+YQSgDbF4BFU3OUSW0XyfETIj9cUMQbSwPFsofhlGp5BRpCd7W9noew==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -345,22 +345,10 @@
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.4.13",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.13.tgz",
|
||||
"integrity": "sha512-dG6IMvfpHpnHdbJ0UffzJ7cZfMiC02MjIi6YJzgx+hKz2UNXWNBIfTvvhqli85mZsGRXL1OYDdYv0K1YzNjXlA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.4.13",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.13.tgz",
|
||||
"integrity": "sha512-BRR1VzaMviXby7qmLm0axNZM8eUZF3ZqfvnDKdVRpC3LaRueD6pMXHuC2IUKaFkn7xktf+8BlDZb6foFNEj8bQ==",
|
||||
"version": "0.4.14",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.14.tgz",
|
||||
"integrity": "sha512-1+LFI8vU+f/lnGy1s3XCySuV4oj3ZUW03xtmedGBW8nv/Y/jWXP0OYJCRI72eu+dLIdu0tCPsEiu8Hl+o02t9g==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -369,22 +357,10 @@
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.4.13",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.13.tgz",
|
||||
"integrity": "sha512-WnekZ7ZMlria+NODZ6aBCljCFQSe2bBNUS9ZpyFl/Y1vHduSQPuBxM6V7vp2QubC0daq/rifgjDob89DF+x3xw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.4.13",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.13.tgz",
|
||||
"integrity": "sha512-3NDpMWBL2ksDHXAraXhowiLqQcNWM5bdbeHwze4+InYMD54hyQ2ODNc+4usxp63Nya9biVnFS27yXULqkzIEqQ==",
|
||||
"version": "0.4.14",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.14.tgz",
|
||||
"integrity": "sha512-fpuNMZ4aHSpZC3ztp5a0Wh18N6DpCx5EPWhS7bGA5XulGc0l+sZAJHfHwalx76ys//0Ns1z7cuKJhZpSa4SrdQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
||||
@@ -145,34 +145,20 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
For a local directory, provide a path for the database:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("~/.lancedb")
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
>>> async def doctest_example():
|
||||
... # For a local directory, provide a path to the database
|
||||
... db = await lancedb.connect_async("~/.lancedb")
|
||||
... # For object storage, use a URI prefix
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : DBConnection
|
||||
conn : AsyncConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
if read_consistency_interval is not None:
|
||||
|
||||
@@ -25,7 +25,6 @@ from overrides import EnforceOverrides, override
|
||||
from pyarrow import fs
|
||||
|
||||
from lancedb.common import data_to_reader, validate_schema
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.utils.events import register_event
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
@@ -451,16 +450,17 @@ class LanceDBConnection(DBConnection):
|
||||
class AsyncConnection(object):
|
||||
"""An active LanceDB connection
|
||||
|
||||
To obtain a connection you can use the [connect] function.
|
||||
To obtain a connection you can use the [connect_async][lancedb.connect_async]
|
||||
function.
|
||||
|
||||
This could be a native connection (using lance) or a remote connection (e.g. for
|
||||
connecting to LanceDb Cloud)
|
||||
|
||||
Local connections do not currently hold any open resources but they may do so in the
|
||||
future (for example, for shared cache or connections to catalog services) Remote
|
||||
connections represent an open connection to the remote server. The [close] method
|
||||
can be used to release any underlying resources eagerly. The connection can also
|
||||
be used as a context manager:
|
||||
connections represent an open connection to the remote server. The
|
||||
[close][lancedb.db.AsyncConnection.close] method can be used to release any
|
||||
underlying resources eagerly. The connection can also be used as a context manager.
|
||||
|
||||
Connections can be shared on multiple threads and are expected to be long lived.
|
||||
Connections can also be used as a context manager, however, in many cases a single
|
||||
@@ -471,10 +471,9 @@ class AsyncConnection(object):
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import asyncio
|
||||
>>> import lancedb
|
||||
>>> async def my_connect():
|
||||
... with await lancedb.connect("/tmp/my_dataset") as conn:
|
||||
>>> async def doctest_example():
|
||||
... with await lancedb.connect_async("/tmp/my_dataset") as conn:
|
||||
... # do something with the connection
|
||||
... pass
|
||||
... # conn is closed here
|
||||
@@ -535,9 +534,8 @@ class AsyncConnection(object):
|
||||
exist_ok: Optional[bool] = None,
|
||||
on_bad_vectors: Optional[str] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> AsyncTable:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -576,7 +574,7 @@ class AsyncConnection(object):
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
AsyncTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
@@ -590,12 +588,14 @@ class AsyncConnection(object):
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
>>> async def doctest_example():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
... my_table = await db.create_table("my_table", data)
|
||||
... print(await my_table.query().limit(5).to_arrow())
|
||||
>>> import asyncio
|
||||
>>> asyncio.run(doctest_example())
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
@@ -614,9 +614,11 @@ class AsyncConnection(object):
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
>>> async def pandas_example():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... my_table = await db.create_table("table2", data)
|
||||
... print(await my_table.query().limit(5).to_arrow())
|
||||
>>> asyncio.run(pandas_example())
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
@@ -636,9 +638,11 @@ class AsyncConnection(object):
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
>>> async def with_schema():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... my_table = await db.create_table("table3", data, schema = custom_schema)
|
||||
... print(await my_table.query().limit(5).to_arrow())
|
||||
>>> asyncio.run(with_schema())
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
@@ -670,9 +674,10 @@ class AsyncConnection(object):
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
>>> async def iterable_example():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... await db.create_table("table4", make_batches(), schema=schema)
|
||||
>>> asyncio.run(iterable_example())
|
||||
"""
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
@@ -681,12 +686,6 @@ class AsyncConnection(object):
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
# Defining defaults here and not in function prototype. In the future
|
||||
# these defaults will move into rust so better to keep them as None.
|
||||
@@ -767,11 +766,11 @@ class AsyncConnection(object):
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
await self._inner.drop_table(name)
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
await self._inner.drop_db()
|
||||
|
||||
@@ -10,13 +10,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
import lancedb
|
||||
|
||||
from .fine_tuner import QADataset
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
@@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
def finetune(self, dataset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Finetuning is not supported for this embedding function"
|
||||
)
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Evaluation is not supported for this embedding function"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
@@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset. This calculates the hit-rate for
|
||||
the top-k retrieved documents for each query in the dataset. Assumes that the
|
||||
first relevant document is the expected document.
|
||||
Pro - Should work for any embedding model
|
||||
Con - Returns every simple metric.
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to evaluate on
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The evaluation results
|
||||
"""
|
||||
corpus = dataset.corpus
|
||||
queries = dataset.queries
|
||||
relevant_docs = dataset.relevant_docs
|
||||
path = path or os.path.join(os.getcwd(), "eval")
|
||||
db = lancedb.connect(path)
|
||||
|
||||
class Schema(lancedb.pydantic.LanceModel):
|
||||
id: str
|
||||
text: str = self.SourceField()
|
||||
vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField()
|
||||
|
||||
retriever = db.create_table("eval", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(k), "text": v} for k, v in corpus.items()]
|
||||
retriever.add(pylist)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
retrieved_nodes = retriever.search(query).limit(top_k).to_list()
|
||||
retrieved_ids = [node["id"] for node in retrieved_nodes]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
|
||||
eval_result = {
|
||||
"is_hit": is_hit,
|
||||
"retrieved": retrieved_ids,
|
||||
"expected": expected_id,
|
||||
"query": query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
|
||||
return eval_results
|
||||
|
||||
133
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
133
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
@@ -0,0 +1,133 @@
|
||||
Fine-tuning workflow for embeddings consists for the following parts:
|
||||
|
||||
### QADataset
|
||||
This class is used for managing the data for fine-tuning. It contains the following builder methods:
|
||||
```
|
||||
- from_llm(
|
||||
nodes: 'List[TextChunk]' ,
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset"
|
||||
```
|
||||
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
|
||||
|
||||
```python
|
||||
|
||||
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
|
||||
```
|
||||
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
|
||||
|
||||
It also contains the following data attributes:
|
||||
```
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
```
|
||||
|
||||
### TextChunk
|
||||
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
|
||||
```
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
```
|
||||
|
||||
Builder Methods:
|
||||
|
||||
```python
|
||||
from_llama_index_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a llama index node.
|
||||
|
||||
```python
|
||||
from_langchain_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a langchain index node.
|
||||
|
||||
```python
|
||||
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a string.
|
||||
|
||||
### FineTuner
|
||||
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
|
||||
```python
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Embedding API finetuning implementation
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
# 1. Create a QADataset object
|
||||
url = "uber10k.pdf"
|
||||
reader = SimpleDirectoryReader(input_files=url)
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(trainset, path="val_res")
|
||||
rs2 = tuned.evaluate(trainset, path="val_res")
|
||||
rs3 = openai.evaluate(trainset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
```
|
||||
|
||||
|
||||
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import QADataset, TextChunk
|
||||
from .llm import Gemini, Openai
|
||||
|
||||
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]
|
||||
13
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
13
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
205
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
205
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from lancedb.utils.general import LOGGER
|
||||
from .llm import BaseLLM
|
||||
|
||||
DEFAULT_PROMPT_TMPL = """\
|
||||
Context information is below.
|
||||
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
|
||||
Given the context information and no prior knowledge.
|
||||
generate only questions based on the below query.
|
||||
|
||||
You are a Teacher/ Professor. Your task is to setup \
|
||||
{num_questions_per_chunk} questions for an upcoming \
|
||||
quiz/examination. The questions should be diverse in nature \
|
||||
across the document. Restrict the questions to the \
|
||||
context information provided."
|
||||
"""
|
||||
|
||||
|
||||
class QADataset(BaseModel):
|
||||
"""Embedding QA Finetuning Dataset.
|
||||
|
||||
Args:
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
|
||||
"""
|
||||
path: Optional[str] = None
|
||||
queries: Dict[str, str] # id -> query
|
||||
corpus: Dict[str, str] # id -> text
|
||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||
mode: str = "text"
|
||||
|
||||
@property
|
||||
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
|
||||
"""Get query, relevant doc ids."""
|
||||
return [
|
||||
(query, self.relevant_docs[query_id])
|
||||
for query_id, query in self.queries.items()
|
||||
]
|
||||
|
||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||
"""Save to lance dataset"""
|
||||
self.path = path
|
||||
save_dir = Path(path)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# convert to pydict {"id": []}
|
||||
queries = {
|
||||
"id": list(self.queries.keys()),
|
||||
"query": list(self.queries.values()),
|
||||
}
|
||||
corpus = {
|
||||
"id": list(self.corpus.keys()),
|
||||
"text": [
|
||||
val or " " for val in self.corpus.values()
|
||||
], # lance saves empty strings as null
|
||||
}
|
||||
relevant_docs = {
|
||||
"query_id": list(self.relevant_docs.keys()),
|
||||
"doc_id": list(self.relevant_docs.values()),
|
||||
}
|
||||
|
||||
# write to lance
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(relevant_docs),
|
||||
save_dir / "relevant_docs.lance",
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
|
||||
"""Load from .lance data"""
|
||||
load_dir = Path(path)
|
||||
queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
|
||||
relevant_docs = (
|
||||
lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
|
||||
)
|
||||
return cls(
|
||||
path=str(path),
|
||||
queries=dict(zip(queries["id"], queries["query"])),
|
||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def switch_version(cls, version: int) -> "QADataset":
|
||||
"""Switch version of a dataset."""
|
||||
if not cls.path:
|
||||
raise ValueError("Path not set. You need to call save() first.")
|
||||
return cls.load(cls.path, version=version)
|
||||
|
||||
# generate queries as a convenience function
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
nodes: "List[TextChunk]",
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset":
|
||||
"""Generate examples given a set of nodes."""
|
||||
node_dict = {node.id: node.text for node in nodes}
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for node_id, text in tqdm(node_dict.items()):
|
||||
query = qa_generate_prompt_tmpl.format(
|
||||
context_str=text, num_questions_per_chunk=num_questions_per_chunk
|
||||
)
|
||||
response = llm.chat_completion(query)
|
||||
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [node_id]
|
||||
|
||||
return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
@classmethod
|
||||
def from_responses(
|
||||
cls,
|
||||
docs: List["TextChunk"],
|
||||
queries: Dict[str, str],
|
||||
relevant_docs: Dict[str, List[str]],
|
||||
) -> "QADataset":
|
||||
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
||||
node_dict = {node.id: node.text for node in docs}
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
def versions(self) -> List[int]:
|
||||
"""Get the versions of the dataset."""
|
||||
# TODO: tidy this up
|
||||
data_paths = self._get_data_file_paths()
|
||||
return lance.dataset(data_paths[0]).versions()
|
||||
|
||||
|
||||
def _get_data_file_paths(self) -> str:
|
||||
"""Get the absolute path of the dataset."""
|
||||
queries = self.path / "queries.lance"
|
||||
corpus = self.path / "corpus.lance"
|
||||
relevant_docs = self.path / "relevant_docs.lance"
|
||||
|
||||
return queries, corpus, relevant_docs
|
||||
|
||||
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
"""Simple text chunk for generating questions."""
|
||||
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
|
||||
"""Create a SimpleTextChunk from a chunk."""
|
||||
# generate a unique id
|
||||
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_llama_index_node(cls, node):
|
||||
"""Convert a llama index node to a text chunk."""
|
||||
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_node(cls, node):
|
||||
"""Convert a langchaain node to a text chunk."""
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a dictionary."""
|
||||
return self.dict()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
|
||||
metadata={self.metadata})"
|
||||
85
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
85
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...util import attempt_import_or_raise
|
||||
from ..utils import api_key_not_found_help
|
||||
|
||||
|
||||
class BaseLLM(BaseModel):
|
||||
"""
|
||||
TODO:
|
||||
Base class for Language Model based Embedding Functions. This class is
|
||||
loosely desined rn, and will be updated as the usage gets clearer.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_kwargs: dict = {}
|
||||
|
||||
@cached_property
|
||||
def _client():
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_completion(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Openai(BaseLLM):
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
kwargs: dict = {}
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
|
||||
def chat_completion(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
|
||||
# TODO: this is legacy openai api replace with completions
|
||||
completion = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
text = completion.choices[0].message.content
|
||||
|
||||
return text
|
||||
|
||||
def get_questions(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
response = self.chat_completion(prompt)
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
return questions
|
||||
|
||||
|
||||
class Gemini(BaseLLM):
|
||||
pass
|
||||
@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
|
||||
@@ -10,12 +10,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.embeddings.fine_tuner import QADataset
|
||||
from lancedb.utils.general import LOGGER
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .fine_tuner.basetuner import BaseEmbeddingTuner
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
@@ -80,3 +84,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
def finetune(self, trainset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the Sentence Transformers model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to use for finetuning
|
||||
"""
|
||||
tuner = SentenceTransformersTuner(
|
||||
model=self.embedding_model,
|
||||
trainset=trainset,
|
||||
**kwargs,
|
||||
)
|
||||
tuner.finetune()
|
||||
|
||||
|
||||
class SentenceTransformersTuner(BaseEmbeddingTuner):
|
||||
"""Sentence Transformers Embedding Finetuning Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
trainset: QADataset,
|
||||
valset: Optional[QADataset] = None,
|
||||
path: Optional[str] = "~/.lancedb/embeddings/models",
|
||||
batch_size: int = 8,
|
||||
epochs: int = 1,
|
||||
show_progress: bool = True,
|
||||
eval_steps: int = 50,
|
||||
max_input_per_doc: int = -1,
|
||||
loss: Optional[Any] = None,
|
||||
evaluator: Optional[Any] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_wandb: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model: str
|
||||
The model to use for finetuning.
|
||||
trainset: QADataset
|
||||
The training dataset.
|
||||
valset: Optional[QADataset]
|
||||
The validation dataset.
|
||||
path: Optional[str]
|
||||
The path to save the model.
|
||||
batch_size: int, default=8
|
||||
The batch size.
|
||||
epochs: int, default=1
|
||||
The number of epochs.
|
||||
show_progress: bool, default=True
|
||||
Whether to show progress.
|
||||
eval_steps: int, default=50
|
||||
The number of steps to evaluate.
|
||||
max_input_per_doc: int, default=-1
|
||||
The number of input per document.
|
||||
if -1, use all documents.
|
||||
"""
|
||||
from sentence_transformers import InputExample, losses
|
||||
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
self.model = model
|
||||
self.trainset = trainset
|
||||
self.valset = valset
|
||||
self.path = path
|
||||
self.batch_size = batch_size
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.max_input_per_doc = max_input_per_doc
|
||||
self.evaluator = None
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.run_name = run_name
|
||||
self.log_wandb = log_wandb
|
||||
|
||||
if self.max_input_per_doc < -1:
|
||||
raise ValueError("max_input_per_doc must be -1 or greater than 0.")
|
||||
|
||||
examples: Any = []
|
||||
for query_id, query in self.trainset.queries.items():
|
||||
if max_input_per_doc == -1:
|
||||
for node_id in self.trainset.relevant_docs[query_id]:
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
else:
|
||||
node_id = self.trainset.relevant_docs[query_id][
|
||||
min(max_input_per_doc, len(self.trainset.relevant_docs[query_id]))
|
||||
]
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
|
||||
self.examples = examples
|
||||
|
||||
self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
|
||||
|
||||
if self.valset is not None:
|
||||
eval_engine = evaluator or InformationRetrievalEvaluator
|
||||
self.evaluator = eval_engine(
|
||||
valset.queries, valset.corpus, valset.relevant_docs
|
||||
)
|
||||
self.evaluator = evaluator
|
||||
|
||||
# define loss
|
||||
self.loss = loss or losses.MultipleNegativesRankingLoss(self.model)
|
||||
self.warmup_steps = int(len(self.loader) * epochs * 0.1)
|
||||
|
||||
def finetune(self) -> None:
|
||||
"""Finetune the Sentence Transformers model."""
|
||||
self.model.fit(
|
||||
train_objectives=[(self.loader, self.loss)],
|
||||
epochs=self.epochs,
|
||||
warmup_steps=self.warmup_steps,
|
||||
output_path=self.path,
|
||||
show_progress_bar=self.show_progress,
|
||||
evaluator=self.evaluator,
|
||||
evaluation_steps=self.eval_steps,
|
||||
callback=self._wandb_callback if self.log_wandb else None,
|
||||
)
|
||||
|
||||
self.helper()
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
LOGGER.info("Finetuning complete.")
|
||||
LOGGER.info(f"Model saved to {self.path}.")
|
||||
LOGGER.info("You can now use the model as follows:")
|
||||
LOGGER.info(
|
||||
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa
|
||||
)
|
||||
|
||||
def _wandb_callback(self, score, epoch, steps):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"wandb is not installed. Please install it using `pip install wandb`"
|
||||
)
|
||||
run = wandb.run or wandb.init(
|
||||
project="sbert_lancedb_finetune", name=self.run_name
|
||||
)
|
||||
run.log({"epoch": epoch, "steps": steps, "score": score})
|
||||
|
||||
@@ -1033,7 +1033,7 @@ class AsyncQueryBase(object):
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[Table.query][] method to create a query.
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
self._inner = inner
|
||||
|
||||
@@ -1041,7 +1041,10 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string. For example:
|
||||
The predicate should be supplied as an SQL query string.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> predicate = "x > 10"
|
||||
>>> predicate = "y > 0 AND y < 100"
|
||||
@@ -1112,7 +1115,8 @@ class AsyncQueryBase(object):
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use [to_batches][]
|
||||
you expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches]
|
||||
"""
|
||||
batch_iter = await self.to_batches()
|
||||
return pa.Table.from_batches(
|
||||
@@ -1123,12 +1127,13 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use [to_batches][]
|
||||
and convert each batch to pandas separately.
|
||||
This method will collect all results into memory before returning. If you
|
||||
expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
|
||||
pandas separately.
|
||||
|
||||
Example
|
||||
-------
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import asyncio
|
||||
>>> from lancedb import connect_async
|
||||
@@ -1148,7 +1153,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
Construct an AsyncQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[Table.query][] method to create a query.
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
@@ -1189,8 +1194,8 @@ class AsyncQuery(AsyncQueryBase):
|
||||
If there is only one vector column (a column whose data type is a
|
||||
fixed size list of floats) then the column does not need to be specified.
|
||||
If there is more than one vector column you must use
|
||||
[AsyncVectorQuery::column][] to specify which column you would like to
|
||||
compare with.
|
||||
[AsyncVectorQuery.column][lancedb.query.AsyncVectorQuery.column] to specify
|
||||
which column you would like to compare with.
|
||||
|
||||
If no index has been created on the vector column then a vector query
|
||||
will perform a distance comparison between the query vector and every
|
||||
@@ -1221,8 +1226,10 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
Construct an AsyncVectorQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, create
|
||||
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
|
||||
to convert to a vector query.
|
||||
a query first with [AsyncTable.query][lancedb.table.AsyncTable.query] and then
|
||||
use [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to]] to convert to
|
||||
a vector query. Or you can use
|
||||
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search]
|
||||
"""
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
@@ -1232,7 +1239,7 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
Set the vector column to query
|
||||
|
||||
This controls which column is compared to the query vector supplied in
|
||||
the call to [Query.nearest_to][].
|
||||
the call to [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to].
|
||||
|
||||
This parameter must be specified if the table has more than one column
|
||||
whose data type is a fixed-size-list of floats.
|
||||
|
||||
@@ -14,7 +14,7 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
||||
model_name : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
||||
The name of the cross encoder model to use. See the sentence transformers
|
||||
documentation for a list of available models.
|
||||
column : str, default "text"
|
||||
|
||||
@@ -1893,8 +1893,8 @@ class AsyncTable:
|
||||
An AsyncTable object is expected to be long lived and reused for multiple
|
||||
operations. AsyncTable objects will cache a certain amount of index data in memory.
|
||||
This cache will be freed when the Table is garbage collected. To eagerly free the
|
||||
cache you can call the [close][AsyncTable.close] method. Once the AsyncTable is
|
||||
closed, it cannot be used for any further operations.
|
||||
cache you can call the [close][lancedb.AsyncTable.close] method. Once the
|
||||
AsyncTable is closed, it cannot be used for any further operations.
|
||||
|
||||
An AsyncTable can also be used as a context manager, and will automatically close
|
||||
when the context is exited. Closing a table is optional. If you do not close the
|
||||
@@ -1903,13 +1903,17 @@ class AsyncTable:
|
||||
Examples
|
||||
--------
|
||||
|
||||
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
|
||||
Create using [AsyncConnection.create_table][lancedb.AsyncConnection.create_table]
|
||||
(more examples in that method's documentation).
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}])
|
||||
>>> table.head()
|
||||
>>> async def create_a_table():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... data = [{"vector": [1.1, 1.2], "b": 2}]
|
||||
... table = await db.create_table("my_table", data=data)
|
||||
... print(await table.query().limit(5).to_arrow())
|
||||
>>> import asyncio
|
||||
>>> asyncio.run(create_a_table())
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
@@ -1918,25 +1922,37 @@ class AsyncTable:
|
||||
vector: [[[1.1,1.2]]]
|
||||
b: [[2]]
|
||||
|
||||
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||
Can append new data with [AsyncTable.add()][lancedb.table.AsyncTable.add].
|
||||
|
||||
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||
>>> async def add_to_table():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... table = await db.open_table("my_table")
|
||||
... await table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||
>>> asyncio.run(add_to_table())
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
Can query the table with
|
||||
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||
>>> async def search_table_for_vector():
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... table = await db.open_table("my_table")
|
||||
... results = (
|
||||
... await table.vector_search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||
... )
|
||||
... print(results)
|
||||
>>> asyncio.run(search_table_for_vector())
|
||||
b vector _distance
|
||||
0 4 [0.5, 1.3] 0.82
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
|
||||
Search queries are much faster when an index is created. See
|
||||
[Table.create_index][lancedb.table.Table.create_index].
|
||||
[AsyncTable.create_index][lancedb.table.AsyncTable.create_index].
|
||||
"""
|
||||
|
||||
def __init__(self, table: LanceDBTable):
|
||||
"""Create a new Table object.
|
||||
"""Create a new AsyncTable object.
|
||||
|
||||
You should not create Table objects directly.
|
||||
You should not create AsyncTable objects directly.
|
||||
|
||||
Use [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
|
||||
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] to obtain
|
||||
@@ -1988,6 +2004,14 @@ class AsyncTable:
|
||||
return await self._inner.count_rows(filter)
|
||||
|
||||
def query(self) -> AsyncQuery:
|
||||
"""
|
||||
Returns an [AsyncQuery][lancedb.query.AsyncQuery] that can be used
|
||||
to search the table.
|
||||
|
||||
Use methods on the returned query to control query behavior. The query
|
||||
can be executed with methods like [to_arrow][lancedb.query.AsyncQuery.to_arrow],
|
||||
[to_pandas][lancedb.query.AsyncQuery.to_pandas] and more.
|
||||
"""
|
||||
return AsyncQuery(self._inner.query())
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
@@ -2024,20 +2048,8 @@ class AsyncTable:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index: Index
|
||||
The index to create.
|
||||
|
||||
LanceDb supports multiple types of indices. See the static methods on
|
||||
the Index class for more details.
|
||||
column: str, default None
|
||||
column: str
|
||||
The column to index.
|
||||
|
||||
When building a scalar index this must be set.
|
||||
|
||||
When building a vector index, this is optional. The default will look
|
||||
for any columns of type fixed-size-list with floating point values. If
|
||||
there is only one column of this type then it will be used. Otherwise
|
||||
an error will be returned.
|
||||
replace: bool, default True
|
||||
Whether to replace the existing index
|
||||
|
||||
@@ -2046,6 +2058,10 @@ class AsyncTable:
|
||||
that index is out of date.
|
||||
|
||||
The default is True
|
||||
config: Union[IvfPq, BTree], default None
|
||||
For advanced configuration you can specify the type of index you would
|
||||
like to create. You can also specify index-specific parameters when
|
||||
creating an index object.
|
||||
"""
|
||||
index = None
|
||||
if config is not None:
|
||||
@@ -2167,7 +2183,8 @@ class AsyncTable:
|
||||
Search the table with a given query vector.
|
||||
This is a convenience method for preparing a vector query and
|
||||
is the same thing as calling `nearestTo` on the builder returned
|
||||
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
|
||||
by `query`. Seer [nearest_to][lancedb.query.AsyncQuery.nearest_to] for more
|
||||
details.
|
||||
"""
|
||||
return self.query().nearest_to(query_vector)
|
||||
|
||||
@@ -2233,7 +2250,7 @@ class AsyncTable:
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return await self._inner.delete(where)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
@@ -2289,102 +2306,6 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
"""
|
||||
Clean up old versions of the table, freeing disk space.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages cleanup for you automatically)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
older_than: timedelta, default None
|
||||
The minimum age of the version to delete. If None, then this defaults
|
||||
to two weeks.
|
||||
delete_unverified: bool, default False
|
||||
Because they may be part of an in-progress transaction, files newer
|
||||
than 7 days old are not deleted by default. If you are sure that
|
||||
there are no in-progress transactions, then you can set this to True
|
||||
to delete all files older than `older_than`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CleanupStats
|
||||
The stats of the cleanup operation, including how many bytes were
|
||||
freed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def compact_files(self, *args, **kwargs):
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages compaction for you automatically)
|
||||
|
||||
This can be run after making several small appends to optimize the table
|
||||
for faster reads.
|
||||
|
||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||
For most cases, the default should be fine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transforms: Dict[str, str]
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
alterations : Iterable[Dict[str, Any]]
|
||||
A sequence of dictionaries, each with the following keys:
|
||||
- "path": str
|
||||
The column path to alter. For a top-level column, this is the name.
|
||||
For a nested column, this is the dot-separated path, e.g. "a.b.c".
|
||||
- "name": str, optional
|
||||
The new name of the column. If not specified, the column name is
|
||||
not changed.
|
||||
- "nullable": bool, optional
|
||||
Whether the column should be nullable. If not specified, the column
|
||||
nullability is not changed. Only non-nullable columns can be changed
|
||||
to nullable. Currently, you cannot change a nullable column to
|
||||
non-nullable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
Drop columns from the table.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns : Iterable[str]
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def version(self) -> int:
|
||||
"""
|
||||
Retrieve the version of the table
|
||||
|
||||
162
python/python/tests/docs/test_basic.py
Normal file
162
python/python/tests/docs/test_basic.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import shutil
|
||||
|
||||
# --8<-- [start:imports]
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
# --8<-- [end:imports]
|
||||
import pytest
|
||||
from numpy.random import randint, random
|
||||
|
||||
shutil.rmtree("data/sample-lancedb", ignore_errors=True)
|
||||
|
||||
|
||||
def test_quickstart():
|
||||
# --8<-- [start:connect]
|
||||
uri = "data/sample-lancedb"
|
||||
db = lancedb.connect(uri)
|
||||
# --8<-- [end:connect]
|
||||
|
||||
# --8<-- [start:create_table]
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
|
||||
# Synchronous client
|
||||
tbl = db.create_table("my_table", data=data)
|
||||
# --8<-- [end:create_table]
|
||||
|
||||
# --8<-- [start:create_table_pandas]
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
)
|
||||
# Synchronous client
|
||||
tbl = db.create_table("table_from_df", data=df)
|
||||
# --8<-- [end:create_table_pandas]
|
||||
|
||||
# --8<-- [start:create_empty_table]
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
# Synchronous client
|
||||
tbl = db.create_table("empty_table", schema=schema)
|
||||
# --8<-- [end:create_empty_table]
|
||||
# --8<-- [start:open_table]
|
||||
# Synchronous client
|
||||
tbl = db.open_table("my_table")
|
||||
# --8<-- [end:open_table]
|
||||
# --8<-- [start:table_names]
|
||||
# Synchronous client
|
||||
print(db.table_names())
|
||||
# --8<-- [end:table_names]
|
||||
# Synchronous client
|
||||
# --8<-- [start:add_data]
|
||||
# Option 1: Add a list of dicts to a table
|
||||
data = [
|
||||
{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
||||
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0},
|
||||
]
|
||||
tbl.add(data)
|
||||
|
||||
# Option 2: Add a pandas DataFrame to a table
|
||||
df = pd.DataFrame(data)
|
||||
tbl.add(data)
|
||||
# --8<-- [end:add_data]
|
||||
# --8<-- [start:vector_search]
|
||||
# Synchronous client
|
||||
tbl.search([100, 100]).limit(2).to_pandas()
|
||||
# --8<-- [end:vector_search]
|
||||
tbl.add(
|
||||
[
|
||||
{"vector": random(2), "item": "autogen", "price": randint(100)}
|
||||
for _ in range(1000)
|
||||
]
|
||||
)
|
||||
# --8<-- [start:create_index]
|
||||
# Synchronous client
|
||||
tbl.create_index(num_sub_vectors=1)
|
||||
# --8<-- [end:create_index]
|
||||
# --8<-- [start:delete_rows]
|
||||
# Synchronous client
|
||||
tbl.delete('item = "fizz"')
|
||||
# --8<-- [end:delete_rows]
|
||||
# --8<-- [start:drop_table]
|
||||
# Synchronous client
|
||||
db.drop_table("my_table")
|
||||
# --8<-- [end:drop_table]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quickstart_async():
|
||||
# --8<-- [start:connect_async]
|
||||
# LanceDb offers both a synchronous and an asynchronous client. There are still a
|
||||
# few operations that are only supported by the synchronous client (e.g. embedding
|
||||
# functions, full text search) but both APIs should soon be equivalent
|
||||
|
||||
# In this guide we will give examples of both clients. In other guides we will
|
||||
# typically only provide examples with one client or the other.
|
||||
uri = "data/sample-lancedb"
|
||||
async_db = await lancedb.connect_async(uri)
|
||||
# --8<-- [end:connect_async]
|
||||
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
|
||||
# --8<-- [start:create_table_async]
|
||||
# Asynchronous client
|
||||
async_tbl = await async_db.create_table("my_table2", data=data)
|
||||
# --8<-- [end:create_table_async]
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
)
|
||||
|
||||
# --8<-- [start:create_table_async_pandas]
|
||||
# Asynchronous client
|
||||
async_tbl = await async_db.create_table("table_from_df2", df)
|
||||
# --8<-- [end:create_table_async_pandas]
|
||||
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
# --8<-- [start:create_empty_table_async]
|
||||
# Asynchronous client
|
||||
async_tbl = await async_db.create_table("empty_table2", schema=schema)
|
||||
# --8<-- [end:create_empty_table_async]
|
||||
# --8<-- [start:open_table_async]
|
||||
# Asynchronous client
|
||||
async_tbl = await async_db.open_table("my_table2")
|
||||
# --8<-- [end:open_table_async]
|
||||
# --8<-- [start:table_names_async]
|
||||
# Asynchronous client
|
||||
print(await async_db.table_names())
|
||||
# --8<-- [end:table_names_async]
|
||||
# --8<-- [start:add_data_async]
|
||||
# Asynchronous client
|
||||
await async_tbl.add(data)
|
||||
# --8<-- [end:add_data_async]
|
||||
# Add sufficient data for training
|
||||
data = [{"vector": [x, x], "item": "filler", "price": x * x} for x in range(1000)]
|
||||
await async_tbl.add(data)
|
||||
# --8<-- [start:vector_search_async]
|
||||
# Asynchronous client
|
||||
await async_tbl.vector_search([100, 100]).limit(2).to_pandas()
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:create_index_async]
|
||||
# Asynchronous client (must specify column to index)
|
||||
await async_tbl.create_index("vector")
|
||||
# --8<-- [end:create_index_async]
|
||||
# --8<-- [start:delete_rows_async]
|
||||
# Asynchronous client
|
||||
await async_tbl.delete('item = "fizz"')
|
||||
# --8<-- [end:delete_rows_async]
|
||||
# --8<-- [start:drop_table_async]
|
||||
# Asynchronous client
|
||||
await async_db.drop_table("my_table2")
|
||||
# --8<-- [end:drop_table_async]
|
||||
45
python/python/tests/test_embedding_fine_tuning.py
Normal file
45
python/python/tests/test_embedding_fine_tuning.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_finetuning_sentence_transformers(tmp_path):
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
chunks = [
|
||||
"This is a chunk related to legal docs",
|
||||
"This is another chunk related financial docs",
|
||||
"This is a chunk related to sports docs",
|
||||
"This is another chunk related to fashion docs",
|
||||
]
|
||||
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
|
||||
for chunk in tqdm(text_chunks):
|
||||
questions = [
|
||||
"What is this chunk about?",
|
||||
"What is the main topic of this chunk?",
|
||||
]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [chunk.id]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
|
||||
assert len(ds.queries) == 8
|
||||
assert len(ds.corpus) == 4
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
|
||||
model = (
|
||||
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
|
||||
)
|
||||
res = model.evaluate(ds)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_text_chunk():
|
||||
# TODO
|
||||
pass
|
||||
@@ -137,6 +137,21 @@ impl Connection {
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<&PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_table(name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_db(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(
|
||||
self_.py(),
|
||||
async move { inner.drop_db().await.infer_error() },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
|
||||
@@ -80,6 +80,13 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete<'a>(self_: PyRef<'a, Self>, condition: String) -> PyResult<&'a PyAny> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.delete(&condition).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
updates: &PyDict,
|
||||
|
||||
Reference in New Issue
Block a user