Skip to content

Commit e548c5c

Browse files
committed
Added type hints to ImageFilter
1 parent 4b258be commit e548c5c

File tree

6 files changed

+67
-38
lines changed

6 files changed

+67
-38
lines changed

Tests/test_color_lut.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,10 @@ def test_overflow(self) -> None:
354354
class TestColorLut3DFilter:
355355
def test_wrong_args(self) -> None:
356356
with pytest.raises(ValueError, match="should be either an integer"):
357-
ImageFilter.Color3DLUT("small", [1])
357+
ImageFilter.Color3DLUT("small", [1]) # type: ignore[arg-type]
358358

359359
with pytest.raises(ValueError, match="should be either an integer"):
360-
ImageFilter.Color3DLUT((11, 11), [1])
360+
ImageFilter.Color3DLUT((11, 11), [1]) # type: ignore[arg-type]
361361

362362
with pytest.raises(ValueError, match=r"in \[2, 65\] range"):
363363
ImageFilter.Color3DLUT((11, 11, 1), [1])

Tests/test_image_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_builtinfilter_p() -> None:
137137
builtin_filter = ImageFilter.BuiltinFilter()
138138

139139
with pytest.raises(ValueError):
140-
builtin_filter.filter(hopper("P"))
140+
builtin_filter.filter(hopper("P").im)
141141

142142

143143
def test_kernel_not_enough_coefficients() -> None:

Tests/test_numpy.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,21 @@
55

66
import pytest
77

8-
from PIL import Image
8+
from PIL import Image, _typing
99

1010
from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature
1111

1212
if TYPE_CHECKING:
1313
import numpy
14-
import numpy.typing
14+
import numpy.typing as npt
1515
else:
1616
numpy = pytest.importorskip("numpy", reason="NumPy not installed")
1717

1818
TEST_IMAGE_SIZE = (10, 10)
1919

2020

2121
def test_numpy_to_image() -> None:
22-
def to_image(
23-
dtype: numpy.typing.DTypeLike, bands: int = 1, boolean: int = 0
24-
) -> Image.Image:
22+
def to_image(dtype: npt.DTypeLike, bands: int = 1, boolean: int = 0) -> Image.Image:
2523
if bands == 1:
2624
if boolean:
2725
data = [0, 255] * 50
@@ -106,9 +104,7 @@ def test_1d_array() -> None:
106104
assert_image(Image.fromarray(a), "L", (1, 5))
107105

108106

109-
def _test_img_equals_nparray(
110-
img: Image.Image, np_img: numpy.typing.NDArray[Any]
111-
) -> None:
107+
def _test_img_equals_nparray(img: Image.Image, np_img: _typing.NumpyArray) -> None:
112108
assert len(np_img.shape) >= 2
113109
np_size = np_img.shape[1], np_img.shape[0]
114110
assert img.size == np_size
@@ -166,7 +162,7 @@ def test_save_tiff_uint16() -> None:
166162
("HSV", numpy.uint8),
167163
),
168164
)
169-
def test_to_array(mode: str, dtype: numpy.typing.DTypeLike) -> None:
165+
def test_to_array(mode: str, dtype: npt.DTypeLike) -> None:
170166
img = hopper(mode)
171167

172168
# Resize to non-square
@@ -216,7 +212,7 @@ def test_putdata() -> None:
216212
numpy.float64,
217213
),
218214
)
219-
def test_roundtrip_eye(dtype: numpy.typing.DTypeLike) -> None:
215+
def test_roundtrip_eye(dtype: npt.DTypeLike) -> None:
220216
arr = numpy.eye(10, dtype=dtype)
221217
numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr)))
222218

docs/reference/internal_modules.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ Internal Modules
3333
Provides a convenient way to import type hints that are not available
3434
on some Python versions.
3535

36+
.. py:class:: NumpyArray
37+
38+
Typing alias.
39+
3640
.. py:class:: StrOrBytesPath
3741
3842
Typing alias.

src/PIL/ImageFilter.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
import abc
2020
import functools
2121
from types import ModuleType
22-
from typing import Any, Sequence
22+
from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
23+
24+
if TYPE_CHECKING:
25+
from . import _imaging
26+
from ._typing import NumpyArray
2327

2428

2529
class Filter:
2630
@abc.abstractmethod
27-
def filter(self, image):
31+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
2832
pass
2933

3034

@@ -33,7 +37,9 @@ class MultibandFilter(Filter):
3337

3438

3539
class BuiltinFilter(MultibandFilter):
36-
def filter(self, image):
40+
filterargs: tuple[Any, ...]
41+
42+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
3743
if image.mode == "P":
3844
msg = "cannot filter palette images"
3945
raise ValueError(msg)
@@ -91,7 +97,7 @@ def __init__(self, size: int, rank: int) -> None:
9197
self.size = size
9298
self.rank = rank
9399

94-
def filter(self, image):
100+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
95101
if image.mode == "P":
96102
msg = "cannot filter palette images"
97103
raise ValueError(msg)
@@ -158,7 +164,7 @@ class ModeFilter(Filter):
158164
def __init__(self, size: int = 3) -> None:
159165
self.size = size
160166

161-
def filter(self, image):
167+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
162168
return image.modefilter(self.size)
163169

164170

@@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
176182
def __init__(self, radius: float | Sequence[float] = 2) -> None:
177183
self.radius = radius
178184

179-
def filter(self, image):
185+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
180186
xy = self.radius
181-
if not isinstance(xy, (tuple, list)):
187+
if isinstance(xy, (int, float)):
182188
xy = (xy, xy)
183189
if xy == (0, 0):
184190
return image.copy()
@@ -208,9 +214,9 @@ def __init__(self, radius: float | Sequence[float]) -> None:
208214
raise ValueError(msg)
209215
self.radius = radius
210216

211-
def filter(self, image):
217+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
212218
xy = self.radius
213-
if not isinstance(xy, (tuple, list)):
219+
if isinstance(xy, (int, float)):
214220
xy = (xy, xy)
215221
if xy == (0, 0):
216222
return image.copy()
@@ -241,7 +247,7 @@ def __init__(
241247
self.percent = percent
242248
self.threshold = threshold
243249

244-
def filter(self, image):
250+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
245251
return image.unsharp_mask(self.radius, self.percent, self.threshold)
246252

247253

@@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
387393
name = "Color 3D LUT"
388394

389395
def __init__(
390-
self, size, table, channels: int = 3, target_mode: str | None = None, **kwargs
391-
):
396+
self,
397+
size: int | tuple[int, int, int],
398+
table: Sequence[float] | Sequence[Sequence[int]] | NumpyArray,
399+
channels: int = 3,
400+
target_mode: str | None = None,
401+
**kwargs: bool,
402+
) -> None:
392403
if channels not in (3, 4):
393404
msg = "Only 3 or 4 output channels are supported"
394405
raise ValueError(msg)
@@ -410,15 +421,16 @@ def __init__(
410421
pass
411422

412423
if numpy and isinstance(table, numpy.ndarray):
424+
numpy_table: NumpyArray = table
413425
if copy_table:
414-
table = table.copy()
426+
numpy_table = numpy_table.copy()
415427

416-
if table.shape in [
428+
if numpy_table.shape in [
417429
(items * channels,),
418430
(items, channels),
419431
(size[2], size[1], size[0], channels),
420432
]:
421-
table = table.reshape(items * channels)
433+
table = numpy_table.reshape(items * channels)
422434
else:
423435
wrong_size = True
424436

@@ -428,15 +440,17 @@ def __init__(
428440

429441
# Convert to a flat list
430442
if table and isinstance(table[0], (list, tuple)):
431-
table, raw_table = [], table
443+
raw_table = cast(Sequence[Sequence[int]], table)
444+
flat_table: list[int] = []
432445
for pixel in raw_table:
433446
if len(pixel) != channels:
434447
msg = (
435448
"The elements of the table should "
436449
f"have a length of {channels}."
437450
)
438451
raise ValueError(msg)
439-
table.extend(pixel)
452+
flat_table.extend(pixel)
453+
table = flat_table
440454

441455
if wrong_size or len(table) != items * channels:
442456
msg = (
@@ -449,23 +463,29 @@ def __init__(
449463
self.table = table
450464

451465
@staticmethod
452-
def _check_size(size: Any) -> list[int]:
466+
def _check_size(size: Any) -> tuple[int, int, int]:
453467
try:
454468
_, _, _ = size
455469
except ValueError as e:
456470
msg = "Size should be either an integer or a tuple of three integers."
457471
raise ValueError(msg) from e
458472
except TypeError:
459473
size = (size, size, size)
460-
size = [int(x) for x in size]
474+
size = tuple(int(x) for x in size)
461475
for size_1d in size:
462476
if not 2 <= size_1d <= 65:
463477
msg = "Size should be in [2, 65] range."
464478
raise ValueError(msg)
465479
return size
466480

467481
@classmethod
468-
def generate(cls, size, callback, channels=3, target_mode=None):
482+
def generate(
483+
cls,
484+
size: int | tuple[int, int, int],
485+
callback: Callable[[float, float, float], tuple[float, ...]],
486+
channels: int = 3,
487+
target_mode: str | None = None,
488+
) -> Color3DLUT:
469489
"""Generates new LUT using provided callback.
470490
471491
:param size: Size of the table. Passed to the constructor.
@@ -482,7 +502,7 @@ def generate(cls, size, callback, channels=3, target_mode=None):
482502
msg = "Only 3 or 4 output channels are supported"
483503
raise ValueError(msg)
484504

485-
table = [0] * (size_1d * size_2d * size_3d * channels)
505+
table: list[float] = [0] * (size_1d * size_2d * size_3d * channels)
486506
idx_out = 0
487507
for b in range(size_3d):
488508
for g in range(size_2d):
@@ -500,7 +520,13 @@ def generate(cls, size, callback, channels=3, target_mode=None):
500520
_copy_table=False,
501521
)
502522

503-
def transform(self, callback, with_normals=False, channels=None, target_mode=None):
523+
def transform(
524+
self,
525+
callback: Callable[..., tuple[float, ...]],
526+
with_normals: bool = False,
527+
channels: int | None = None,
528+
target_mode: str | None = None,
529+
) -> Color3DLUT:
504530
"""Transforms the table values using provided callback and returns
505531
a new LUT with altered values.
506532
@@ -564,7 +590,7 @@ def __repr__(self) -> str:
564590
r.append(f"target_mode={self.mode}")
565591
return "<{}>".format(" ".join(r))
566592

567-
def filter(self, image):
593+
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
568594
from . import Image
569595

570596
return image.color_lut_3d(

src/PIL/_typing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
import os
44
import sys
5-
from typing import Protocol, Sequence, TypeVar, Union
5+
from typing import Any, Protocol, Sequence, TypeVar, Union
6+
7+
import numpy.typing as npt
68

79
if sys.version_info >= (3, 10):
810
from typing import TypeGuard
911
else:
1012
try:
1113
from typing_extensions import TypeGuard
1214
except ImportError:
13-
from typing import Any
1415

1516
class TypeGuard: # type: ignore[no-redef]
1617
def __class_getitem__(cls, item: Any) -> type[bool]:
@@ -19,6 +20,8 @@ def __class_getitem__(cls, item: Any) -> type[bool]:
1920

2021
Coords = Union[Sequence[float], Sequence[Sequence[float]]]
2122

23+
NumpyArray = npt.NDArray[Any]
24+
2225

2326
_T_co = TypeVar("_T_co", covariant=True)
2427

0 commit comments

Comments
 (0)