1
0
mirror of https://github.com/openSUSE/osc.git synced 2025-08-23 14:48:53 +02:00

Fix output.safe_write() in connection with NamedTemporaryFile

Fixes: AttributeError: '_io.BufferedRandom' object has no attribute 'buffer'
This commit is contained in:
2024-04-16 16:50:10 +02:00
parent d2503fbf49
commit d1111e23a1
2 changed files with 31 additions and 4 deletions

View File

@@ -143,9 +143,15 @@ def safe_write(file: TextIO, text: Union[str, bytes], *, add_newline: bool = Fal
""" """
text = sanitize_text(text) text = sanitize_text(text)
if isinstance(text, bytes): if isinstance(text, bytes):
file.buffer.write(text) if hasattr(file, "buffer"):
if add_newline: file.buffer.write(text)
file.buffer.write(os.linesep.encode("utf-8")) if add_newline:
file.buffer.write(os.linesep.encode("utf-8"))
else:
# file has no "buffer" attribute, let's try to write the bytes directly
file.write(text)
if add_newline:
file.write(os.linesep.encode("utf-8"))
else: else:
file.write(text) file.write(text)
if add_newline: if add_newline:
@@ -174,7 +180,8 @@ def run_pager(message: Union[bytes, str], tmp_suffix: str = ""):
safe_write(sys.stdout, message) safe_write(sys.stdout, message)
return return
with tempfile.NamedTemporaryFile(suffix=tmp_suffix) as tmpfile: mode = "w+b" if isinstance(message, bytes) else "w+"
with tempfile.NamedTemporaryFile(mode=mode, suffix=tmp_suffix) as tmpfile:
safe_write(tmpfile, message) safe_write(tmpfile, message)
tmpfile.flush() tmpfile.flush()

View File

@@ -1,10 +1,12 @@
import contextlib import contextlib
import io import io
import tempfile
import unittest import unittest
import osc.conf import osc.conf
from osc.output import KeyValueTable from osc.output import KeyValueTable
from osc.output import print_msg from osc.output import print_msg
from osc.output import safe_write
from osc.output import sanitize_text from osc.output import sanitize_text
from osc.output import tty from osc.output import tty
@@ -238,5 +240,23 @@ class TestSanitization(unittest.TestCase):
self.assertEqual(sanitized, b"0;this is the window title") self.assertEqual(sanitized, b"0;this is the window title")
class TestSafeWrite(unittest.TestCase):
def test_string_to_file(self):
with tempfile.NamedTemporaryFile(mode="w+") as f:
safe_write(f, "string")
def test_bytes_to_file(self):
with tempfile.NamedTemporaryFile(mode="wb+") as f:
safe_write(f, b"bytes")
def test_string_to_stringio(self):
with io.StringIO() as f:
safe_write(f, "string")
def test_bytes_to_bytesio(self):
with io.BytesIO() as f:
safe_write(f, b"bytes")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()