diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 83d10771..dee4180c 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -13,7 +13,7 @@ from __future__ import annotations import pandas as pd - +from .exceptions import MissingValueError, MissingColumnError def contextualize(raw_df: pd.DataFrame) -> Contextualizer: """Create a Contextualizer object for the given DataFrame. @@ -140,6 +140,17 @@ class Contextualizer: def to_df(self) -> pd.DataFrame: """Create the context windows and return a DataFrame.""" + if self._text_col not in self._raw_df.columns.tolist(): + raise MissingColumnError(self._text_col) + + if self._window is None or self._window < 1: + raise MissingValueError("The value of window is None or less than 1. Specify the " + "window size (number of rows to include in each window)") + + if self._stride is None or self._stride < 1: + raise MissingValueError("The value of stride is None or less than 1. Specify the " + "stride (number of rows to skip between each window)") + def process_group(grp): # For each group, create the text rolling window text = grp[self._text_col].values diff --git a/python/lancedb/exceptions.py b/python/lancedb/exceptions.py new file mode 100644 index 00000000..10def337 --- /dev/null +++ b/python/lancedb/exceptions.py @@ -0,0 +1,16 @@ +"""Custom exception handling""" + +class MissingValueError(ValueError): + """Exception raised when a required value is missing.""" + pass + +class MissingColumnError(KeyError): + """ + Exception raised when a column name specified is not in + the DataFrame object + """ + def __init__(self, column_name): + self.column_name = column_name + + def __str__(self): + return f"Error: Column '{self.column_name}' does not exist in the DataFrame object"