mirror of
https://github.com/openSUSE/osc.git
synced 2025-08-23 14:48:53 +02:00
Fix output.safe_write() in connection with NamedTemporaryFile
Fixes: AttributeError: '_io.BufferedRandom' object has no attribute 'buffer'
This commit is contained in:
@@ -143,9 +143,15 @@ def safe_write(file: TextIO, text: Union[str, bytes], *, add_newline: bool = Fal
|
|||||||
"""
|
"""
|
||||||
text = sanitize_text(text)
|
text = sanitize_text(text)
|
||||||
if isinstance(text, bytes):
|
if isinstance(text, bytes):
|
||||||
file.buffer.write(text)
|
if hasattr(file, "buffer"):
|
||||||
if add_newline:
|
file.buffer.write(text)
|
||||||
file.buffer.write(os.linesep.encode("utf-8"))
|
if add_newline:
|
||||||
|
file.buffer.write(os.linesep.encode("utf-8"))
|
||||||
|
else:
|
||||||
|
# file has no "buffer" attribute, let's try to write the bytes directly
|
||||||
|
file.write(text)
|
||||||
|
if add_newline:
|
||||||
|
file.write(os.linesep.encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
file.write(text)
|
file.write(text)
|
||||||
if add_newline:
|
if add_newline:
|
||||||
@@ -174,7 +180,8 @@ def run_pager(message: Union[bytes, str], tmp_suffix: str = ""):
|
|||||||
safe_write(sys.stdout, message)
|
safe_write(sys.stdout, message)
|
||||||
return
|
return
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=tmp_suffix) as tmpfile:
|
mode = "w+b" if isinstance(message, bytes) else "w+"
|
||||||
|
with tempfile.NamedTemporaryFile(mode=mode, suffix=tmp_suffix) as tmpfile:
|
||||||
safe_write(tmpfile, message)
|
safe_write(tmpfile, message)
|
||||||
tmpfile.flush()
|
tmpfile.flush()
|
||||||
|
|
||||||
|
@@ -1,10 +1,12 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import osc.conf
|
import osc.conf
|
||||||
from osc.output import KeyValueTable
|
from osc.output import KeyValueTable
|
||||||
from osc.output import print_msg
|
from osc.output import print_msg
|
||||||
|
from osc.output import safe_write
|
||||||
from osc.output import sanitize_text
|
from osc.output import sanitize_text
|
||||||
from osc.output import tty
|
from osc.output import tty
|
||||||
|
|
||||||
@@ -238,5 +240,23 @@ class TestSanitization(unittest.TestCase):
|
|||||||
self.assertEqual(sanitized, b"0;this is the window title")
|
self.assertEqual(sanitized, b"0;this is the window title")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeWrite(unittest.TestCase):
|
||||||
|
def test_string_to_file(self):
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+") as f:
|
||||||
|
safe_write(f, "string")
|
||||||
|
|
||||||
|
def test_bytes_to_file(self):
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb+") as f:
|
||||||
|
safe_write(f, b"bytes")
|
||||||
|
|
||||||
|
def test_string_to_stringio(self):
|
||||||
|
with io.StringIO() as f:
|
||||||
|
safe_write(f, "string")
|
||||||
|
|
||||||
|
def test_bytes_to_bytesio(self):
|
||||||
|
with io.BytesIO() as f:
|
||||||
|
safe_write(f, b"bytes")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user