diff --git a/py/objstr.c b/py/objstr.c index 7000ed1fb5..329dfe6dd9 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -33,6 +33,7 @@ const mp_obj_t mp_const_empty_bytes; STATIC mp_obj_t mp_obj_new_str_iterator(mp_obj_t str); STATIC mp_obj_t mp_obj_new_bytes_iterator(mp_obj_t str); STATIC mp_obj_t str_new(const mp_obj_type_t *type, const byte* data, uint len); +STATIC void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn)); /******************************************************************************/ /* str */ @@ -367,38 +368,71 @@ bad_arg: #define is_ws(c) ((c) == ' ' || (c) == '\t') STATIC mp_obj_t str_split(uint n_args, const mp_obj_t *args) { - int splits = -1; + machine_int_t splits = -1; mp_obj_t sep = mp_const_none; if (n_args > 1) { sep = args[1]; if (n_args > 2) { - splits = MP_OBJ_SMALL_INT_VALUE(args[2]); + splits = mp_obj_get_int(args[2]); } } - assert(sep == mp_const_none); - (void)sep; // unused; to hush compiler warning + mp_obj_t res = mp_obj_new_list(0, NULL); GET_STR_DATA_LEN(args[0], s, len); const byte *top = s + len; - const byte *start; - // Initial whitespace is not counted as split, so we pre-do it - while (s < top && is_ws(*s)) s++; - while (s < top && splits != 0) { - start = s; - while (s < top && !is_ws(*s)) s++; - mp_obj_list_append(res, mp_obj_new_str(start, s - start, false)); - if (s >= top) { - break; - } + if (sep == mp_const_none) { + // sep not given, so separate on whitespace + + // Initial whitespace is not counted as split, so we pre-do it while (s < top && is_ws(*s)) s++; - if (splits > 0) { - splits--; + while (s < top && splits != 0) { + const byte *start = s; + while (s < top && !is_ws(*s)) s++; + mp_obj_list_append(res, mp_obj_new_str(start, s - start, false)); + if (s >= top) { + break; + } + while (s < top && is_ws(*s)) s++; + if (splits > 0) { + splits--; + } } - } - if (s < top) { - mp_obj_list_append(res, mp_obj_new_str(s, top - s, false)); + if (s < top) { + mp_obj_list_append(res, mp_obj_new_str(s, top - s, false)); + } + + } else { + // sep given + + uint sep_len; + const char *sep_str = mp_obj_str_get_data(sep, &sep_len); + + if (sep_len == 0) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "empty separator")); + } + + for (;;) { + const byte *start = s; + for (;;) { + if (splits == 0 || s + sep_len > top) { + s = top; + break; + } else if (memcmp(s, sep_str, sep_len) == 0) { + break; + } + s++; + } + mp_obj_list_append(res, mp_obj_new_str(start, s - start, false)); + if (s >= top) { + break; + } + s += sep_len; + if (splits > 0) { + splits--; + } + } } return res; @@ -1052,7 +1086,7 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t } pfenv_print_int(&pfenv_vstr, arg_as_int(arg), 1, 16, 'A', flags, fill, width); break; - + default: nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "unsupported format character '%c' (0x%x) at index %d", @@ -1191,8 +1225,7 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) { STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, machine_int_t direction) { assert(MP_OBJ_IS_STR(self_in)); if (!MP_OBJ_IS_STR(arg)) { - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, - "Can't convert '%s' object to str implicitly", mp_obj_get_type_str(arg))); + bad_implicit_conversion(arg); } GET_STR_DATA_LEN(self_in, str, str_len); @@ -1365,8 +1398,7 @@ bool mp_obj_str_equal(mp_obj_t s1, mp_obj_t s2) { } } -void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn)); -void bad_implicit_conversion(mp_obj_t self_in) { +STATIC void bad_implicit_conversion(mp_obj_t self_in) { nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "Can't convert '%s' object to str implicitly", mp_obj_get_type_str(self_in))); } diff --git a/tests/basics/string_split.py b/tests/basics/string_split.py index f73cb4291e..398a115397 100644 --- a/tests/basics/string_split.py +++ b/tests/basics/string_split.py @@ -1,3 +1,4 @@ +# default separator (whitespace) print("a b".split()) print(" a b ".split(None)) print(" a b ".split(None, 1)) @@ -5,3 +6,23 @@ print(" a b ".split(None, 2)) print(" a b c ".split(None, 1)) print(" a b c ".split(None, 0)) print(" a b c ".split(None, -1)) + +# empty separator should fail +try: + "abc".split('') +except ValueError: + print("ValueError") + +# non-empty separator +print("abc".split("a")) +print("abc".split("b")) +print("abc".split("c")) +print("abc".split("z")) +print("abc".split("ab")) +print("abc".split("bc")) +print("abc".split("abc")) +print("abc".split("abcd")) +print("abcabc".split("bc")) +print("abcabc".split("bc", 0)) +print("abcabc".split("bc", 1)) +print("abcabc".split("bc", 2))