add black to python CI (#178)

Closes #48
This commit is contained in:
Tevin Wang
2023-06-12 11:22:34 -07:00
committed by GitHub
parent 7bad676f30
commit 9b83ce3d2a
13 changed files with 86 additions and 51 deletions

View File

@@ -32,7 +32,9 @@ jobs:
run: |
pip install -e .
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock
pip install pytest pytest-mock black
- name: Black
run: black --check --diff --no-color --quiet .
- name: Run tests
run: pytest -x -v --durations=30 tests
- name: doctest

View File

@@ -15,13 +15,7 @@ from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
lancedb_image = Image.debian_slim().pip_install(
"lancedb",
"langchain",
"openai",
"pandas",
"tiktoken",
"unstructured",
"tabulate"
"lancedb", "langchain", "openai", "pandas", "tiktoken", "unstructured", "tabulate"
)
stub = Stub(
@@ -34,21 +28,26 @@ docsearch = None
docs_path = Path("docs.pkl")
db_path = Path("lancedb")
def get_document_title(document):
m = str(document.metadata["source"])
title = re.findall("pandas.documentation(.*).html", m)
if title[0] is not None:
return(title[0])
return ''
return title[0]
return ""
def download_docs():
pandas_docs = requests.get("https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip")
pandas_docs = requests.get(
"https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
)
with open(Path("pandas.documentation.zip"), "wb") as f:
f.write(pandas_docs.content)
file = zipfile.ZipFile(Path("pandas.documentation.zip"))
file.extractall(path=Path("pandas_docs"))
def store_docs():
docs = []
@@ -74,6 +73,7 @@ def store_docs():
return docs
def qanda_langchain(query):
download_docs()
docs = store_docs()
@@ -85,14 +85,25 @@ def qanda_langchain(query):
documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings()
db = lancedb.connect(db_path)
table = db.create_table("pandas_docs", data=[
{"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"}
], mode="overwrite")
db = lancedb.connect(db_path)
table = db.create_table(
"pandas_docs",
data=[
{
"vector": embeddings.embed_query("Hello World"),
"text": "Hello World",
"id": "1",
}
],
mode="overwrite",
)
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever())
qa = RetrievalQA.from_chain_type(
llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()
)
return qa.run(query)
@stub.function()
@web_endpoint(method="GET")
def web(query: str):
@@ -101,6 +112,7 @@ def web(query: str):
"answer": answer,
}
@stub.function()
def cli(query: str):
answer = qanda_langchain(query)

View File

@@ -6,6 +6,7 @@ import pytest
# import lancedb so we don't have to in every example
import lancedb
@pytest.fixture(autouse=True)
def doctest_setup(monkeypatch, tmpdir):
# disable color for doctests so we don't have to include
@@ -15,6 +16,3 @@ def doctest_setup(monkeypatch, tmpdir):
monkeypatch.setitem(os.environ, "COLUMNS", "80")
# Work in a temporary directory
monkeypatch.chdir(tmpdir)

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
import pandas as pd
from .exceptions import MissingValueError, MissingColumnError
def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
"""Create a Contextualizer object for the given DataFrame.
@@ -85,6 +86,7 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
class Contextualizer:
"""Create context windows from a DataFrame. See [lancedb.context.contextualize][]."""
def __init__(self, raw_df):
self._text_col = None
self._groupby = None
@@ -144,12 +146,16 @@ class Contextualizer:
raise MissingColumnError(self._text_col)
if self._window is None or self._window < 1:
raise MissingValueError("The value of window is None or less than 1. Specify the "
"window size (number of rows to include in each window)")
raise MissingValueError(
"The value of window is None or less than 1. Specify the "
"window size (number of rows to include in each window)"
)
if self._stride is None or self._stride < 1:
raise MissingValueError("The value of stride is None or less than 1. Specify the "
"stride (number of rows to skip between each window)")
raise MissingValueError(
"The value of stride is None or less than 1. Specify the "
"stride (number of rows to skip between each window)"
)
def process_group(grp):
# For each group, create the text rolling window

View File

@@ -33,7 +33,7 @@ class LanceDBConnection:
----------
uri: str or Path
The root uri of the database.
Examples
--------
>>> import lancedb
@@ -79,16 +79,20 @@ class LanceDBConnection:
try:
filesystem, path = fs.FileSystem.from_uri(self.uri)
except pa.ArrowInvalid:
raise NotImplementedError(
"Unsupported scheme: " + self.uri
)
raise NotImplementedError("Unsupported scheme: " + self.uri)
try:
paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri)))
paths = filesystem.get_file_info(
fs.FileSelector(get_uri_location(self.uri))
)
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance']
tables = [
os.path.splitext(file_info.base_name)[0]
for file_info in paths
if file_info.extension == "lance"
]
return tables
def __len__(self) -> int:
@@ -153,7 +157,7 @@ class LanceDBConnection:
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
You can also pass a pandas DataFrame:
>>> import pandas as pd
@@ -175,7 +179,7 @@ class LanceDBConnection:
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
Data is converted to Arrow before being written to disk. For maximum
Data is converted to Arrow before being written to disk. For maximum
control over how data is saved, either provide the PyArrow schema to
convert to or else provide a PyArrow table directly.

View File

@@ -33,7 +33,7 @@ def with_embeddings(
"""Add a vector column to a table using the given embedding function.
The new columns will be called "vector".
Parameters
----------
func : Callable
@@ -48,7 +48,7 @@ def with_embeddings(
Whether to show a progress bar.
batch_size : int, default 1000
The number of row values to pass to each call of the embedding function.
Returns
-------
pa.Table

View File

@@ -1,16 +1,22 @@
"""Custom exception handling"""
class MissingValueError(ValueError):
"""Exception raised when a required value is missing."""
pass
class MissingColumnError(KeyError):
"""
Exception raised when a column name specified is not in
Exception raised when a column name specified is not in
the DataFrame object
"""
def __init__(self, column_name):
self.column_name = column_name
def __str__(self):
return f"Error: Column '{self.column_name}' does not exist in the DataFrame object"
return (
f"Error: Column '{self.column_name}' does not exist in the DataFrame object"
)

View File

@@ -68,7 +68,7 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
The table to index
fields : List[str]
List of fields to index
Returns
-------
int

View File

@@ -120,7 +120,7 @@ class LanceQueryBuilder:
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
"""Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
See discussion in [Querying an ANN Index][../querying-an-ann-index] for

View File

@@ -99,7 +99,7 @@ class LanceTable:
@property
def schema(self) -> pa.Schema:
"""Return the schema of the table.
Returns
-------
pa.Schema
@@ -117,14 +117,14 @@ class LanceTable:
def checkout(self, version: int):
"""Checkout a version of the table. This is an in-place operation.
This allows viewing previous versions of the table.
Parameters
----------
version : int
The version to checkout.
Examples
--------
>>> import lancedb
@@ -165,7 +165,7 @@ class LanceTable:
def to_pandas(self) -> pd.DataFrame:
"""Return the table as a pandas DataFrame.
Returns
-------
pd.DataFrame
@@ -174,7 +174,7 @@ class LanceTable:
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
Returns
-------
pa.Table"""
@@ -342,4 +342,6 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
values = values.cast(pa.float32())
list_size = len(values) / len(data)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(data.column_names.index(vector_column_name), vector_column_name, vec_arr)
return data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)

View File

@@ -119,4 +119,4 @@ def test_delete_table(tmp_path):
assert db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert db.table_names() == ["test"]

View File

@@ -19,6 +19,7 @@ import lancedb
# You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
@pytest.mark.skipif(
(os.environ.get("TEST_S3_BASE_URL") is None),
reason="please setup s3 base url",

View File

@@ -30,12 +30,16 @@ class MockTable:
@pytest.fixture
def table(tmp_path) -> MockTable:
df = pa.table({
"vector": pa.array([[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
})
df = pa.table(
{
"vector": pa.array(
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
}
)
lance.write_dataset(df, tmp_path)
return MockTable(tmp_path)