Expand source code
import errno
import hashlib
import itertools
import os
import random
import socket
from socket import socket as Socket
import ssl
import struct
from base64 import encodebytes, b64encode
from hmac import compare_digest
from logging import Logger
from threading import Lock
from typing import Tuple, Optional, Union, List, Callable, Dict
from urllib.parse import urlparse, unquote
from .frame_header import FrameHeader
def _parse_connect_response(sock: Socket) -> Tuple[Optional[int], str]:
status = None
lines = []
while True:
line = []
while True:
c = sock.recv(1)
if not c:
raise ConnectionError("Connection is closed")
line.append(c)
if c == b"\n":
break
line = b"".join(line).decode("utf-8").strip()
if line is None or len(line) == 0:
break
lines.append(line)
if not status:
status_line = line.split(" ", 2)
status = int(status_line[1])
return status, "\n".join(lines)
def _use_or_create_ssl_context(ssl_context: Optional[ssl.SSLContext] = None):
return ssl_context if ssl_context is not None else ssl.create_default_context()
def _establish_new_socket_connection(
session_id: str,
server_hostname: str,
server_port: int,
logger: Logger,
sock_send_lock: Lock,
receive_timeout: float,
proxy: Optional[str],
proxy_headers: Optional[Dict[str, str]],
trace_enabled: bool,
ssl_context: Optional[ssl.SSLContext] = None,
) -> Union[ssl.SSLSocket, Socket]:
ssl_context = _use_or_create_ssl_context(ssl_context)
if proxy is not None:
parsed_proxy = urlparse(proxy)
proxy_host, proxy_port = parsed_proxy.hostname, parsed_proxy.port or 80
sock = socket.create_connection((proxy_host, proxy_port), receive_timeout)
if hasattr(socket, "TCP_NODELAY"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "SO_KEEPALIVE"):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
message = [f"CONNECT {server_hostname}:{server_port} HTTP/1.0"]
if parsed_proxy.username is not None and parsed_proxy.password is not None:
# In the case where the proxy is "http://{username}:{password}@{hostname}:{port}"
raw_value = f"{unquote(parsed_proxy.username)}:{unquote(parsed_proxy.password)}"
auth = b64encode(raw_value.encode("utf-8")).decode("ascii")
message.append(f"Proxy-Authorization: Basic {auth}")
if proxy_headers is not None:
for k, v in proxy_headers.items():
message.append(f"{k}: {v}")
message.append("")
message.append("")
req: str = "\r\n".join([line.lstrip() for line in message])
if trace_enabled:
logger.debug(f"Proxy connect request (session id: {session_id}):\n{req}")
with sock_send_lock:
sock.send(req.encode("utf-8"))
status, text = _parse_connect_response(sock)
if trace_enabled:
log_message = f"Proxy connect response (session id: {session_id}):\n{text}"
logger.debug(log_message)
if status != 200:
raise Exception(f"Failed to connect to the proxy (proxy: {proxy}, connect status code: {status})")
sock = ssl_context.wrap_socket(
sock,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=server_hostname,
)
return sock
if server_port != 443:
# only for library testing
logger.info(f"Using non-ssl socket to connect ({server_hostname}:{server_port})")
sock = socket.create_connection((server_hostname, server_port), timeout=3)
return sock
sock = socket.create_connection((server_hostname, server_port), receive_timeout)
sock = ssl_context.wrap_socket(
sock,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=server_hostname,
)
return sock
def _read_http_response_line(sock: ssl.SSLSocket) -> str:
cs = []
while True:
b: bytes = sock.recv(1)
if not b:
raise ConnectionError("Connection is closed")
c: str = b.decode("utf-8")
if c == "\r":
break
if c != "\n":
cs.append(c)
return "".join(cs)
def _parse_handshake_response(sock: ssl.SSLSocket) -> Tuple[Optional[int], dict, str]:
"""Parses the handshake response.
Args:
sock: The current active socket
Returns:
(http status, headers, whole response as a str)
"""
lines = []
status = None
headers = {}
while True:
line = _read_http_response_line(sock)
if status is None:
elements = line.split(" ")
if len(elements) > 2:
status = int(elements[1])
else:
elements = line.split(":")
if len(elements) == 2:
headers[elements[0].strip().lower()] = elements[1].strip()
if line is None or len(line.strip()) == 0:
break
lines.append(line)
text = "\n".join(lines)
return (status, headers, text)
def _generate_sec_websocket_key() -> str:
return encodebytes(os.urandom(16)).decode("utf-8").strip()
def _validate_sec_websocket_accept(sec_websocket_key: str, headers: dict) -> bool:
v = (sec_websocket_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("utf-8")
expected = encodebytes(hashlib.sha1(v).digest()).decode("utf-8").strip()
actual = headers.get("sec-websocket-accept", "").strip()
return compare_digest(expected, actual)
def _to_readable_opcode(opcode: int) -> str:
if opcode == FrameHeader.OPCODE_CONTINUATION:
return "continuation"
if opcode == FrameHeader.OPCODE_TEXT:
return "text"
if opcode == FrameHeader.OPCODE_BINARY:
return "binary"
if opcode == FrameHeader.OPCODE_CLOSE:
return "close"
if opcode == FrameHeader.OPCODE_PING:
return "ping"
if opcode == FrameHeader.OPCODE_PONG:
return "pong"
return "-"
def _parse_text_payload(data: Optional[bytes], logger: Logger) -> str:
try:
if data is not None and isinstance(data, bytes):
return data.decode("utf-8")
else:
return ""
except UnicodeDecodeError as e:
logger.debug(f"Failed to parse a payload (data: {data}, error: {e})")
return ""
def _receive_messages(
sock: ssl.SSLSocket,
sock_receive_lock: Lock,
logger: Logger,
receive_buffer_size: int = 1024,
all_message_trace_enabled: bool = False,
) -> List[Tuple[Optional[FrameHeader], bytes]]:
def receive(specific_buffer_size: Optional[int] = None):
size = specific_buffer_size if specific_buffer_size is not None else receive_buffer_size
with sock_receive_lock:
try:
received_bytes = sock.recv(size)
if all_message_trace_enabled:
if len(received_bytes) > 0:
logger.debug(f"Received bytes: {received_bytes}")
return received_bytes
except OSError as e:
# For Linux/macOS, errno.EBADF is the expected error for bad connections.
# The errno.ENOTSOCK can be sent when running on Windows OS.
if e.errno in (errno.EBADF, errno.ENOTSOCK):
# Note that bad connections can be detected by monitoring threads
# the Socket Mode client automatically reconnects to a new endpoint later.
logger.debug("The connection seems to be already closed.")
return bytes()
raise e
return _fetch_messages(
messages=[],
receive=receive,
remaining_bytes=None,
current_mask_key=None,
current_header=None,
current_data=bytes(),
logger=logger,
)
def _fetch_messages(
messages: List[Tuple[Optional[FrameHeader], bytes]],
receive: Callable[[Optional[int]], bytes], # buffer size
logger: Logger,
remaining_bytes: Optional[bytes] = None,
current_mask_key: Optional[str] = None,
current_header: Optional[FrameHeader] = None,
current_data: Optional[bytes] = None,
) -> List[Tuple[Optional[FrameHeader], bytes]]:
if remaining_bytes is None:
# Fetch more to complete the current message
remaining_bytes = receive() # type: ignore
if remaining_bytes is None or len(remaining_bytes) == 0:
# no more bytes
if current_header is not None:
_append_message(messages, current_header, current_data)
return messages
if current_header is None:
# new message
if len(remaining_bytes) <= 2:
remaining_bytes += receive() # type: ignore
if remaining_bytes[0] == 10: # \n
if current_data is not None and len(current_data) >= 0:
_append_message(messages, current_header, current_data)
_append_message(messages, None, remaining_bytes[:1])
remaining_bytes = remaining_bytes[1:]
if len(remaining_bytes) == 0:
return messages
else:
return _fetch_messages(
messages=messages,
receive=receive,
remaining_bytes=remaining_bytes,
logger=logger,
)
# https://tools.ietf.org/html/rfc6455#section-5.2
b1, b2 = remaining_bytes[0], remaining_bytes[1]
# determine data length and the first index of the data part
current_data_length: int = b2 & 0b01111111
idx_after_length_part: int = 2
if current_data_length == 126:
if len(remaining_bytes) < 4:
remaining_bytes += receive(1024)
current_data_length = struct.unpack("!H", bytes(remaining_bytes[2:4]))[0]
idx_after_length_part = 4
elif current_data_length == 127:
if len(remaining_bytes) < 10:
remaining_bytes += receive(1024)
current_data_length = struct.unpack("!Q", bytes(remaining_bytes[2:10]))[0]
idx_after_length_part = 10
current_header = FrameHeader(
fin=b1 & 0b10000000,
rsv1=b1 & 0b01000000,
rsv2=b1 & 0b00100000,
rsv3=b1 & 0b00010000,
opcode=b1 & 0b00001111,
masked=b2 & 0b10000000,
length=current_data_length,
)
if current_header.masked > 0:
if current_mask_key is None:
idx1, idx2 = idx_after_length_part, idx_after_length_part + 4
current_mask_key = remaining_bytes[idx1:idx2]
idx_after_length_part += 4
start, end = idx_after_length_part, idx_after_length_part + current_data_length
data_to_append = remaining_bytes[start:end]
current_data = bytes()
if current_header.masked > 0:
for i in range(data_to_append):
mask = current_mask_key[i % 4]
data_to_append[i] ^= mask # type: ignore
current_data += data_to_append
else:
current_data += data_to_append
if len(current_data) == current_data_length:
_append_message(messages, current_header, current_data)
remaining_bytes = remaining_bytes[end:]
if len(remaining_bytes) > 0:
# continue with the remaining data
return _fetch_messages(
messages=messages,
receive=receive,
remaining_bytes=remaining_bytes,
logger=logger,
)
else:
return messages
elif len(current_data) < current_data_length:
# need more bytes to complete this message
return _fetch_messages(
messages=messages,
receive=receive,
current_mask_key=current_mask_key,
current_header=current_header,
current_data=current_data,
logger=logger,
)
else:
# This pattern is unexpected but set data with the expected length anyway
_append_message(current_header, current_data[:current_data_length]) # type: ignore
return messages
# work in progress with the current_header/current_data
if current_header is not None:
length_needed = current_header.length - len(current_data)
if length_needed > len(remaining_bytes):
current_data += remaining_bytes
# need more bytes to complete this message
return _fetch_messages(
messages=messages,
receive=receive,
current_mask_key=current_mask_key,
current_header=current_header,
current_data=current_data,
logger=logger,
)
else:
current_data += remaining_bytes[:length_needed]
_append_message(messages, current_header, current_data)
remaining_bytes = remaining_bytes[length_needed:]
if len(remaining_bytes) == 0:
return messages
else:
# continue with the remaining data
return _fetch_messages(
messages=messages,
receive=receive,
remaining_bytes=remaining_bytes,
logger=logger,
)
return messages
def _append_message(
messages: List[Tuple[Optional[FrameHeader], bytes]],
header: Optional[FrameHeader],
data: bytes,
) -> None:
messages.append((header, data))
def _build_data_frame_for_sending(
payload: Union[str, bytes],
opcode: int,
fin: int = 1,
rsv1: int = 0,
rsv2: int = 0,
rsv3: int = 0,
masked: int = 1,
):
b1 = fin << 7 | rsv1 << 6 | rsv2 << 5 | rsv3 << 4 | opcode
header: bytes = bytes([b1])
original_payload_data: bytes = payload.encode("utf-8") if isinstance(payload, str) else payload
payload_length = len(original_payload_data)
if payload_length <= 125:
b2 = masked << 7 | payload_length
header += bytes([b2])
else:
b2 = masked << 7 | 126
header += struct.pack("!BH", b2, payload_length)
mask_key: List[int] = random.choices(range(256), k=4)
header += bytes(mask_key)
payload_data: bytes = bytes(byte ^ mask for byte, mask in zip(original_payload_data, itertools.cycle(mask_key)))
return header + payload_data