mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
@@ -15,13 +15,7 @@ from langchain.llms import OpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
lancedb_image = Image.debian_slim().pip_install(
|
||||
"lancedb",
|
||||
"langchain",
|
||||
"openai",
|
||||
"pandas",
|
||||
"tiktoken",
|
||||
"unstructured",
|
||||
"tabulate"
|
||||
"lancedb", "langchain", "openai", "pandas", "tiktoken", "unstructured", "tabulate"
|
||||
)
|
||||
|
||||
stub = Stub(
|
||||
@@ -34,21 +28,26 @@ docsearch = None
|
||||
docs_path = Path("docs.pkl")
|
||||
db_path = Path("lancedb")
|
||||
|
||||
|
||||
def get_document_title(document):
|
||||
m = str(document.metadata["source"])
|
||||
title = re.findall("pandas.documentation(.*).html", m)
|
||||
if title[0] is not None:
|
||||
return(title[0])
|
||||
return ''
|
||||
return title[0]
|
||||
return ""
|
||||
|
||||
|
||||
def download_docs():
|
||||
pandas_docs = requests.get("https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip")
|
||||
pandas_docs = requests.get(
|
||||
"https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
|
||||
)
|
||||
with open(Path("pandas.documentation.zip"), "wb") as f:
|
||||
f.write(pandas_docs.content)
|
||||
|
||||
file = zipfile.ZipFile(Path("pandas.documentation.zip"))
|
||||
file.extractall(path=Path("pandas_docs"))
|
||||
|
||||
|
||||
def store_docs():
|
||||
docs = []
|
||||
|
||||
@@ -74,6 +73,7 @@ def store_docs():
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def qanda_langchain(query):
|
||||
download_docs()
|
||||
docs = store_docs()
|
||||
@@ -85,14 +85,25 @@ def qanda_langchain(query):
|
||||
documents = text_splitter.split_documents(docs)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.create_table("pandas_docs", data=[
|
||||
{"vector": embeddings.embed_query("Hello World"), "text": "Hello World", "id": "1"}
|
||||
], mode="overwrite")
|
||||
db = lancedb.connect(db_path)
|
||||
table = db.create_table(
|
||||
"pandas_docs",
|
||||
data=[
|
||||
{
|
||||
"vector": embeddings.embed_query("Hello World"),
|
||||
"text": "Hello World",
|
||||
"id": "1",
|
||||
}
|
||||
],
|
||||
mode="overwrite",
|
||||
)
|
||||
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)
|
||||
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever())
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()
|
||||
)
|
||||
return qa.run(query)
|
||||
|
||||
|
||||
@stub.function()
|
||||
@web_endpoint(method="GET")
|
||||
def web(query: str):
|
||||
@@ -101,6 +112,7 @@ def web(query: str):
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
|
||||
@stub.function()
|
||||
def cli(query: str):
|
||||
answer = qanda_langchain(query)
|
||||
|
||||
Reference in New Issue
Block a user