add black to python CI (#178)

Closes #48
This commit is contained in:
Tevin Wang
2023-06-12 11:22:34 -07:00
committed by GitHub
parent 7bad676f30
commit 9b83ce3d2a
13 changed files with 86 additions and 51 deletions

View File

@@ -15,13 +15,7 @@ from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
lancedb_image = Image.debian_slim().pip_install(
"lancedb",
"langchain",
"openai",
"pandas",
"tiktoken",
"unstructured",
"tabulate"
"lancedb", "langchain", "openai", "pandas", "tiktoken", "unstructured", "tabulate"
)
stub = Stub(
@@ -34,21 +28,26 @@ docsearch = None
docs_path = Path("docs.pkl")
db_path = Path("lancedb")
def get_document_title(document):
m = str(document.metadata["source"])
title = re.findall("pandas.documentation(.*).html", m)
if title[0] is not None:
return(title[0])
return ''
return title[0]
return ""
def download_docs():
pandas_docs = requests.get("https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip")
pandas_docs = requests.get(
"https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
)
with open(Path("pandas.documentation.zip"), "wb") as f:
f.write(pandas_docs.content)
file = zipfile.ZipFile(Path("pandas.documentation.zip"))
file.extractall(path=Path("pandas_docs"))
def store_docs():
docs = []
@@ -74,6 +73,7 @@ def store_docs():
return docs
def qanda_langchain(query):
download_docs()
docs = store_docs()
@@ -85,14 +85,25 @@ def qanda_langchain(query):
documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings()
db = lancedb.connect(db_path)
table = db.create_table("pandas_docs", data=[
{"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"}
], mode="overwrite")
db = lancedb.connect(db_path)
table = db.create_table(
"pandas_docs",
data=[
{
"vector": embeddings.embed_query("Hello World"),
"text": "Hello World",
"id": "1",
}
],
mode="overwrite",
)
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever())
qa = RetrievalQA.from_chain_type(
llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()
)
return qa.run(query)
@stub.function()
@web_endpoint(method="GET")
def web(query: str):
@@ -101,6 +112,7 @@ def web(query: str):
"answer": answer,
}
@stub.function()
def cli(query: str):
answer = qanda_langchain(query)