diff --git a/micropython/usb/usb-device-cdc/usb/device/cdc.py b/micropython/usb/usb-device-cdc/usb/device/cdc.py index 46509ad0..4abeffc3 100644 --- a/micropython/usb/usb-device-cdc/usb/device/cdc.py +++ b/micropython/usb/usb-device-cdc/usb/device/cdc.py @@ -289,38 +289,32 @@ class CDCInterface(io.IOBase, Interface): bmRequestType, bRequest, wValue, wIndex, wLength = struct.unpack("BBHHH", request) recipient, req_type, req_dir = split_bmRequestType(bmRequestType) - # Only for the control interface if wIndex != self._c_itf: - return False + return False # Only for the control interface (may be redundant check?) + + if req_type != _REQ_TYPE_CLASS: + return False # Unsupported request type if stage == _STAGE_SETUP: - if req_type == _REQ_TYPE_CLASS: - if bRequest == _SET_LINE_CODING_REQ: - if wLength == len(self._line_coding): - return self._line_coding - return False # wrong length - elif bRequest == _GET_LINE_CODING_REQ: - return self._line_coding - elif bRequest == _SET_CONTROL_LINE_STATE: - if wLength == 0: - self._line_state = wValue - if self.line_state_cb: - self.line_state_cb(wValue) - return b"" - else: - return False # wrong length - elif bRequest == _SEND_BREAK_REQ: - if self.break_cb: - self.break_cb(wValue) - return b"" + if bRequest in (_SET_LINE_CODING_REQ, _GET_LINE_CODING_REQ): + return self._line_coding # Buffer to read or write - if stage == _STAGE_DATA: - if req_type == _REQ_TYPE_CLASS: - if bRequest == _SET_LINE_CODING_REQ: - if self.line_coding_cb: - self.line_coding_cb(self._line_coding) + # Continue on other supported requests, stall otherwise + return bRequest in (_SET_CONTROL_LINE_STATE, _SEND_BREAK_REQ) - return True + if stage == _STAGE_ACK: + if bRequest == _SET_LINE_CODING_REQ: + if self.line_coding_cb: + self.line_coding_cb(self._line_coding) + elif bRequest == _SET_CONTROL_LINE_STATE: + self._line_state = wValue + if self.line_state_cb: + self.line_state_cb(wValue) + elif bRequest == _SEND_BREAK_REQ: + if self.break_cb: + self.break_cb(wValue) + + return True # allow DATA/ACK stages to complete normally def _wr_xfer(self): # Submit a new data IN transfer from the _wb buffer, if needed diff --git a/micropython/usb/usb-device-hid/usb/device/hid.py b/micropython/usb/usb-device-hid/usb/device/hid.py index 760e8b32..9d74b39b 100644 --- a/micropython/usb/usb-device-hid/usb/device/hid.py +++ b/micropython/usb/usb-device-hid/usb/device/hid.py @@ -206,37 +206,25 @@ class HIDInterface(Interface): return bytes([self.idle_rate]) if bRequest == _REQ_CONTROL_GET_PROTOCOL: return bytes([self.protocol]) + if bRequest in (_REQ_CONTROL_SET_IDLE, _REQ_CONTROL_SET_PROTOCOL): + return True + if bRequest == _REQ_CONTROL_SET_REPORT: + return self._set_report_buf # If None, request will stall + return False # Unsupported request + + if stage == _STAGE_ACK: + if req_type == _REQ_TYPE_CLASS: if bRequest == _REQ_CONTROL_SET_IDLE: self.idle_rate = wValue >> 8 - return b"" - if bRequest == _REQ_CONTROL_SET_PROTOCOL: + elif bRequest == _REQ_CONTROL_SET_PROTOCOL: self.protocol = wValue - return b"" - if bRequest == _REQ_CONTROL_SET_REPORT: - # Return the _set_report_buf to be filled with the - # report data - if not self._set_report_buf: - return False - elif wLength >= len(self._set_report_buf): - # Saves an allocation if the size is exactly right (or will be a short read) - return self._set_report_buf - else: - # Otherwise, need to wrap the buffer in a memoryview of the correct length - # - # TODO: check this is correct, maybe TinyUSB won't mind if we ask for more - # bytes than the host has offered us. - return memoryview(self._set_report_buf)[:wLength] - return False # Unsupported - - if stage == _STAGE_DATA: - if req_type == _REQ_TYPE_CLASS: - if bRequest == _REQ_CONTROL_SET_REPORT and self._set_report_buf: + elif bRequest == _REQ_CONTROL_SET_REPORT: report_id = wValue & 0xFF report_type = wValue >> 8 report_data = self._set_report_buf if wLength < len(report_data): - # as above, need to truncate the buffer if we read less - # bytes than what was provided + # need to truncate the response in the callback if we got less bytes + # than allowed for in the buffer report_data = memoryview(self._set_report_buf)[:wLength] self.on_set_report(report_data, report_id, report_type)