From 826fe320bb823ccc7243b300f7685280708ccf86 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:31:24 -0700 Subject: [PATCH] address PR comments --- python/lancedb/context.py | 49 +++++++++++++++++++++++++++++++++------ python/lancedb/db.py | 5 ++++ python/lancedb/table.py | 11 +++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 3a2a4c2d..25090195 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -10,11 +10,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import pandas as pd -def contextualize(raw_df): +def contextualize(raw_df: pd.DataFrame) -> Contextualizer: + """Create a Contextualizer object for the given DataFrame. + Used to create context windows. + """ return Contextualizer(raw_df) @@ -26,25 +30,56 @@ class Contextualizer: self._window = None self._raw_df = raw_df - def window(self, window): + def window(self, window: int) -> Contextualizer: + """Set the window size. i.e., how many rows to include in each window. + + Parameters + ---------- + window: int + The window size. + """ self._window = window return self - def stride(self, stride): + def stride(self, stride: int) -> Contextualizer: + """Set the stride. i.e., how many rows to skip between each window. + + Parameters + ---------- + stride: int + The stride. + """ self._stride = stride return self - def groupby(self, groupby): + def groupby(self, groupby: str) -> Contextualizer: + """Set the groupby column. i.e., how to group the rows. + Windows don't cross groups + + Parameters + ---------- + groupby: str + The groupby column. + """ self._groupby = groupby return self - def text_col(self, text_col): + def text_col(self, text_col: str) -> Contextualizer: + """Set the text column used to make the context window. + + Parameters + ---------- + text_col: str + The text column. + """ self._text_col = text_col return self - def to_df(self): + def to_df(self) -> pd.DataFrame: + """Create the context windows and return a DataFrame.""" + def process_group(grp): - # For each video, create the text rolling window + # For each group, create the text rolling window text = grp[self._text_col].values contexts = grp.iloc[: -self._window : self._stride, :].copy() contexts[self._text_col] = [ diff --git a/python/lancedb/db.py b/python/lancedb/db.py index d869575b..cf408e21 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -68,6 +68,11 @@ class LanceDBConnection: schema: pyarrow.Schema; optional The schema of the table. + Note + ---- + The vector index won't be created by default. + To create the index, call the `create_index` method on the table. + Returns ------- A LanceTable object representing the table. diff --git a/python/lancedb/table.py b/python/lancedb/table.py index b5fd6a1f..e4eea8f6 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -60,6 +60,17 @@ class LanceTable: return os.path.join(self._conn.uri, f"{self.name}.lance") def create_index(self, num_partitions=256, num_sub_vectors=96): + """Create an index on the table. + + Parameters + ---------- + num_partitions: int + The number of IVF partitions to use when creating the index. + Default is 256. + num_sub_vectors: int + The number of PQ sub-vectors to use when creating the index. + Default is 96. + """ return self._dataset.create_index( column=VECTOR_COLUMN_NAME, index_type="IVF_PQ",