diff --git a/py/objstringio.c b/py/objstringio.c index 1a083449ba..b137c3396e 100644 --- a/py/objstringio.c +++ b/py/objstringio.c @@ -77,6 +77,13 @@ STATIC void stringio_copy_on_write(mp_obj_stringio_t *o) { STATIC mp_uint_t stringio_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { (void)errcode; + + if (!mp_obj_is_type(o_in, &mp_type_stringio)) { + + if (!mp_obj_is_type(o_in, &mp_type_bytesio)) { + mp_raise_TypeError(MP_ERROR_TEXT("expecting a StringIO or BytesIO object")); + } + } mp_obj_stringio_t *o = MP_OBJ_TO_PTR(o_in); check_stringio_is_open(o); diff --git a/tests/io/stringio_type_check.py b/tests/io/stringio_type_check.py new file mode 100644 index 0000000000..c26eda5c21 --- /dev/null +++ b/tests/io/stringio_type_check.py @@ -0,0 +1,31 @@ +import sys + +sys.path[0] = ".frozen" # avoid local dir io import + +import io + + +class TestStream(io.StringIO): + def __init_(self, alloc_size): + super().__init__(alloc_size) + + +class ByteStream(io.BytesIO): + def __init_(self, alloc_size): + super().__init__(alloc_size) + + +test_stringio = TestStream(100) + +test_bytesio = ByteStream(100) + +try: + print("hello", file=test_stringio) +except Exception as e: + print(e) + + +try: + print("hello", file=test_bytesio) +except Exception as e: + print(e) diff --git a/tests/io/stringio_type_check.py.exp b/tests/io/stringio_type_check.py.exp new file mode 100644 index 0000000000..d96829dce8 --- /dev/null +++ b/tests/io/stringio_type_check.py.exp @@ -0,0 +1,2 @@ +expecting a StringIO or BytesIO object +expecting a StringIO or BytesIO object