diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 04967e21..6f450192 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -33,11 +33,11 @@ jobs: python-version: "3.11" - name: Install ruff run: | - pip install ruff==0.2.2 + pip install ruff==0.5.4 - name: Format check run: ruff format --check . - name: Lint - run: ruff . + run: ruff check . doctest: name: "Doctest" timeout-minutes: 30 diff --git a/docs/src/python/saas-python.md b/docs/src/python/saas-python.md index cc4ffa00..e2ea3047 100644 --- a/docs/src/python/saas-python.md +++ b/docs/src/python/saas-python.md @@ -1,6 +1,6 @@ # Python API Reference (SaaS) -This section contains the API reference for the SaaS Python API. +This section contains the API reference for the LanceDB Cloud Python API. ## Installation diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 9aaf2857..2d72acad 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -163,19 +163,19 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: TypeError If the type is not supported. """ - if py_type == int: + if py_type is int: return pa.int64() - elif py_type == float: + elif py_type is float: return pa.float64() - elif py_type == str: + elif py_type is str: return pa.utf8() - elif py_type == bool: + elif py_type is bool: return pa.bool_() - elif py_type == bytes: + elif py_type is bytes: return pa.binary() - elif py_type == date: + elif py_type is date: return pa.date32() - elif py_type == datetime: + elif py_type is datetime: tz = get_extras(field, "tz") return pa.timestamp("us", tz=tz) elif getattr(py_type, "__origin__", None) in (list, tuple): @@ -210,17 +210,17 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: ): origin = field.annotation.__origin__ args = field.annotation.__args__ - if origin == list: + if origin is list: child = args[0] return pa.list_(_py_type_to_arrow_type(child, field)) elif origin == Union: - if len(args) == 2 and args[1] == type(None): + if len(args) == 2 and args[1] is type(None): return _py_type_to_arrow_type(args[0], field) elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): args = field.annotation.__args__ if len(args) == 2: for typ in args: - if typ == type(None): + if typ is type(None): continue return _py_type_to_arrow_type(typ, field) elif inspect.isclass(field.annotation): @@ -239,12 +239,12 @@ def is_nullable(field: FieldInfo) -> bool: origin = field.annotation.__origin__ args = field.annotation.__args__ if origin == Union: - if len(args) == 2 and args[1] == type(None): + if len(args) == 2 and args[1] is type(None): return True elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): args = field.annotation.__args__ for typ in args: - if typ == type(None): + if typ is type(None): return True return False diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 85b7e789..68fb0ed4 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -330,6 +330,14 @@ class RemoteTable(Table): result = self._conn._client.query(self._name, query) return result.to_arrow().to_reader() + def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: + """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] + that can be used to create a "merge insert" operation. + + See [`Table.merge_insert`][lancedb.table.Table.merge_insert] for more details. + """ + super().merge_insert(on) + def _do_merge( self, merge: LanceMergeInsertBuilder, @@ -354,9 +362,9 @@ class RemoteTable(Table): params["on"] = merge._on[0] params["when_matched_update_all"] = str(merge._when_matched_update_all).lower() if merge._when_matched_update_all_condition is not None: - params[ - "when_matched_update_all_filt" - ] = merge._when_matched_update_all_condition + params["when_matched_update_all_filt"] = ( + merge._when_matched_update_all_condition + ) params["when_not_matched_insert_all"] = str( merge._when_not_matched_insert_all ).lower() @@ -364,9 +372,9 @@ class RemoteTable(Table): merge._when_not_matched_by_source_delete ).lower() if merge._when_not_matched_by_source_condition is not None: - params[ - "when_not_matched_by_source_delete_filt" - ] = merge._when_not_matched_by_source_condition + params["when_not_matched_by_source_delete_filt"] = ( + merge._when_not_matched_by_source_condition + ) self._conn._client.post( f"/v1/table/{self._name}/merge_insert/",