ws_client tests: Updated example test to use WebsSocket package

Added a new test for closing connection with close frames
pull/5919/head
David Cermak 2020-07-21 16:04:25 +02:00 zatwierdzone przez bot
rodzic 5e9f8b52e7
commit 44c553fd14
1 zmienionych plików z 38 dodań i 148 usunięć

Wyświetl plik

@ -3,12 +3,10 @@ from __future__ import unicode_literals
import re import re
import os import os
import socket import socket
import select
import hashlib
import base64
import queue
import random import random
import string import string
from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket
from tiny_test_fw import Utility
from threading import Thread, Event from threading import Thread, Event
import ttfw_idf import ttfw_idf
@ -26,159 +24,45 @@ def get_my_ip():
return IP return IP
class TestEcho(WebSocket):
def handleMessage(self):
self.sendMessage(self.data)
print('Server sent: {}'.format(self.data))
def handleConnected(self):
print('Connection from: {}'.format(self.address))
def handleClose(self):
print('{} closed the connection'.format(self.address))
# Simple Websocket server for testing purposes # Simple Websocket server for testing purposes
class Websocket: class Websocket(object):
HEADER_LEN = 6
def send_data(self, data):
for nr, conn in self.server.connections.items():
conn.sendMessage(data)
def run(self):
self.server = SimpleWebSocketServer('', self.port, TestEcho)
while not self.exit_event.is_set():
self.server.serveonce()
def __init__(self, port): def __init__(self, port):
self.port = port self.port = port
self.socket = socket.socket() self.exit_event = Event()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.thread = Thread(target=self.run)
self.socket.settimeout(10.0) self.thread.start()
self.send_q = queue.Queue()
self.shutdown = Event()
def __enter__(self): def __enter__(self):
try:
self.socket.bind(('', self.port))
except socket.error as e:
print("Bind failed:{}".format(e))
raise
self.socket.listen(1)
self.server_thread = Thread(target=self.run_server)
self.server_thread.start()
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.shutdown.set() self.exit_event.set()
self.server_thread.join() self.thread.join(10)
self.socket.close() if self.thread.is_alive():
self.conn.close() Utility.console_log('Thread cannot be joined', 'orange')
def run_server(self):
self.conn, address = self.socket.accept() # accept new connection
self.socket.settimeout(10.0)
print("Connection from: {}".format(address))
self.establish_connection()
print("WS established")
# Handle connection until client closes it, will echo any data received and send data from send_q queue
self.handle_conn()
def establish_connection(self):
while not self.shutdown.is_set():
try:
# receive data stream. it won't accept data packet greater than 1024 bytes
data = self.conn.recv(1024).decode()
if not data:
# exit if data is not received
raise
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
self.handshake(data)
return
except socket.error as err:
print("Unable to establish a websocket connection: {}".format(err))
raise
def handshake(self, data):
# Magic string from RFC
MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
headers = data.split("\r\n")
for header in headers:
if "Sec-WebSocket-Key" in header:
client_key = header.split()[1]
if client_key:
resp_key = client_key + MAGIC_STRING
resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
"Upgrade: websocket\r\n" + \
"Connection: Upgrade\r\n" + \
"Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
self.conn.send(resp.encode())
def handle_conn(self):
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
if self.conn in r:
self.echo_data()
if not self.send_q.empty():
self._send_data_(self.send_q.get())
except socket.error as err:
print("Stopped echoing data: {}".format(err))
raise
def echo_data(self):
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if socket closed by peer
return
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
if not payload:
# exit if socket closed by peer
return
frame = header + payload
decoded_payload = self.decode_frame(frame)
print("Sending echo...")
self._send_data_(decoded_payload)
def _send_data_(self, data):
frame = self.encode_frame(data)
self.conn.send(frame)
def send_data(self, data):
self.send_q.put(data.encode())
def decode_frame(self, frame):
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
payload_len = ~(1 << 7) & frame[1]
mask = frame[2:self.HEADER_LEN]
encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
payload = bytearray()
for i in range(payload_len):
payload.append(encrypted_payload[i] ^ mask[i % 4])
return payload
def encode_frame(self, payload):
# Set FIN = 1 and OP_CODE = 1 (text)
header = (1 << 7) | (1 << 0)
frame = bytearray([header])
payload_len = len(payload)
# If payload len is longer than 125 then the next 16 bits are used to encode length
if payload_len > 125:
frame.append(126)
frame.append(payload_len >> 8)
frame.append(0xFF & payload_len)
else:
frame.append(payload_len)
frame += payload
return frame
def test_echo(dut): def test_echo(dut):
@ -188,6 +72,11 @@ def test_echo(dut):
print("All echos received") print("All echos received")
def test_close(dut):
code = dut.expect(re.compile(r"WEBSOCKET: Received closed message with code=(\d*)"), timeout=60)[0]
print("Received close frame with code {}".format(code))
def test_recv_long_msg(dut, websocket, msg_len, repeats): def test_recv_long_msg(dut, websocket, msg_len, repeats):
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len)) send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
@ -246,6 +135,7 @@ def test_examples_protocol_websocket(env, extra_data):
test_echo(dut1) test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
test_recv_long_msg(dut1, ws, 2000, 3) test_recv_long_msg(dut1, ws, 2000, 3)
test_close(dut1)
else: else:
print("DUT connecting to {}".format(uri)) print("DUT connecting to {}".format(uri))