19
19
import abc
20
20
import functools
21
21
from types import ModuleType
22
- from typing import Any , Sequence
22
+ from typing import TYPE_CHECKING , Any , Callable , Sequence , cast
23
+
24
+ if TYPE_CHECKING :
25
+ from . import _imaging
26
+ from ._typing import NumpyArray
23
27
24
28
25
29
class Filter :
26
30
@abc .abstractmethod
27
- def filter (self , image ) :
31
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
28
32
pass
29
33
30
34
@@ -33,7 +37,9 @@ class MultibandFilter(Filter):
33
37
34
38
35
39
class BuiltinFilter (MultibandFilter ):
36
- def filter (self , image ):
40
+ filterargs : tuple [Any , ...]
41
+
42
+ def filter (self , image : _imaging .ImagingCore ) -> _imaging .ImagingCore :
37
43
if image .mode == "P" :
38
44
msg = "cannot filter palette images"
39
45
raise ValueError (msg )
@@ -91,7 +97,7 @@ def __init__(self, size: int, rank: int) -> None:
91
97
self .size = size
92
98
self .rank = rank
93
99
94
- def filter (self , image ) :
100
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
95
101
if image .mode == "P" :
96
102
msg = "cannot filter palette images"
97
103
raise ValueError (msg )
@@ -158,7 +164,7 @@ class ModeFilter(Filter):
158
164
def __init__ (self , size : int = 3 ) -> None :
159
165
self .size = size
160
166
161
- def filter (self , image ) :
167
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
162
168
return image .modefilter (self .size )
163
169
164
170
@@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
176
182
def __init__ (self , radius : float | Sequence [float ] = 2 ) -> None :
177
183
self .radius = radius
178
184
179
- def filter (self , image ) :
185
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
180
186
xy = self .radius
181
- if not isinstance (xy , (tuple , list )):
187
+ if isinstance (xy , (int , float )):
182
188
xy = (xy , xy )
183
189
if xy == (0 , 0 ):
184
190
return image .copy ()
@@ -208,9 +214,9 @@ def __init__(self, radius: float | Sequence[float]) -> None:
208
214
raise ValueError (msg )
209
215
self .radius = radius
210
216
211
- def filter (self , image ) :
217
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
212
218
xy = self .radius
213
- if not isinstance (xy , (tuple , list )):
219
+ if isinstance (xy , (int , float )):
214
220
xy = (xy , xy )
215
221
if xy == (0 , 0 ):
216
222
return image .copy ()
@@ -241,7 +247,7 @@ def __init__(
241
247
self .percent = percent
242
248
self .threshold = threshold
243
249
244
- def filter (self , image ) :
250
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
245
251
return image .unsharp_mask (self .radius , self .percent , self .threshold )
246
252
247
253
@@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
387
393
name = "Color 3D LUT"
388
394
389
395
def __init__ (
390
- self , size , table , channels : int = 3 , target_mode : str | None = None , ** kwargs
391
- ):
396
+ self ,
397
+ size : int | tuple [int , int , int ],
398
+ table : Sequence [float ] | Sequence [Sequence [int ]] | NumpyArray ,
399
+ channels : int = 3 ,
400
+ target_mode : str | None = None ,
401
+ ** kwargs : bool ,
402
+ ) -> None :
392
403
if channels not in (3 , 4 ):
393
404
msg = "Only 3 or 4 output channels are supported"
394
405
raise ValueError (msg )
@@ -410,15 +421,16 @@ def __init__(
410
421
pass
411
422
412
423
if numpy and isinstance (table , numpy .ndarray ):
424
+ numpy_table : NumpyArray = table
413
425
if copy_table :
414
- table = table .copy ()
426
+ numpy_table = numpy_table .copy ()
415
427
416
- if table .shape in [
428
+ if numpy_table .shape in [
417
429
(items * channels ,),
418
430
(items , channels ),
419
431
(size [2 ], size [1 ], size [0 ], channels ),
420
432
]:
421
- table = table .reshape (items * channels )
433
+ table = numpy_table .reshape (items * channels )
422
434
else :
423
435
wrong_size = True
424
436
@@ -428,15 +440,17 @@ def __init__(
428
440
429
441
# Convert to a flat list
430
442
if table and isinstance (table [0 ], (list , tuple )):
431
- table , raw_table = [], table
443
+ raw_table = cast (Sequence [Sequence [int ]], table )
444
+ flat_table : list [int ] = []
432
445
for pixel in raw_table :
433
446
if len (pixel ) != channels :
434
447
msg = (
435
448
"The elements of the table should "
436
449
f"have a length of { channels } ."
437
450
)
438
451
raise ValueError (msg )
439
- table .extend (pixel )
452
+ flat_table .extend (pixel )
453
+ table = flat_table
440
454
441
455
if wrong_size or len (table ) != items * channels :
442
456
msg = (
@@ -449,23 +463,29 @@ def __init__(
449
463
self .table = table
450
464
451
465
@staticmethod
452
- def _check_size (size : Any ) -> list [ int ]:
466
+ def _check_size (size : Any ) -> tuple [ int , int , int ]:
453
467
try :
454
468
_ , _ , _ = size
455
469
except ValueError as e :
456
470
msg = "Size should be either an integer or a tuple of three integers."
457
471
raise ValueError (msg ) from e
458
472
except TypeError :
459
473
size = (size , size , size )
460
- size = [ int (x ) for x in size ]
474
+ size = tuple ( int (x ) for x in size )
461
475
for size_1d in size :
462
476
if not 2 <= size_1d <= 65 :
463
477
msg = "Size should be in [2, 65] range."
464
478
raise ValueError (msg )
465
479
return size
466
480
467
481
@classmethod
468
- def generate (cls , size , callback , channels = 3 , target_mode = None ):
482
+ def generate (
483
+ cls ,
484
+ size : int | tuple [int , int , int ],
485
+ callback : Callable [[float , float , float ], tuple [float , ...]],
486
+ channels : int = 3 ,
487
+ target_mode : str | None = None ,
488
+ ) -> Color3DLUT :
469
489
"""Generates new LUT using provided callback.
470
490
471
491
:param size: Size of the table. Passed to the constructor.
@@ -482,7 +502,7 @@ def generate(cls, size, callback, channels=3, target_mode=None):
482
502
msg = "Only 3 or 4 output channels are supported"
483
503
raise ValueError (msg )
484
504
485
- table = [0 ] * (size_1d * size_2d * size_3d * channels )
505
+ table : list [ float ] = [0 ] * (size_1d * size_2d * size_3d * channels )
486
506
idx_out = 0
487
507
for b in range (size_3d ):
488
508
for g in range (size_2d ):
@@ -500,7 +520,13 @@ def generate(cls, size, callback, channels=3, target_mode=None):
500
520
_copy_table = False ,
501
521
)
502
522
503
- def transform (self , callback , with_normals = False , channels = None , target_mode = None ):
523
+ def transform (
524
+ self ,
525
+ callback : Callable [..., tuple [float , ...]],
526
+ with_normals : bool = False ,
527
+ channels : int | None = None ,
528
+ target_mode : str | None = None ,
529
+ ) -> Color3DLUT :
504
530
"""Transforms the table values using provided callback and returns
505
531
a new LUT with altered values.
506
532
@@ -564,7 +590,7 @@ def __repr__(self) -> str:
564
590
r .append (f"target_mode={ self .mode } " )
565
591
return "<{}>" .format (" " .join (r ))
566
592
567
- def filter (self , image ) :
593
+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
568
594
from . import Image
569
595
570
596
return image .color_lut_3d (
0 commit comments