From 9b8472850e55ecf411e46b7d0e82af0c4b6de592 Mon Sep 17 00:00:00 2001 From: Sayandip Dutta Date: Sat, 14 Sep 2024 01:02:59 +0530 Subject: [PATCH] fix: unterminated string literal on table update (#1573) resolves #1429 (python) ```python - return f"'{value}'" + return f'"{value}"' ``` --------- Co-authored-by: Will Jones --- python/python/lancedb/util.py | 1 + python/python/tests/test_util.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index b1371a8a..6392e40e 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -219,6 +219,7 @@ def value_to_sql(value): @value_to_sql.register(str) def _(value: str): + value = value.replace("'", "''") return f"'{value}'" diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index fa7e75f0..2681505f 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -15,7 +15,8 @@ import os import pathlib import pytest -from lancedb.util import get_uri_scheme, join_uri +import lancedb +from lancedb.util import get_uri_scheme, join_uri, value_to_sql def test_normalize_uri(): @@ -84,3 +85,29 @@ def test_local_join_uri_windows(): assert joined == str(pathlib.Path(base) / "table.lance") joined = join_uri(pathlib.Path(base), "table.lance") assert joined == pathlib.Path(base) / "table.lance" + + +def test_value_to_sql_string(tmp_path): + # Make sure we can convert Python string literals to SQL strings, even if + # they contain characters meaningful in SQL, such as ' and \. + values = ["anthony's", 'a "test" string', "anthony's \"favorite color\" wasn't red"] + expected_values = [ + "'anthony''s'", + "'a \"test\" string'", + "'anthony''s \"favorite color\" wasn''t red'", + ] + + for value, expected in zip(values, expected_values): + assert value_to_sql(value) == expected + + # Also test we can roundtrip those strings through update. + # This validates the query parser understands the strings we + # are creating. + db = lancedb.connect(tmp_path) + table = db.create_table( + "test", + [{"search": value, "replace": "something"} for value in values], + ) + for value in values: + table.update(where=f"search = {value_to_sql(value)}", values={"replace": value}) + assert table.to_pandas().query("search == @value")["replace"].item() == value