diff --git a/python/lancedb/context.py b/python/lancedb/context.py index f1d634d1..a90b8f70 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -42,34 +42,38 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer: paragraphs, messages, etc. >>> contextualize(data).window(3).stride(1).text_col('token').to_df() - token document_id - 0 The quick brown 1 - 1 quick brown fox 1 - 2 brown fox jumped 1 - 3 fox jumped over 1 - 4 jumped over the 1 - 5 over the lazy 1 - 6 the lazy dog 1 - 7 lazy dog I 1 - 8 dog I love 1 - >>> contextualize(data).window(7).stride(1).text_col('token').to_df() + token document_id + 0 The quick brown 1 + 1 quick brown fox 1 + 2 brown fox jumped 1 + 3 fox jumped over 1 + 4 jumped over the 1 + 5 over the lazy 1 + 6 the lazy dog 1 + 7 lazy dog I 1 + 8 dog I love 1 + 9 I love sandwiches 2 + 10 love sandwiches 2 + >>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_df() token document_id 0 The quick brown fox jumped over the 1 1 quick brown fox jumped over the lazy 1 2 brown fox jumped over the lazy dog 1 3 fox jumped over the lazy dog I 1 4 jumped over the lazy dog I love 1 - + 5 over the lazy dog I love sandwiches 1 ``stride`` determines how many rows to skip between each window start. This can be used to reduce the total number of windows generated. >>> contextualize(data).window(4).stride(2).text_col('token').to_df() - token document_id - 0 The quick brown fox 1 - 2 brown fox jumped over 1 - 4 jumped over the lazy 1 - 6 the lazy dog I 1 + token document_id + 0 The quick brown fox 1 + 2 brown fox jumped over 1 + 4 jumped over the lazy 1 + 6 the lazy dog I 1 + 8 dog I love sandwiches 1 + 10 love sandwiches 2 ``groupby`` determines how to group the rows. For example, we would like to have context windows that don't cross document boundaries. In this case, we can @@ -80,6 +84,25 @@ def contextualize(raw_df: pd.DataFrame) -> Contextualizer: 0 The quick brown fox 1 2 brown fox jumped over 1 4 jumped over the lazy 1 + 6 the lazy dog 1 + 9 I love sandwiches 2 + + ``min_window_size`` determines the minimum size of the context windows that are generated + This can be used to trim the last few context windows which have size less than + ``min_window_size``. By default context windows of size 1 are skipped. + + >>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_df() + token document_id + 0 The quick brown fox jumped over 1 + 3 fox jumped over the lazy dog 1 + 6 the lazy dog 1 + 9 I love sandwiches 2 + + >>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_df() + token document_id + 0 The quick brown fox jumped over 1 + 3 fox jumped over the lazy dog 1 + """ return Contextualizer(raw_df) @@ -92,6 +115,7 @@ class Contextualizer: self._groupby = None self._stride = None self._window = None + self._min_window_size = 2 self._raw_df = raw_df def window(self, window: int) -> Contextualizer: @@ -139,6 +163,17 @@ class Contextualizer: self._text_col = text_col return self + def min_window_size(self, min_window_size: int) -> Contextualizer: + """Set the (optional) min_window_size size for the context window. + + Parameters + ---------- + min_window_size: int + The min_window_size. + """ + self._min_window_size = min_window_size + return self + def to_df(self) -> pd.DataFrame: """Create the context windows and return a DataFrame.""" @@ -159,12 +194,19 @@ class Contextualizer: def process_group(grp): # For each group, create the text rolling window + # with values of size >= min_window_size text = grp[self._text_col].values - contexts = grp.iloc[: -self._window : self._stride, :].copy() - contexts[self._text_col] = [ - " ".join(text[start_i : start_i + self._window]) - for start_i in range(0, len(grp) - self._window, self._stride) + contexts = grp.iloc[:: self._stride, :].copy() + windows = [ + " ".join(text[start_i : min(start_i + self._window, len(grp))]) + for start_i in range(0, len(grp), self._stride) + if start_i + self._window <= len(grp) + or len(grp) - start_i >= self._min_window_size ] + # if last few rows dropped + if len(windows) < len(contexts): + contexts = contexts.iloc[: len(windows)] + contexts[self._text_col] = windows return contexts if self._groupby is None: diff --git a/python/tests/test_context.py b/python/tests/test_context.py new file mode 100644 index 00000000..12ba4116 --- /dev/null +++ b/python/tests/test_context.py @@ -0,0 +1,77 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pytest + +from lancedb.context import contextualize + + +@pytest.fixture +def raw_df() -> pd.DataFrame: + return pd.DataFrame( + { + "token": [ + "The", + "quick", + "brown", + "fox", + "jumped", + "over", + "the", + "lazy", + "dog", + "I", + "love", + "sandwiches", + ], + "document_id": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2], + } + ) + + +def test_contextualizer(raw_df: pd.DataFrame): + result = ( + contextualize(raw_df) + .window(6) + .stride(3) + .text_col("token") + .groupby("document_id") + .to_df()["token"] + .to_list() + ) + + assert result == [ + "The quick brown fox jumped over", + "fox jumped over the lazy dog", + "the lazy dog", + "I love sandwiches", + ] + + +def test_contextualizer_with_threshold(raw_df: pd.DataFrame): + result = ( + contextualize(raw_df) + .window(6) + .stride(3) + .text_col("token") + .groupby("document_id") + .min_window_size(4) + .to_df()["token"] + .to_list() + ) + + assert result == [ + "The quick brown fox jumped over", + "fox jumped over the lazy dog", + ]