diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d12f6bff..647780ea 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -32,7 +32,9 @@ jobs: run: | pip install -e . pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 - pip install pytest pytest-mock + pip install pytest pytest-mock black + - name: Black + run: black --check --diff --no-color --quiet . - name: Run tests run: pytest -x -v --durations=30 tests - name: doctest diff --git a/docs/src/examples/modal_langchain.py b/docs/src/examples/modal_langchain.py index 60a11629..f929b6d4 100644 --- a/docs/src/examples/modal_langchain.py +++ b/docs/src/examples/modal_langchain.py @@ -15,13 +15,7 @@ from langchain.llms import OpenAI from langchain.chains import RetrievalQA lancedb_image = Image.debian_slim().pip_install( - "lancedb", - "langchain", - "openai", - "pandas", - "tiktoken", - "unstructured", - "tabulate" + "lancedb", "langchain", "openai", "pandas", "tiktoken", "unstructured", "tabulate" ) stub = Stub( @@ -34,21 +28,26 @@ docsearch = None docs_path = Path("docs.pkl") db_path = Path("lancedb") + def get_document_title(document): m = str(document.metadata["source"]) title = re.findall("pandas.documentation(.*).html", m) if title[0] is not None: - return(title[0]) - return '' + return title[0] + return "" + def download_docs(): - pandas_docs = requests.get("https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip") + pandas_docs = requests.get( + "https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip" + ) with open(Path("pandas.documentation.zip"), "wb") as f: f.write(pandas_docs.content) file = zipfile.ZipFile(Path("pandas.documentation.zip")) file.extractall(path=Path("pandas_docs")) + def store_docs(): docs = [] @@ -74,6 +73,7 @@ def store_docs(): return docs + def qanda_langchain(query): download_docs() docs = store_docs() @@ -85,14 +85,25 @@ def qanda_langchain(query): documents = text_splitter.split_documents(docs) embeddings = OpenAIEmbeddings() - db = lancedb.connect(db_path) - table = db.create_table("pandas_docs", data=[ - {"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"} - ], mode="overwrite") + db = lancedb.connect(db_path) + table = db.create_table( + "pandas_docs", + data=[ + { + "vector": embeddings.embed_query("Hello World"), + "text": "Hello World", + "id": "1", + } + ], + mode="overwrite", + ) docsearch = LanceDB.from_documents(documents, embeddings, connection=table) - qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()) + qa = RetrievalQA.from_chain_type( + llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever() + ) return qa.run(query) + @stub.function() @web_endpoint(method="GET") def web(query: str): @@ -101,6 +112,7 @@ def web(query: str): "answer": answer, } + @stub.function() def cli(query: str): answer = qanda_langchain(query) diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index 50c42cfc..ea0e2300 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -6,6 +6,7 @@ import pytest # import lancedb so we don't have to in every example import lancedb + @pytest.fixture(autouse=True) def doctest_setup(monkeypatch, tmpdir): # disable color for doctests so we don't have to include @@ -15,6 +16,3 @@ def doctest_setup(monkeypatch, tmpdir): monkeypatch.setitem(os.environ, "COLUMNS", "80") # Work in a temporary directory monkeypatch.chdir(tmpdir) - - - diff --git a/python/lancedb/context.py b/python/lancedb/context.py index dee4180c..f1d634d1 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -15,6 +15,7 @@ from __future__ import annotations import pandas as pd from .exceptions import MissingValueError, MissingColumnError + def contextualize(raw_df: pd.DataFrame) -> Contextualizer: """Create a Contextualizer object for the given DataFrame. @@ -85,6 +86,7 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer: class Contextualizer: """Create context windows from a DataFrame. See [lancedb.context.contextualize][].""" + def __init__(self, raw_df): self._text_col = None self._groupby = None @@ -144,12 +146,16 @@ class Contextualizer: raise MissingColumnError(self._text_col) if self._window is None or self._window < 1: - raise MissingValueError("The value of window is None or less than 1. Specify the " - "window size (number of rows to include in each window)") + raise MissingValueError( + "The value of window is None or less than 1. Specify the " + "window size (number of rows to include in each window)" + ) if self._stride is None or self._stride < 1: - raise MissingValueError("The value of stride is None or less than 1. Specify the " - "stride (number of rows to skip between each window)") + raise MissingValueError( + "The value of stride is None or less than 1. Specify the " + "stride (number of rows to skip between each window)" + ) def process_group(grp): # For each group, create the text rolling window diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 18c1e9d4..f48f0130 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -33,7 +33,7 @@ class LanceDBConnection: ---------- uri: str or Path The root uri of the database. - + Examples -------- >>> import lancedb @@ -79,16 +79,20 @@ class LanceDBConnection: try: filesystem, path = fs.FileSystem.from_uri(self.uri) except pa.ArrowInvalid: - raise NotImplementedError( - "Unsupported scheme: " + self.uri - ) + raise NotImplementedError("Unsupported scheme: " + self.uri) try: - paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri))) + paths = filesystem.get_file_info( + fs.FileSelector(get_uri_location(self.uri)) + ) except FileNotFoundError: # It is ok if the file does not exist since it will be created paths = [] - tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance'] + tables = [ + os.path.splitext(file_info.base_name)[0] + for file_info in paths + if file_info.extension == "lance" + ] return tables def __len__(self) -> int: @@ -153,7 +157,7 @@ class LanceDBConnection: vector: [[[1.1,1.2],[0.2,1.8]]] lat: [[45.5,40.1]] long: [[-122.7,-74.1]] - + You can also pass a pandas DataFrame: >>> import pandas as pd @@ -175,7 +179,7 @@ class LanceDBConnection: lat: [[45.5,40.1]] long: [[-122.7,-74.1]] - Data is converted to Arrow before being written to disk. For maximum + Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow table directly. diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index 0d37f6e4..03568101 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -33,7 +33,7 @@ def with_embeddings( """Add a vector column to a table using the given embedding function. The new columns will be called "vector". - + Parameters ---------- func : Callable @@ -48,7 +48,7 @@ def with_embeddings( Whether to show a progress bar. batch_size : int, default 1000 The number of row values to pass to each call of the embedding function. - + Returns ------- pa.Table diff --git a/python/lancedb/exceptions.py b/python/lancedb/exceptions.py index 10def337..0dddb13d 100644 --- a/python/lancedb/exceptions.py +++ b/python/lancedb/exceptions.py @@ -1,16 +1,22 @@ """Custom exception handling""" + class MissingValueError(ValueError): """Exception raised when a required value is missing.""" + pass + class MissingColumnError(KeyError): """ - Exception raised when a column name specified is not in + Exception raised when a column name specified is not in the DataFrame object """ + def __init__(self, column_name): self.column_name = column_name def __str__(self): - return f"Error: Column '{self.column_name}' does not exist in the DataFrame object" + return ( + f"Error: Column '{self.column_name}' does not exist in the DataFrame object" + ) diff --git a/python/lancedb/fts.py b/python/lancedb/fts.py index 8824ec1d..fe3895d7 100644 --- a/python/lancedb/fts.py +++ b/python/lancedb/fts.py @@ -68,7 +68,7 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) - The table to index fields : List[str] List of fields to index - + Returns ------- int diff --git a/python/lancedb/query.py b/python/lancedb/query.py index b5722779..2ad4b85b 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -120,7 +120,7 @@ class LanceQueryBuilder: def nprobes(self, nprobes: int) -> LanceQueryBuilder: """Set the number of probes to use. - Higher values will yield better recall (more likely to find vectors if + Higher values will yield better recall (more likely to find vectors if they exist) at the expense of latency. See discussion in [Querying an ANN Index][../querying-an-ann-index] for diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 07e212a4..af794470 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -99,7 +99,7 @@ class LanceTable: @property def schema(self) -> pa.Schema: """Return the schema of the table. - + Returns ------- pa.Schema @@ -117,14 +117,14 @@ class LanceTable: def checkout(self, version: int): """Checkout a version of the table. This is an in-place operation. - + This allows viewing previous versions of the table. - + Parameters ---------- version : int The version to checkout. - + Examples -------- >>> import lancedb @@ -165,7 +165,7 @@ class LanceTable: def to_pandas(self) -> pd.DataFrame: """Return the table as a pandas DataFrame. - + Returns ------- pd.DataFrame @@ -174,7 +174,7 @@ class LanceTable: def to_arrow(self) -> pa.Table: """Return the table as a pyarrow Table. - + Returns ------- pa.Table""" @@ -342,4 +342,6 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table values = values.cast(pa.float32()) list_size = len(values) / len(data) vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) - return data.set_column(data.column_names.index(vector_column_name), vector_column_name, vec_arr) + return data.set_column( + data.column_names.index(vector_column_name), vector_column_name, vec_arr + ) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 1ec39dd9..f0fc5bb4 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -119,4 +119,4 @@ def test_delete_table(tmp_path): assert db.table_names() == [] db.create_table("test", data=data) - assert db.table_names() == ["test"] \ No newline at end of file + assert db.table_names() == ["test"] diff --git a/python/tests/test_io.py b/python/tests/test_io.py index f05d8154..d6b092bb 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -19,6 +19,7 @@ import lancedb # You need to setup AWS credentials an a base path to run this test. Example # AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py + @pytest.mark.skipif( (os.environ.get("TEST_S3_BASE_URL") is None), reason="please setup s3 base url", diff --git a/python/tests/test_query.py b/python/tests/test_query.py index d069d85c..675c3772 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -30,12 +30,16 @@ class MockTable: @pytest.fixture def table(tmp_path) -> MockTable: - df = pa.table({ - "vector": pa.array([[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)), - "id": pa.array([1, 2]), - "str_field": pa.array(["a", "b"]), - "float_field": pa.array([1.0, 2.0]), - }) + df = pa.table( + { + "vector": pa.array( + [[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2) + ), + "id": pa.array([1, 2]), + "str_field": pa.array(["a", "b"]), + "float_field": pa.array([1.0, 2.0]), + } + ) lance.write_dataset(df, tmp_path) return MockTable(tmp_path)