diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3af5e096..d12f6bff 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -35,6 +35,8 @@ jobs: pip install pytest pytest-mock - name: Run tests run: pytest -x -v --durations=30 tests + - name: doctest + run: pytest --doctest-modules lancedb mac: timeout-minutes: 30 runs-on: "macos-12" diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index d849de4a..71f8e608 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -14,10 +14,24 @@ theme: plugins: - search +- autorefs - mkdocstrings: handlers: python: paths: [../python] + selection: + docstring_style: numpy + rendering: + heading_level: 4 + show_source: false + show_symbol_type_in_heading: true + show_signature_annotations: true + show_root_heading: true + members_order: source + import: + # for cross references + - https://arrow.apache.org/docs/objects.inv + - https://pandas.pydata.org/docs/objects.inv - mkdocs-jupyter markdown_extensions: diff --git a/docs/src/index.md b/docs/src/index.md index a025e96e..384dada3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -68,4 +68,5 @@ LanceDB's core is written in Rust 🦀 and is built using >> db = lancedb.connect("s3://my-bucket/lancedb") + Returns ------- - A connection to a LanceDB database. + conn : LanceDBConnection + A connection to a LanceDB database. """ return LanceDBConnection(uri) diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py new file mode 100644 index 00000000..50c42cfc --- /dev/null +++ b/python/lancedb/conftest.py @@ -0,0 +1,20 @@ +import builtins +import os + +import pytest + +# import lancedb so we don't have to in every example +import lancedb + +@pytest.fixture(autouse=True) +def doctest_setup(monkeypatch, tmpdir): + # disable color for doctests so we don't have to include + # escape codes in docstrings + monkeypatch.setitem(os.environ, "NO_COLOR", "1") + # Explicitly set the column width + monkeypatch.setitem(os.environ, "COLUMNS", "80") + # Work in a temporary directory + monkeypatch.chdir(tmpdir) + + + diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 25090195..83d10771 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -17,12 +17,74 @@ import pandas as pd def contextualize(raw_df: pd.DataFrame) -> Contextualizer: """Create a Contextualizer object for the given DataFrame. - Used to create context windows. + + Used to create context windows. Context windows are rolling subsets of text + data. + + The input text column should already be separated into rows that will be the + unit of the window. So to create a context window over tokens, start with + a DataFrame with one token per row. To create a context window over sentences, + start with a DataFrame with one sentence per row. + + Examples + -------- + >>> from lancedb.context import contextualize + >>> import pandas as pd + >>> data = pd.DataFrame({ + ... 'token': ['The', 'quick', 'brown', 'fox', 'jumped', 'over', + ... 'the', 'lazy', 'dog', 'I', 'love', 'sandwiches'], + ... 'document_id': [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2] + ... }) + + ``window`` determines how many rows to include in each window. In our case + this how many tokens, but depending on the input data, it could be sentences, + paragraphs, messages, etc. + + >>> contextualize(data).window(3).stride(1).text_col('token').to_df() + token document_id + 0 The quick brown 1 + 1 quick brown fox 1 + 2 brown fox jumped 1 + 3 fox jumped over 1 + 4 jumped over the 1 + 5 over the lazy 1 + 6 the lazy dog 1 + 7 lazy dog I 1 + 8 dog I love 1 + >>> contextualize(data).window(7).stride(1).text_col('token').to_df() + token document_id + 0 The quick brown fox jumped over the 1 + 1 quick brown fox jumped over the lazy 1 + 2 brown fox jumped over the lazy dog 1 + 3 fox jumped over the lazy dog I 1 + 4 jumped over the lazy dog I love 1 + + + ``stride`` determines how many rows to skip between each window start. This can + be used to reduce the total number of windows generated. + + >>> contextualize(data).window(4).stride(2).text_col('token').to_df() + token document_id + 0 The quick brown fox 1 + 2 brown fox jumped over 1 + 4 jumped over the lazy 1 + 6 the lazy dog I 1 + + ``groupby`` determines how to group the rows. For example, we would like to have + context windows that don't cross document boundaries. In this case, we can + pass ``document_id`` as the group by. + + >>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_df() + token document_id + 0 The quick brown fox 1 + 2 brown fox jumped over 1 + 4 jumped over the lazy 1 """ return Contextualizer(raw_df) class Contextualizer: + """Create context windows from a DataFrame. See [lancedb.context.contextualize][].""" def __init__(self, raw_df): self._text_col = None self._groupby = None diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 4f380e28..18c1e9d4 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -28,6 +28,31 @@ from .util import get_uri_scheme, get_uri_location class LanceDBConnection: """ A connection to a LanceDB database. + + Parameters + ---------- + uri: str or Path + The root uri of the database. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}, + ... {"vector": [0.5, 1.3], "b": 4}]) + LanceTable(my_table) + >>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}]) + LanceTable(another_table) + >>> db.table_names() + ['another_table', 'my_table'] + >>> len(db) + 2 + >>> db["my_table"] + LanceTable(my_table) + >>> "my_table" in db + True + >>> db.drop_table("my_table") + >>> db.drop_table("another_table") """ def __init__(self, uri: URI): @@ -48,7 +73,8 @@ class LanceDBConnection: Returns ------- - A list of table names. + list of str + A list of table names. """ try: filesystem, path = fs.FileSystem.from_uri(self.uri) @@ -103,7 +129,73 @@ class LanceDBConnection: Returns ------- - A LanceTable object representing the table. + LanceTable + A reference to the newly created table. + + Examples + -------- + + Can create with list of tuples or dictionaries: + + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, + ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] + >>> db.create_table("my_table", data) + LanceTable(my_table) + >>> db["my_table"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: double + long: double + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] + + You can also pass a pandas DataFrame: + + >>> import pandas as pd + >>> data = pd.DataFrame({ + ... "vector": [[1.1, 1.2], [0.2, 1.8]], + ... "lat": [45.5, 40.1], + ... "long": [-122.7, -74.1] + ... }) + >>> db.create_table("table2", data) + LanceTable(table2) + >>> db["table2"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: double + long: double + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] + + Data is converted to Arrow before being written to disk. For maximum + control over how data is saved, either provide the PyArrow schema to + convert to or else provide a PyArrow table directly. + + >>> custom_schema = pa.schema([ + ... pa.field("vector", pa.list_(pa.float32(), 2)), + ... pa.field("lat", pa.float32()), + ... pa.field("long", pa.float32()) + ... ]) + >>> db.create_table("table3", data, schema = custom_schema) + LanceTable(table3) + >>> db["table3"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: float + long: float + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] """ if data is not None: tbl = LanceTable.create(self, name, data, schema, mode=mode) diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index 28d63286..0d37f6e4 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -29,7 +29,31 @@ def with_embeddings( wrap_api: bool = True, show_progress: bool = False, batch_size: int = 1000, -): +) -> pa.Table: + """Add a vector column to a table using the given embedding function. + + The new columns will be called "vector". + + Parameters + ---------- + func : Callable + A function that takes a list of strings and returns a list of vectors. + data : pa.Table or pd.DataFrame + The data to add an embedding column to. + column : str, default "text" + The name of the column to use as input to the embedding function. + wrap_api : bool, default True + Whether to wrap the embedding function in a retry and rate limiter. + show_progress : bool, default False + Whether to show a progress bar. + batch_size : int, default 1000 + The number of row values to pass to each call of the embedding function. + + Returns + ------- + pa.Table + The input table with a new column called "vector" containing the embeddings. + """ func = EmbeddingFunction(func) if wrap_api: func = func.retry().rate_limit() diff --git a/python/lancedb/fts.py b/python/lancedb/fts.py index 259d9d80..8824ec1d 100644 --- a/python/lancedb/fts.py +++ b/python/lancedb/fts.py @@ -68,6 +68,11 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) - The table to index fields : List[str] List of fields to index + + Returns + ------- + int + The number of rows indexed """ # first check the fields exist and are string or large string type for name in fields: diff --git a/python/lancedb/query.py b/python/lancedb/query.py index defe744f..b5722779 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations +from typing import Literal import numpy as np import pandas as pd @@ -22,6 +23,24 @@ from .common import VECTOR_COLUMN_NAME class LanceQueryBuilder: """ A builder for nearest neighbor queries for LanceDB. + + Examples + -------- + >>> import lancedb + >>> data = [{"vector": [1.1, 1.2], "b": 2}, + ... {"vector": [0.5, 1.3], "b": 4}, + ... {"vector": [0.4, 0.4], "b": 6}, + ... {"vector": [0.4, 0.4], "b": 10}] + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data=data) + >>> (table.search([0.4, 0.4]) + ... .metric("cosine") + ... .where("b < 10") + ... .select(["b"]) + ... .limit(2) + ... .to_df()) + b vector score + 0 6 [0.4, 0.4] 0.0 """ def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): @@ -44,7 +63,8 @@ class LanceQueryBuilder: Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._limit = limit return self @@ -59,7 +79,8 @@ class LanceQueryBuilder: Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._columns = columns return self @@ -74,22 +95,24 @@ class LanceQueryBuilder: Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._where = where return self - def metric(self, metric: str) -> LanceQueryBuilder: + def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder: """Set the distance metric to use. Parameters ---------- - metric: str - The distance metric to use. By default "l2" is used. + metric: "L2" or "cosine" + The distance metric to use. By default "L2" is used. Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._metric = metric return self @@ -97,6 +120,12 @@ class LanceQueryBuilder: def nprobes(self, nprobes: int) -> LanceQueryBuilder: """Set the number of probes to use. + Higher values will yield better recall (more likely to find vectors if + they exist) at the expense of latency. + + See discussion in [Querying an ANN Index][../querying-an-ann-index] for + tuning advice. + Parameters ---------- nprobes: int @@ -104,13 +133,20 @@ class LanceQueryBuilder: Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._nprobes = nprobes return self def refine_factor(self, refine_factor: int) -> LanceQueryBuilder: - """Set the refine factor to use. + """Set the refine factor to use, increasing the number of vectors sampled. + + As an example, a refine factor of 2 will sample 2x as many vectors as + requested, re-ranks them, and returns the top half most relevant results. + + See discussion in [Querying an ANN Index][querying-an-ann-index] for + tuning advice. Parameters ---------- @@ -119,7 +155,8 @@ class LanceQueryBuilder: Returns ------- - The LanceQueryBuilder object. + LanceQueryBuilder + The LanceQueryBuilder object. """ self._refine_factor = refine_factor return self diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 57375a9d..07e212a4 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -47,6 +47,40 @@ def _sanitize_data(data, schema): class LanceTable: """ A table in a LanceDB database. + + Examples + -------- + + Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table] + (more examples in that method's documentation). + + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}]) + >>> table.head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + b: int64 + ---- + vector: [[[1.1,1.2]]] + b: [[2]] + + Can append new data with [LanceTable.add][lancedb.table.LanceTable.add]. + + >>> table.add([{"vector": [0.5, 1.3], "b": 4}]) + 2 + + Can query the table with [LanceTable.search][lancedb.table.LanceTable.search]. + + >>> table.search([0.4, 0.4]).select(["b"]).to_df() + b vector score + 0 4 [0.5, 1.3] 0.82 + 1 2 [1.1, 1.2] 1.13 + + Search queries are much faster when an index is created. See + [LanceTable.create_index][lancedb.table.LanceTable.create_index]. + """ def __init__( @@ -64,7 +98,12 @@ class LanceTable: @property def schema(self) -> pa.Schema: - """Return the schema of the table.""" + """Return the schema of the table. + + Returns + ------- + pa.Schema + A PyArrow schema object.""" return self._dataset.schema def list_versions(self): @@ -72,12 +111,39 @@ class LanceTable: return self._dataset.versions() @property - def version(self): + def version(self) -> int: """Get the current version of the table""" return self._dataset.version def checkout(self, version: int): - """Checkout a version of the table""" + """Checkout a version of the table. This is an in-place operation. + + This allows viewing previous versions of the table. + + Parameters + ---------- + version : int + The version to checkout. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) + >>> table.version + 1 + >>> table.to_pandas() + vector type + 0 [1.1, 0.9] vector + >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + 2 + >>> table.version + 2 + >>> table.checkout(1) + >>> table.to_pandas() + vector type + 0 [1.1, 0.9] vector + """ max_ver = max([v["version"] for v in self._dataset.versions()]) if version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") @@ -98,11 +164,20 @@ class LanceTable: return self._dataset.head(n) def to_pandas(self) -> pd.DataFrame: - """Return the table as a pandas DataFrame.""" + """Return the table as a pandas DataFrame. + + Returns + ------- + pd.DataFrame + """ return self.to_arrow().to_pandas() def to_arrow(self) -> pa.Table: - """Return the table as a pyarrow Table.""" + """Return the table as a pyarrow Table. + + Returns + ------- + pa.Table""" return self._dataset.to_table() @property @@ -175,7 +250,8 @@ class LanceTable: Returns ------- - The number of vectors added to the table. + int + The number of vectors in the table. """ data = _sanitize_data(data, self.schema) lance.write_dataset(data, self._dataset_uri, mode=mode) @@ -193,10 +269,11 @@ class LanceTable: Returns ------- - A LanceQueryBuilder object representing the query. - Once executed, the query returns selected columns, the vector, - and also the "score" column which is the distance between the query - vector and the returned vector. + LanceQueryBuilder + A query builder object representing the query. + Once executed, the query returns selected columns, the vector, + and also the "score" column which is the distance between the query + vector and the returned vector. """ if isinstance(query, str): # fts diff --git a/python/pyproject.toml b/python/pyproject.toml index 44908923..f0b15c42 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,7 +37,7 @@ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] tests = [ - "pytest", "pytest-mock" + "pytest", "pytest-mock", "doctest" ] dev = [ "ruff", "pre-commit", "black" diff --git a/python/tests/test_query.py b/python/tests/test_query.py index f7f0caa0..d069d85c 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -30,23 +30,13 @@ class MockTable: @pytest.fixture def table(tmp_path) -> MockTable: - df = pd.DataFrame( - { - "vector": [[1, 2], [3, 4]], - "id": [1, 2], - "str_field": ["a", "b"], - "float_field": [1.0, 2.0], - } - ) - schema = pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), list_size=2)), - pa.field("id", pa.int32()), - pa.field("str_field", pa.string()), - pa.field("float_field", pa.float64()), - ] - ) - lance.write_dataset(df, tmp_path, schema) + df = pa.table({ + "vector": pa.array([[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)), + "id": pa.array([1, 2]), + "str_field": pa.array(["a", "b"]), + "float_field": pa.array([1.0, 2.0]), + }) + lance.write_dataset(df, tmp_path) return MockTable(tmp_path) @@ -65,7 +55,7 @@ def test_query_builder_with_filter(table): def test_query_builder_with_metric(table): query = [4, 8] df_default = LanceQueryBuilder(table, query).to_df() - df_l2 = LanceQueryBuilder(table, query).metric("l2").to_df() + df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df() tm.assert_frame_equal(df_default, df_l2) df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df()