diff --git a/extmod/modre.c b/extmod/modre.c index 7f00b1c23c..8697e1c6ef 100644 --- a/extmod/modre.c +++ b/extmod/modre.c @@ -80,7 +80,13 @@ STATIC mp_obj_t match_group(mp_obj_t self_in, mp_obj_t no_in) { // no match for this group return mp_const_none; } - return mp_obj_new_str_of_type(mp_obj_get_type(self->str), + const mp_obj_type_t *str_type = mp_obj_get_type(self->str); + if (str_type != &mp_type_str) { + // bytes, bytearray etc. args should return bytes + str_type = &mp_type_bytes; + } + + return mp_obj_new_str_of_type(str_type, (const byte *)start, self->caps[no * 2 + 1] - start); } MP_DEFINE_CONST_FUN_OBJ_2(match_group_obj, match_group); @@ -120,7 +126,9 @@ STATIC void match_span_helper(size_t n_args, const mp_obj_t *args, mp_obj_t span const char *start = self->caps[no * 2]; if (start != NULL) { // have a match for this group - const char *begin = mp_obj_str_get_str(self->str); + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(self->str, &bufinfo, MP_BUFFER_READ); + const char *begin = bufinfo.buf; s = start - begin; e = self->caps[no * 2 + 1] - begin; } @@ -203,9 +211,10 @@ STATIC mp_obj_t re_exec(bool is_anchored, uint n_args, const mp_obj_t *args) { self = MP_OBJ_TO_PTR(mod_re_compile(1, args)); } Subject subj; - size_t len; - subj.begin_line = subj.begin = mp_obj_str_get_data(args[1], &len); - subj.end = subj.begin + len; + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; mp_obj_match_t *match = m_new_obj_var(mp_obj_match_t, char *, caps_num); // cast is a workaround for a bug in msvc: it treats const char** as a const pointer instead of a pointer to pointer to const char @@ -235,10 +244,15 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(re_search_obj, 2, 4, re_search); STATIC mp_obj_t re_split(size_t n_args, const mp_obj_t *args) { mp_obj_re_t *self = MP_OBJ_TO_PTR(args[0]); Subject subj; - size_t len; + mp_buffer_info_t bufinfo; const mp_obj_type_t *str_type = mp_obj_get_type(args[1]); - subj.begin_line = subj.begin = mp_obj_str_get_data(args[1], &len); - subj.end = subj.begin + len; + if (str_type != &mp_type_str) { + // bytes, bytearray etc. args should return bytes + str_type = &mp_type_bytes; + } + mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; int maxsplit = 0; @@ -294,11 +308,11 @@ STATIC mp_obj_t re_sub_helper(size_t n_args, const mp_obj_t *args) { // Note: flags are currently ignored } - size_t where_len; - const char *where_str = mp_obj_str_get_data(where, &where_len); Subject subj; - subj.begin_line = subj.begin = where_str; - subj.end = subj.begin + where_len; + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(where, &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; vstr_t vstr_return; @@ -327,10 +341,13 @@ STATIC mp_obj_t re_sub_helper(size_t n_args, const mp_obj_t *args) { vstr_add_strn(&vstr_return, subj.begin, match->caps[0] - subj.begin); // Get replacement string - const char *repl = mp_obj_str_get_str((mp_obj_is_callable(replace) ? mp_call_function_1(replace, MP_OBJ_FROM_PTR(match)) : replace)); + mp_obj_t repl_obj = (mp_obj_is_callable(replace) ? mp_call_function_1(replace, MP_OBJ_FROM_PTR(match)) : replace); + mp_get_buffer_raise(repl_obj, &bufinfo, MP_BUFFER_READ); + const char *repl = bufinfo.buf; + const char *repl_top = repl + bufinfo.len; // Append replacement string to result, substituting any regex groups - while (*repl != '\0') { + while (repl < repl_top) { if (*repl == '\\') { ++repl; bool is_g_format = false; @@ -423,8 +440,11 @@ STATIC MP_DEFINE_CONST_OBJ_TYPE( STATIC mp_obj_t mod_re_compile(size_t n_args, const mp_obj_t *args) { (void)n_args; - const char *re_str = mp_obj_str_get_str(args[0]); - int size = re1_5_sizecode(re_str); + + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ); + const char *re_str = bufinfo.buf; + int size = re1_5_sizecode(re_str, bufinfo.len); if (size == -1) { goto error; } @@ -435,7 +455,7 @@ STATIC mp_obj_t mod_re_compile(size_t n_args, const mp_obj_t *args) { flags = mp_obj_get_int(args[1]); } #endif - int error = re1_5_compilecode(&o->re, re_str); + int error = re1_5_compilecode(&o->re, re_str, bufinfo.len); if (error != 0) { error: mp_raise_ValueError(MP_ERROR_TEXT("error in regex")); diff --git a/tests/extmod/re1.py b/tests/extmod/re1.py index 7e3839ae24..4fd5820a15 100644 --- a/tests/extmod/re1.py +++ b/tests/extmod/re1.py @@ -93,6 +93,23 @@ m = re.match(rb"a+?", b"ab") print(m.group(0)) print("===") +# bytearray / memoryview objects +m = re.match(rb"a.", bytearray(b"ab")) +print(m.group(0)) +m = re.match(rb"a.", memoryview(b"ab")) +print(m.group(0)) +# While micropython supports bytearray pattern, cpython does not. +# m = re.match(bytearray(b"a."), b"ab") +# print(m.group(0)) +print("===") + +# null chars +m = re.match("ab.d", "ab\x00d") +print(list(m.group(0))) +m = re.match("ab\x00d", "ab\x00d") +print(list(m.group(0))) +print("===") + # escaping m = re.match(r"a\.c", "a.c") print(m.group(0) if m else "") diff --git a/tests/extmod/re_split.py b/tests/extmod/re_split.py index 7769e1a121..486b1c3881 100644 --- a/tests/extmod/re_split.py +++ b/tests/extmod/re_split.py @@ -38,3 +38,8 @@ print(s) r = re.compile("^ab|cab") s = r.split("abababcabab") print(s) + +# bytearray objects +r = re.compile(b"x") +s = r.split(bytearray(b"fooxbar")) +print(s) diff --git a/tests/extmod/re_sub.py b/tests/extmod/re_sub.py index 229c0e63ee..779d32374f 100644 --- a/tests/extmod/re_sub.py +++ b/tests/extmod/re_sub.py @@ -26,6 +26,13 @@ def A(): print(re.sub("a", A(), "aBCBABCDabcda.")) + +def B(): + return bytearray(b"B") + + +print(re.sub(b"a", B(), b"aBCBABCDabcda.")) + print( re.sub( r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):", @@ -61,10 +68,11 @@ try: except: print("invalid group") -# Module function takes str/bytes/re. +# Module function takes str/bytes/re/bytearray. print(re.sub("a", "a", "a")) print(re.sub(b".", b"a", b"a")) print(re.sub(re.compile("a"), "a", "a")) +print(re.sub(b"a", bytearray(b"b"), bytearray(b"a"))) try: re.sub(123, "a", "a") except TypeError: