Skip to content

Commit a05a61b

Browse files
committed
Use closefd argument
1 parent 341969d commit a05a61b

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

src/xopen/__init__.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
try:
6767
import zstandard # type: ignore
6868
except ImportError:
69-
zstandard = None
69+
zstandard = None # type: ignore
7070

7171
try:
7272
import fcntl
@@ -171,6 +171,7 @@ def __init__( # noqa: C901
171171
compresslevel: Optional[int] = None,
172172
threads: Optional[int] = None,
173173
program_settings: _ProgramSettings = _ProgramSettings(("gzip", "--no-name")),
174+
closefd: bool = True,
174175
):
175176
"""
176177
mode -- one of 'w', 'wb', 'a', 'ab'
@@ -185,10 +186,13 @@ def __init__( # noqa: C901
185186
self._program_args = list(program_settings.program_args)
186187
self._allowed_exit_code = program_settings.allowed_exit_code
187188
self._allowed_exit_message = program_settings.allowed_exit_message
189+
self.closefd = closefd
188190
if mode not in ("r", "rb", "w", "wb", "a", "ab"):
189191
raise ValueError(
190192
f"Mode is '{mode}', but it must be 'r', 'rb', 'w', 'wb', 'a', or 'ab'"
191193
)
194+
if "b" not in mode:
195+
mode += "b"
192196
if hasattr(filename, "read") or hasattr(filename, "write"):
193197
file: BinaryIO = filename # type: ignore
194198
filepath: FilePath = ""
@@ -262,7 +266,8 @@ def __init__( # noqa: C901
262266
close_fds=close_fds,
263267
) # type: ignore
264268
except OSError:
265-
file.close()
269+
if self.closefd:
270+
file.close()
266271
raise
267272
assert self.process.stdin is not None
268273
self._file = self.process.stdin # type: ignore
@@ -334,6 +339,8 @@ def close(self) -> None:
334339
if self.in_thread:
335340
self.in_thread.join()
336341
self._file.close()
342+
if self.closefd:
343+
self.fileobj.close()
337344
else:
338345
self._file.close()
339346
self.process.wait()
@@ -433,7 +440,7 @@ def _open_stdin_or_out(mode: str) -> BinaryIO:
433440
return open(std.fileno(), mode=mode, closefd=False) # type: ignore
434441

435442

436-
def _open_bz2(fileobj: BinaryIO, mode: str, threads: Optional[int]):
443+
def _open_bz2(fileobj: BinaryIO, mode: str, threads: Optional[int], closefd: bool):
437444
assert "b" in mode
438445
if threads != 0:
439446
try:
@@ -443,6 +450,7 @@ def _open_bz2(fileobj: BinaryIO, mode: str, threads: Optional[int]):
443450
mode,
444451
threads=threads,
445452
program_settings=_PROGRAM_SETTINGS["pbzip2"],
453+
closefd=closefd,
446454
)
447455
except OSError:
448456
pass # We try without threads.
@@ -455,6 +463,7 @@ def _open_xz(
455463
mode: str,
456464
compresslevel: Optional[int],
457465
threads: Optional[int],
466+
closefd: bool,
458467
):
459468
assert "b" in mode
460469
if compresslevel is None:
@@ -464,7 +473,7 @@ def _open_xz(
464473
try:
465474
# xz can compress using multiple cores.
466475
return _PipedCompressionProgram(
467-
fileobj, mode, compresslevel, threads, _PROGRAM_SETTINGS["xz"]
476+
fileobj, mode, compresslevel, threads, _PROGRAM_SETTINGS["xz"], closefd
468477
)
469478
except OSError:
470479
pass # We try without threads.
@@ -481,6 +490,7 @@ def _open_zst( # noqa: C901
481490
mode: str,
482491
compresslevel: Optional[int],
483492
threads: Optional[int],
493+
closefd: bool,
484494
):
485495
assert "b" in mode
486496
assert compresslevel != 0
@@ -490,7 +500,12 @@ def _open_zst( # noqa: C901
490500
try:
491501
# zstd can compress using multiple cores
492502
return _PipedCompressionProgram(
493-
fileobj, mode, compresslevel, threads, _PROGRAM_SETTINGS["zstd"]
503+
fileobj,
504+
mode,
505+
compresslevel,
506+
threads,
507+
_PROGRAM_SETTINGS["zstd"],
508+
closefd,
494509
)
495510
except OSError:
496511
if zstandard is None:
@@ -511,7 +526,7 @@ def _open_zst( # noqa: C901
511526
return f
512527

513528

514-
def _open_gz(fileobj: BinaryIO, mode: str, compresslevel, threads, **text_mode_kwargs):
529+
def _open_gz(fileobj: BinaryIO, mode: str, compresslevel, threads, closefd):
515530
"""
516531
Open a gzip file. The ISA-L library is preferred when applicable because
517532
it is the fastest. Then zlib-ng which is not as fast, but supports all
@@ -553,18 +568,21 @@ def _open_gz(fileobj: BinaryIO, mode: str, compresslevel, threads, **text_mode_k
553568
for program in ("pigz", "gzip"):
554569
try:
555570
return _PipedCompressionProgram(
556-
fileobj, mode, compresslevel, threads, _PROGRAM_SETTINGS[program]
571+
fileobj,
572+
mode,
573+
compresslevel,
574+
threads,
575+
_PROGRAM_SETTINGS[program],
576+
closefd,
557577
)
558578
except OSError:
559579
pass # We try without threads.
560580
return _open_reproducible_gzip(
561-
fileobj,
562-
mode=mode,
563-
compresslevel=compresslevel,
581+
fileobj, mode=mode, compresslevel=compresslevel, closefd=closefd
564582
)
565583

566584

567-
def _open_reproducible_gzip(fileobj, mode: str, compresslevel: int):
585+
def _open_reproducible_gzip(fileobj, mode: str, compresslevel: int, closefd):
568586
"""
569587
Open a gzip file for writing (without external processes)
570588
that has neither mtime nor the file name in the header
@@ -595,7 +613,8 @@ def _open_reproducible_gzip(fileobj, mode: str, compresslevel: int):
595613
# When (I)GzipFile is created with a fileobj instead of a filename,
596614
# the passed file object is not closed when (I)GzipFile.close()
597615
# is called. This forces it to be closed.
598-
gzip_file.myfileobj = fileobj
616+
if closefd:
617+
gzip_file.myfileobj = fileobj
599618
return gzip_file
600619

601620

@@ -647,23 +666,23 @@ def _detect_format_from_extension(filename: Union[str, bytes]) -> Optional[str]:
647666

648667
def _file_or_path_to_name_and_binary_stream(
649668
file_or_path: FileOrPath, binary_mode: str
650-
) -> Tuple[str, BinaryIO]:
669+
) -> Tuple[str, BinaryIO, bool]:
651670
if binary_mode not in ("rb", "wb", "ab"):
652671
raise AssertionError()
653672
if file_or_path == "-":
654-
return "", _open_stdin_or_out(binary_mode)
673+
return "", _open_stdin_or_out(binary_mode), False
655674
if isinstance(file_or_path, (str, bytes, os.PathLike)):
656675
filepath = os.fspath(file_or_path)
657676
if isinstance(filepath, bytes):
658677
filepath = filepath.decode()
659-
return filepath, open(os.fspath(file_or_path), binary_mode) # type: ignore
678+
return filepath, open(os.fspath(file_or_path), binary_mode), True # type: ignore
660679
if isinstance(file_or_path, (io.BufferedReader, io.BufferedWriter)):
661-
return file_or_path.name, file_or_path
680+
return file_or_path.name, file_or_path, False
662681
if isinstance(file_or_path, io.TextIOWrapper):
663-
return file_or_path.name, file_or_path.buffer
682+
return file_or_path.name, file_or_path.buffer, False
664683
if isinstance(file_or_path, io.IOBase) and not hasattr(file_or_path, "encoding"):
665684
# Text files have encoding attributes. This file is binary:
666-
return "", file_or_path
685+
return "", file_or_path, False
667686
else:
668687
raise TypeError(
669688
f"Unsupported type for {file_or_path}, "
@@ -761,7 +780,9 @@ def xopen( # noqa: C901 # The function is complex, but readable.
761780
if mode not in ("rt", "rb", "wt", "wb", "at", "ab"):
762781
raise ValueError("Mode '{}' not supported".format(mode))
763782
binary_mode = mode[0] + "b"
764-
filepath, fileobj = _file_or_path_to_name_and_binary_stream(filename, binary_mode)
783+
filepath, fileobj, closefd = _file_or_path_to_name_and_binary_stream(
784+
filename, binary_mode
785+
)
765786

766787
if format not in (None, "gz", "xz", "bz2", "zst"):
767788
raise ValueError(
@@ -773,13 +794,13 @@ def xopen( # noqa: C901 # The function is complex, but readable.
773794
detected_format = _detect_format_from_content(fileobj)
774795

775796
if detected_format == "gz":
776-
opened_file = _open_gz(fileobj, binary_mode, compresslevel, threads)
797+
opened_file = _open_gz(fileobj, binary_mode, compresslevel, threads, closefd)
777798
elif detected_format == "xz":
778-
opened_file = _open_xz(fileobj, binary_mode, compresslevel, threads)
799+
opened_file = _open_xz(fileobj, binary_mode, compresslevel, threads, closefd)
779800
elif detected_format == "bz2":
780-
opened_file = _open_bz2(fileobj, binary_mode, threads)
801+
opened_file = _open_bz2(fileobj, binary_mode, threads, closefd)
781802
elif detected_format == "zst":
782-
opened_file = _open_zst(fileobj, binary_mode, compresslevel, threads)
803+
opened_file = _open_zst(fileobj, binary_mode, compresslevel, threads, closefd)
783804
else:
784805
opened_file = fileobj
785806

0 commit comments

Comments
 (0)