diff --git a/tools/mpremote/mpremote/commands.py b/tools/mpremote/mpremote/commands.py index aae612765b..b063ec587c 100644 --- a/tools/mpremote/mpremote/commands.py +++ b/tools/mpremote/mpremote/commands.py @@ -56,7 +56,11 @@ def do_connect(state, args=None): # Connect to the given device. if dev.startswith("port:"): dev = dev[len("port:") :] - state.transport = SerialTransport(dev, baudrate=115200) + opts = {} + if args and args.insecure_certificate: + import ssl + opts = {"sslopt": {"cert_reqs": ssl.CERT_NONE}} + state.transport = SerialTransport(dev, baudrate=115200, **opts) return except TransportError as er: msg = er.args[0] diff --git a/tools/mpremote/mpremote/main.py b/tools/mpremote/mpremote/main.py index eeb9cbd989..d3d2d95a13 100644 --- a/tools/mpremote/mpremote/main.py +++ b/tools/mpremote/mpremote/main.py @@ -97,6 +97,7 @@ def _bool_flag(cmd_parser, name, short_name, default, description): def argparse_connect(): cmd_parser = argparse.ArgumentParser(description="connect to given device") + _bool_flag(cmd_parser, "insecure-certificate", "k", False, "Don't check certificate validity.") cmd_parser.add_argument( "device", nargs=1, help="Either list, auto, id:x, port:x, or any valid device name/path" ) diff --git a/tools/mpremote/mpremote/transport_serial.py b/tools/mpremote/mpremote/transport_serial.py index 3b4cd00078..cfcc9bb302 100644 --- a/tools/mpremote/mpremote/transport_serial.py +++ b/tools/mpremote/mpremote/transport_serial.py @@ -57,9 +57,52 @@ def reraise_filesystem_error(e, info): raise FileNotFoundError(info) raise +class WebsockSerial: + def __init__(self, device, sslopt = None): + import websocket, urllib.parse + url_tup = urllib.parse.urlsplit(device) + port = url_tup.port + if port is None: + port = 8266 + hostname = url_tup.hostname + if ":" in hostname: + hostname = "[" + hostname + "]" + self.websock = websocket.create_connection(url_tup.scheme + "://" + hostname + ":" + str(port) + "/", sslopt=sslopt) + self.buf = b"" + if url_tup.password is not None: + self.write(url_tup.password.encode("utf8") + b"\r") + + @property + def fd(self): + return self.websock.sock.fileno() + + def close(self): + self.websock.close() + + def inWaiting(self): + while True: + import array, fcntl + buf = array.array('h', [0]) + FIONREAD = 0x541B + fcntl.ioctl(self.fd, FIONREAD, buf) + if buf[0] == 0: + break + self.buf += self.websock.recv() + return len(self.buf) + + def read(self, n): + while len(self.buf) < n: + self.buf += self.websock.recv() + out = self.buf[:n] + self.buf = self.buf[n:] + return out + + def write(self, buf): + self.websock.send(buf) + class SerialTransport(Transport): - def __init__(self, device, baudrate=115200, wait=0, exclusive=True): + def __init__(self, device, baudrate=115200, wait=0, exclusive=True, sslopt={}): self.in_raw_repl = False self.use_raw_paste = True self.device_name = device @@ -78,6 +121,8 @@ class SerialTransport(Transport): try: if device.startswith("rfc2217://"): self.serial = serial.serial_for_url(device, **serial_kwargs) + elif device.startswith("ws://") or device.startswith("wss://"): + self.serial = WebsockSerial(device, sslopt=sslopt) elif os.name == "nt": self.serial = serial.Serial(**serial_kwargs) self.serial.port = device @@ -93,6 +138,8 @@ class SerialTransport(Transport): self.serial = serial.Serial(device, **serial_kwargs) break except OSError: + import traceback + traceback.print_exc() if wait == 0: continue if attempt == 0: @@ -816,6 +863,8 @@ class RemoteFile(io.IOBase): c = self.cmd c.begin(CMD_WRITE) c.wr_s8(self.fd) + if self.is_text: + buf = bytes(buf, 'utf8') c.wr_bytes(buf) n = c.rd_s32() c.end() @@ -1204,7 +1253,8 @@ class SerialIntercept: self.orig_serial.close() def inWaiting(self): - self._check_input(False) + while self.orig_serial.inWaiting() > 0: + self._check_input(False) return len(self.buf) def read(self, n): diff --git a/tools/mpremote/requirements.txt b/tools/mpremote/requirements.txt index 6209cde5c3..74f58d72e8 100644 --- a/tools/mpremote/requirements.txt +++ b/tools/mpremote/requirements.txt @@ -1,2 +1,3 @@ pyserial >= 3.3 importlib_metadata >= 1.4; python_version < "3.8" +websocket-client >= 1.7.0