mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -32,7 +32,9 @@ jobs:
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock
|
||||
pip install pytest pytest-mock black
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
|
||||
@@ -15,13 +15,7 @@ from langchain.llms import OpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
lancedb_image = Image.debian_slim().pip_install(
|
||||
"lancedb",
|
||||
"langchain",
|
||||
"openai",
|
||||
"pandas",
|
||||
"tiktoken",
|
||||
"unstructured",
|
||||
"tabulate"
|
||||
"lancedb", "langchain", "openai", "pandas", "tiktoken", "unstructured", "tabulate"
|
||||
)
|
||||
|
||||
stub = Stub(
|
||||
@@ -34,21 +28,26 @@ docsearch = None
|
||||
docs_path = Path("docs.pkl")
|
||||
db_path = Path("lancedb")
|
||||
|
||||
|
||||
def get_document_title(document):
|
||||
m = str(document.metadata["source"])
|
||||
title = re.findall("pandas.documentation(.*).html", m)
|
||||
if title[0] is not None:
|
||||
return(title[0])
|
||||
return ''
|
||||
return title[0]
|
||||
return ""
|
||||
|
||||
|
||||
def download_docs():
|
||||
pandas_docs = requests.get("https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip")
|
||||
pandas_docs = requests.get(
|
||||
"https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
|
||||
)
|
||||
with open(Path("pandas.documentation.zip"), "wb") as f:
|
||||
f.write(pandas_docs.content)
|
||||
|
||||
file = zipfile.ZipFile(Path("pandas.documentation.zip"))
|
||||
file.extractall(path=Path("pandas_docs"))
|
||||
|
||||
|
||||
def store_docs():
|
||||
docs = []
|
||||
|
||||
@@ -74,6 +73,7 @@ def store_docs():
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def qanda_langchain(query):
|
||||
download_docs()
|
||||
docs = store_docs()
|
||||
@@ -85,14 +85,25 @@ def qanda_langchain(query):
|
||||
documents = text_splitter.split_documents(docs)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.create_table("pandas_docs", data=[
|
||||
{"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"}
|
||||
], mode="overwrite")
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.create_table(
|
||||
"pandas_docs",
|
||||
data=[
|
||||
{
|
||||
"vector": embeddings.embed_query("Hello World"),
|
||||
"text": "Hello World",
|
||||
"id": "1",
|
||||
}
|
||||
],
|
||||
mode="overwrite",
|
||||
)
|
||||
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)
|
||||
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever())
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()
|
||||
)
|
||||
return qa.run(query)
|
||||
|
||||
|
||||
@stub.function()
|
||||
@web_endpoint(method="GET")
|
||||
def web(query: str):
|
||||
@@ -101,6 +112,7 @@ def web(query: str):
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
|
||||
@stub.function()
|
||||
def cli(query: str):
|
||||
answer = qanda_langchain(query)
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
# import lancedb so we don't have to in every example
|
||||
import lancedb
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def doctest_setup(monkeypatch, tmpdir):
|
||||
# disable color for doctests so we don't have to include
|
||||
@@ -15,6 +16,3 @@ def doctest_setup(monkeypatch, tmpdir):
|
||||
monkeypatch.setitem(os.environ, "COLUMNS", "80")
|
||||
# Work in a temporary directory
|
||||
monkeypatch.chdir(tmpdir)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
import pandas as pd
|
||||
from .exceptions import MissingValueError, MissingColumnError
|
||||
|
||||
|
||||
def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
|
||||
"""Create a Contextualizer object for the given DataFrame.
|
||||
|
||||
@@ -85,6 +86,7 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
|
||||
|
||||
class Contextualizer:
|
||||
"""Create context windows from a DataFrame. See [lancedb.context.contextualize][]."""
|
||||
|
||||
def __init__(self, raw_df):
|
||||
self._text_col = None
|
||||
self._groupby = None
|
||||
@@ -144,12 +146,16 @@ class Contextualizer:
|
||||
raise MissingColumnError(self._text_col)
|
||||
|
||||
if self._window is None or self._window < 1:
|
||||
raise MissingValueError("The value of window is None or less than 1. Specify the "
|
||||
"window size (number of rows to include in each window)")
|
||||
raise MissingValueError(
|
||||
"The value of window is None or less than 1. Specify the "
|
||||
"window size (number of rows to include in each window)"
|
||||
)
|
||||
|
||||
if self._stride is None or self._stride < 1:
|
||||
raise MissingValueError("The value of stride is None or less than 1. Specify the "
|
||||
"stride (number of rows to skip between each window)")
|
||||
raise MissingValueError(
|
||||
"The value of stride is None or less than 1. Specify the "
|
||||
"stride (number of rows to skip between each window)"
|
||||
)
|
||||
|
||||
def process_group(grp):
|
||||
# For each group, create the text rolling window
|
||||
|
||||
@@ -33,7 +33,7 @@ class LanceDBConnection:
|
||||
----------
|
||||
uri: str or Path
|
||||
The root uri of the database.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -79,16 +79,20 @@ class LanceDBConnection:
|
||||
try:
|
||||
filesystem, path = fs.FileSystem.from_uri(self.uri)
|
||||
except pa.ArrowInvalid:
|
||||
raise NotImplementedError(
|
||||
"Unsupported scheme: " + self.uri
|
||||
)
|
||||
raise NotImplementedError("Unsupported scheme: " + self.uri)
|
||||
|
||||
try:
|
||||
paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri)))
|
||||
paths = filesystem.get_file_info(
|
||||
fs.FileSelector(get_uri_location(self.uri))
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# It is ok if the file does not exist since it will be created
|
||||
paths = []
|
||||
tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance']
|
||||
tables = [
|
||||
os.path.splitext(file_info.base_name)[0]
|
||||
for file_info in paths
|
||||
if file_info.extension == "lance"
|
||||
]
|
||||
return tables
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -153,7 +157,7 @@ class LanceDBConnection:
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
@@ -175,7 +179,7 @@ class LanceDBConnection:
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a PyArrow table directly.
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def with_embeddings(
|
||||
"""Add a vector column to a table using the given embedding function.
|
||||
|
||||
The new columns will be called "vector".
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
@@ -48,7 +48,7 @@ def with_embeddings(
|
||||
Whether to show a progress bar.
|
||||
batch_size : int, default 1000
|
||||
The number of row values to pass to each call of the embedding function.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
"""Custom exception handling"""
|
||||
|
||||
|
||||
class MissingValueError(ValueError):
|
||||
"""Exception raised when a required value is missing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MissingColumnError(KeyError):
|
||||
"""
|
||||
Exception raised when a column name specified is not in
|
||||
Exception raised when a column name specified is not in
|
||||
the DataFrame object
|
||||
"""
|
||||
|
||||
def __init__(self, column_name):
|
||||
self.column_name = column_name
|
||||
|
||||
def __str__(self):
|
||||
return f"Error: Column '{self.column_name}' does not exist in the DataFrame object"
|
||||
return (
|
||||
f"Error: Column '{self.column_name}' does not exist in the DataFrame object"
|
||||
)
|
||||
|
||||
@@ -68,7 +68,7 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
|
||||
The table to index
|
||||
fields : List[str]
|
||||
List of fields to index
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
|
||||
@@ -120,7 +120,7 @@ class LanceQueryBuilder:
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
See discussion in [Querying an ANN Index][../querying-an-ann-index] for
|
||||
|
||||
@@ -99,7 +99,7 @@ class LanceTable:
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the schema of the table.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Schema
|
||||
@@ -117,14 +117,14 @@ class LanceTable:
|
||||
|
||||
def checkout(self, version: int):
|
||||
"""Checkout a version of the table. This is an in-place operation.
|
||||
|
||||
|
||||
This allows viewing previous versions of the table.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int
|
||||
The version to checkout.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -165,7 +165,7 @@ class LanceTable:
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
@@ -174,7 +174,7 @@ class LanceTable:
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table"""
|
||||
@@ -342,4 +342,6 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
|
||||
values = values.cast(pa.float32())
|
||||
list_size = len(values) / len(data)
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
|
||||
return data.set_column(data.column_names.index(vector_column_name), vector_column_name, vec_arr)
|
||||
return data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
@@ -119,4 +119,4 @@ def test_delete_table(tmp_path):
|
||||
assert db.table_names() == []
|
||||
|
||||
db.create_table("test", data=data)
|
||||
assert db.table_names() == ["test"]
|
||||
assert db.table_names() == ["test"]
|
||||
|
||||
@@ -19,6 +19,7 @@ import lancedb
|
||||
# You need to setup AWS credentials an a base path to run this test. Example
|
||||
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
(os.environ.get("TEST_S3_BASE_URL") is None),
|
||||
reason="please setup s3 base url",
|
||||
|
||||
@@ -30,12 +30,16 @@ class MockTable:
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> MockTable:
|
||||
df = pa.table({
|
||||
"vector": pa.array([[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)),
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
})
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
}
|
||||
)
|
||||
lance.write_dataset(df, tmp_path)
|
||||
return MockTable(tmp_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user