Skip to content

Added ppm support and fixed attribute types #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions autodistill/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,36 @@ def load_image(
pil_image = Image.open(BytesIO(response.content))
return np.array(pil_image)
elif os.path.isfile(image):
if return_format == "PIL":
return Image.open(image)
elif return_format == "cv2":
# channels need to be reversed for cv2
return cv2.cvtColor(np.array(Image.open(image)), cv2.COLOR_RGB2BGR)
elif return_format == "numpy":
pil_image = Image.open(image)
return np.array(pil_image)
# Special handling for PPM files
if image.lower().endswith('.ppm'):
if return_format == "PIL":
return Image.open(image)
elif return_format == "cv2":
# For PPM files, we can use cv2.imread directly
cv2_image = cv2.imread(image)
if cv2_image is None:
# Fallback to PIL if cv2 can't read it
return cv2.cvtColor(np.array(Image.open(image)), cv2.COLOR_RGB2BGR)
return cv2_image
elif return_format == "numpy":
# Try with cv2 first for better performance with PPM
cv2_image = cv2.imread(image)
if cv2_image is None:
# Fallback to PIL
pil_image = Image.open(image)
return np.array(pil_image)
# Convert BGR to RGB for numpy array
return cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
else:
# Regular handling for other image formats
if return_format == "PIL":
return Image.open(image)
elif return_format == "cv2":
# channels need to be reversed for cv2
return cv2.cvtColor(np.array(Image.open(image)), cv2.COLOR_RGB2BGR)
elif return_format == "numpy":
pil_image = Image.open(image)
return np.array(pil_image)
else:
raise ValueError(f"{image} is not a valid file path or URI")

Expand All @@ -84,8 +106,8 @@ def split_data(base_dir, split_ratio=0.8, record_confidence=False):
os.rename(
os.path.join(images_dir, file), os.path.join(images_dir, new_file_name)
)

# Convert .png and .jpeg images to .jpg
# Convert .png, .jpeg, and .ppm images to .jpg
for file in os.listdir(images_dir):
if file.endswith(".png"):
img = Image.open(os.path.join(images_dir, file))
Expand All @@ -97,6 +119,11 @@ def split_data(base_dir, split_ratio=0.8, record_confidence=False):
rgb_img = img.convert("RGB")
rgb_img.save(os.path.join(images_dir, file.replace(".jpeg", ".jpg")))
os.remove(os.path.join(images_dir, file))
if file.endswith(".ppm"):
img = Image.open(os.path.join(images_dir, file))
rgb_img = img.convert("RGB")
rgb_img.save(os.path.join(images_dir, file.replace(".ppm", ".jpg")))
os.remove(os.path.join(images_dir, file))

# Get list of all files (removing the image file extension)
all_files = os.listdir(images_dir)
Expand Down
37 changes: 36 additions & 1 deletion test/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from PIL import Image
import cv2
import numpy as np
import os
import tempfile

TEST_IMAGE = "test/data/dog.jpeg"

Expand Down Expand Up @@ -30,4 +32,37 @@

assert isinstance(load_image(url, return_format="PIL"), Image.Image)
assert isinstance(load_image(url, return_format="cv2"), np.ndarray)
assert isinstance(load_image(url, return_format="numpy"), np.ndarray)
assert isinstance(load_image(url, return_format="numpy"), np.ndarray)

# Test PPM support if test_ppm_support is available
try:
from test_ppm_support import create_test_ppm

# Create a test PPM file
TEST_PPM = create_test_ppm()

# Test PPM loading
assert isinstance(load_image(TEST_PPM, return_format="PIL"), Image.Image)
assert isinstance(load_image(TEST_PPM, return_format="cv2"), np.ndarray)
assert isinstance(load_image(TEST_PPM, return_format="numpy"), np.ndarray)

# Clean up
if os.path.exists(TEST_PPM):
os.remove(TEST_PPM)

print("PPM support tests passed!")
except (ImportError, ModuleNotFoundError):
print("Skipping PPM tests - test_ppm_support module not found")

# Test error handling for non-image file
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tmp:
tmp.write("This is not an image.")
tmp_path = tmp.name
try:
try:
load_image(tmp_path)
assert False, "Expected ValueError for non-image file"
except ValueError as e:
assert "not a valid image" in str(e)
finally:
os.remove(tmp_path)
60 changes: 60 additions & 0 deletions test/test_ppm_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import sys
import numpy as np
import cv2
from PIL import Image

# Add parent directory to path so we can import autodistill
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from autodistill.helpers import load_image

def create_test_ppm(filepath="test/data/test.ppm"):
"""Create a simple PPM image file for testing"""
# Ensure directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True)

# Create a simple 100x100 colored image
width, height = 100, 100

# PPM P6 format (binary)
with open(filepath, 'wb') as f:
f.write(b'P6\n')
f.write(f'{width} {height}\n'.encode())
f.write(b'255\n') # Max color value

# Write RGB data
for y in range(height):
for x in range(width):
# Create a gradient
r = (x * 255) // width
g = (y * 255) // height
b = ((x+y) * 255) // (width+height)
f.write(bytes([r, g, b]))

return filepath

def test_ppm_loading():
"""Test that PPM files can be loaded correctly using the load_image function"""
# Create a test PPM file
test_file = create_test_ppm()

# Test loading with various return formats
pil_image = load_image(test_file, return_format="PIL")
assert isinstance(pil_image, Image.Image), "Failed to load PPM as PIL Image"

cv2_image = load_image(test_file, return_format="cv2")
assert isinstance(cv2_image, np.ndarray), "Failed to load PPM as cv2 image"
assert len(cv2_image.shape) == 3, "PPM should be loaded as a 3-channel image"

numpy_image = load_image(test_file, return_format="numpy")
assert isinstance(numpy_image, np.ndarray), "Failed to load PPM as numpy array"
assert len(numpy_image.shape) == 3, "PPM should be loaded as a 3-channel image"

print("All PPM loading tests passed!")

# Clean up
os.remove(test_file)

if __name__ == "__main__":
test_ppm_loading()
Loading