mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
fix(python): uses PIL incorrectly and may raise AttributeError (#2954)
Importing `PIL` alone does not guarantee that the `Image` submodule is loaded. In a clean environment where no other code has imported `PIL.Image` before, `PIL.Image` does not exist on the `PIL` package, which leads to the AttributeError.
This commit is contained in:
@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
Convert image inputs to PIL Images.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
requests = attempt_import_or_raise("requests", "requests")
|
||||
images = self.sanitize_input(images)
|
||||
pil_images = []
|
||||
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
if image.startswith(("http://", "https://")):
|
||||
response = requests.get(image, timeout=10)
|
||||
response.raise_for_status()
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
|
||||
else:
|
||||
with PIL.Image.open(image) as im:
|
||||
with PIL_Image.open(image) as im:
|
||||
pil_images.append(im.copy())
|
||||
elif isinstance(image, bytes):
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
||||
pil_images.append(PIL_Image.open(io.BytesIO(image)))
|
||||
else:
|
||||
# Assume it's a PIL Image; will raise if invalid
|
||||
pil_images.append(image)
|
||||
|
||||
@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
if isinstance(inputs, list):
|
||||
inputs = inputs
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, PIL.Image.Image):
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(inputs, PIL_Image.Image):
|
||||
inputs = [inputs]
|
||||
return inputs
|
||||
|
||||
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
elif isinstance(image, (str, Path)):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if parsed.scheme == "file":
|
||||
pil_image = PIL.Image.open(parsed.path)
|
||||
pil_image = PIL_Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
buffered = io.BytesIO()
|
||||
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
image_bytes = buffered.getvalue()
|
||||
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
if isinstance(image, PIL_Image.Image):
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
elif isinstance(query, (Path, bytes)):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
if isinstance(query, PIL_Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError(
|
||||
|
||||
@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(query, PIL_Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return PIL_Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL_Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
return PIL_Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
return PIL_Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(query, PIL_Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("SigLIP supports str or PIL Image as query")
|
||||
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
|
||||
return image_features.cpu().detach().numpy().squeeze()
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(image, PIL_Image.Image):
|
||||
return image.convert("RGB") if image.mode != "RGB" else image
|
||||
elif isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
|
||||
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path).convert("RGB")
|
||||
return PIL_Image.open(parsed.path).convert("RGB")
|
||||
elif parsed.scheme == "":
|
||||
path = image if os.name == "nt" else parsed.path
|
||||
return PIL.Image.open(path).convert("RGB")
|
||||
return PIL_Image.open(path).convert("RGB")
|
||||
elif parsed.scheme.startswith("http"):
|
||||
image_bytes = url_retrieve(image)
|
||||
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
else:
|
||||
|
||||
@@ -64,7 +64,7 @@ def is_video_path(path: Path) -> bool:
|
||||
|
||||
|
||||
def transform_input(input_data: Union[str, bytes, Path]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(input_data, str):
|
||||
if is_valid_url(input_data):
|
||||
if is_video_url(input_data):
|
||||
@@ -73,7 +73,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
else:
|
||||
content = {"type": "text", "text": input_data}
|
||||
elif isinstance(input_data, PIL.Image.Image):
|
||||
elif isinstance(input_data, PIL_Image.Image):
|
||||
buffered = BytesIO()
|
||||
input_data.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
@@ -82,7 +82,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
elif isinstance(input_data, bytes):
|
||||
img = PIL.Image.open(BytesIO(input_data))
|
||||
img = PIL_Image.open(BytesIO(input_data))
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
@@ -101,7 +101,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
||||
"video_base64": video_str,
|
||||
}
|
||||
else:
|
||||
img = PIL.Image.open(input_data)
|
||||
img = PIL_Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
@@ -119,8 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
||||
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, list):
|
||||
pass # Already a list, use as-is
|
||||
@@ -133,7 +133,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
||||
f"Input type {type(inputs)} not allowed with multimodal model."
|
||||
)
|
||||
|
||||
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
|
||||
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
|
||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||
|
||||
return [transform_input(i) for i in inputs]
|
||||
|
||||
Reference in New Issue
Block a user