racket-ssh-2012/ssh-message-types.rkt

229 lines
7.3 KiB
Racket

#lang racket/base
(require "ssh-numbers.rkt")
(require (for-syntax racket/base))
(require (for-syntax (only-in racket/list append*)))
(require (for-syntax (only-in srfi/1 iota)))
(require (planet tonyg/bitsyntax))
(require racket/bytes)
(require rackunit)
(provide ssh-message-decode
ssh-message-encode)
(provide t:boolean
t:string
t:mpint
mpint-width
t:name-list)
(provide (struct-out ssh-msg-kexinit)
(struct-out ssh-msg-kexdh-init)
(struct-out ssh-msg-kexdh-reply)
(struct-out ssh-msg-disconnect)
(struct-out ssh-msg-unimplemented)
(struct-out ssh-msg-newkeys)
(struct-out ssh-msg-debug)
(struct-out ssh-msg-ignore))
(define decoder-map (make-hasheqv))
(define-values (prop:ssh-message-encoder ssh-message-encoder? ssh-message-encoder)
(make-struct-type-property 'ssh-message-encoder))
(define (ssh-message-decode packet)
(define type-code (bytes-ref packet 0))
(define decoder (hash-ref decoder-map type-code #f))
(if decoder
(decoder packet)
#f))
(define (ssh-message-encode m)
(bit-string->bytes ((ssh-message-encoder m) m)))
(define-syntax define-ssh-message-type
(syntax-rules ()
((_ name type-byte-value (field-type field-name) ...)
(begin
(struct name (field-name ...)
#:transparent
#:property prop:ssh-message-encoder
(compute-ssh-message-encoder type-byte-value field-type ...))
(hash-set! decoder-map type-byte-value
(compute-ssh-message-decoder name type-byte-value field-type ...))))))
(define-syntax t:boolean
(syntax-rules ()
((_ #t) (lambda (input ks kf)
(bit-string-case input
([ v (rest :: binary) ]
(ks (not (zero? v)) rest))
(else (kf)))))
((_ #f) (lambda (v) (bit-string (if v 1 0))))))
(define-syntax t:string
(syntax-rules ()
((_ #t #:pack) (lambda (input ks kf)
((t:string #t) input (lambda (v rest) (ks (bit-string->bytes v) rest)) kf)))
((_ #t) (lambda (input ks kf)
(bit-string-case input
([ (length :: integer bits 32) (body :: binary bytes length) (rest :: binary) ]
(ks body rest))
(else (kf)))))
((_ #f) (lambda (bs)
(bit-string ((bytes-length (bit-string->bytes bs)) :: integer bits 32)
(bs :: binary))))))
(define-syntax t:mpint
(syntax-rules ()
((_ #t) (lambda (input ks kf)
(bit-string-case input
([ (length :: integer bits 32) (body :: binary bytes length) (rest :: binary) ]
(ks (if (zero? (bit-string-length body)) 0 (bit-string->integer body #t #t))
rest))
(else (kf)))))
((_ #f) (lambda (n)
(let* ((width (mpint-width n))
(buf (integer->bit-string n (* 8 width) #t)))
(bit-string (width :: integer bits 32) (buf :: binary)))))))
(define-syntax t:name-list
(syntax-rules ()
((_ #t) (lambda (input ks kf)
((t:string #t) input
(lambda (body rest) (ks (name-list->symbols body) rest))
kf)))
((_ #f) (lambda (ns)
((t:string #f) (symbols->name-list ns))))))
(define-for-syntax (codec-options field-type)
(syntax-case field-type (byte boolean uint32 uint64 string mpint name-list)
(byte #'(integer bits 8))
((byte n) #'(binary bytes n))
(boolean #'((t:boolean)))
(uint32 #'(integer bits 32))
(uint64 #'(integer bits 64))
(string #'((t:string)))
(mpint #'((t:mpint)))
(name-list #'((t:name-list)))))
(define-syntax compute-ssh-message-encoder
(lambda (stx)
(syntax-case stx ()
((_ type-byte-value field-type ...)
#`(lambda (message)
(let ((vec (struct->vector message)))
#,(with-syntax (((field-spec ...)
(let ((type-list (syntax->list #'(field-type ...))))
(map (lambda (index type)
#`((vector-ref vec #,index) :: #,@(codec-options type)))
(iota (length type-list) 1)
type-list))))
#'(bit-string (type-byte-value :: integer bytes 1)
field-spec ...))))))))
(define-syntax compute-ssh-message-decoder
(lambda (stx)
(syntax-case stx ()
((_ struct-name type-byte-value field-type ...)
(with-syntax (((temp-name ...) (generate-temporaries #'(field-type ...)))
(((codec-option ...) ...)
(map codec-options (syntax->list #'(field-type ...)))))
#`(lambda (packet)
(bit-string-case packet
([ (= type-byte-value) (temp-name :: codec-option ...) ... ]
(struct-name temp-name ...)))))))))
(define (mpint-width n)
(if (zero? n)
0
(+ 1 (quotient (integer-length n) 8))))
(check-eqv? (mpint-width 0) 0)
(check-eqv? (mpint-width #x9a378f9b2e332a7) 8)
(check-eqv? (mpint-width #x7f) 1)
(check-eqv? (mpint-width #x80) 2)
(check-eqv? (mpint-width #x81) 2)
(check-eqv? (mpint-width #xff) 2)
(check-eqv? (mpint-width #x100) 2)
(check-eqv? (mpint-width #x101) 2)
(check-eqv? (mpint-width #x-1234) 2)
(check-eqv? (mpint-width #x-deadbeef) 5)
(define (symbols->name-list syms)
(bytes-join (map (lambda (s) (string->bytes/utf-8 (symbol->string s))) syms) #","))
(define (name-list->symbols bs)
(if (zero? (bit-string-length bs))
'()
(map string->symbol (regexp-split #rx"," (bytes->string/utf-8 (bit-string->bytes bs))))))
(struct test-message (value)
#:transparent
#:property prop:ssh-message-encoder (compute-ssh-message-encoder 123 mpint))
(let ((test-decode (compute-ssh-message-decoder test-message 123 mpint)))
(define (bidi-check msg enc-without-type-tag)
(let ((enc (bytes-append (bytes 123) enc-without-type-tag)))
(let ((msg-enc (ssh-message-encode msg))
(enc-msg (test-decode enc)))
(if (and (equal? msg-enc enc)
(equal? enc-msg msg))
'ok
`(fail ,msg-enc ,enc-msg)))))
(check-eqv? (bidi-check (test-message 0) (bytes 0 0 0 0)) 'ok)
(check-eqv? (bidi-check (test-message #x9a378f9b2e332a7)
(bytes #x00 #x00 #x00 #x08
#x09 #xa3 #x78 #xf9
#xb2 #xe3 #x32 #xa7)) 'ok)
(check-eqv? (bidi-check (test-message #x80)
(bytes #x00 #x00 #x00 #x02 #x00 #x80)) 'ok)
(check-eqv? (bidi-check (test-message #x-1234)
(bytes #x00 #x00 #x00 #x02 #xed #xcc)) 'ok)
(check-eqv? (bidi-check (test-message #x-deadbeef)
(bytes #x00 #x00 #x00 #x05
#xff #x21 #x52 #x41 #x11)) 'ok))
(define-ssh-message-type ssh-msg-kexinit SSH_MSG_KEXINIT
((byte 16) cookie)
(name-list kex_algorithms)
(name-list server_host_key_algorithms)
(name-list encryption_algorithms_client_to_server)
(name-list encryption_algorithms_server_to_client)
(name-list mac_algorithms_client_to_server)
(name-list mac_algorithms_server_to_client)
(name-list compression_algorithms_client_to_server)
(name-list compression_algorithms_server_to_client)
(name-list languages_client_to_server)
(name-list languages_server_to_client)
(boolean first_kex_packet_follows)
(uint32 reserved))
(define-ssh-message-type ssh-msg-kexdh-init SSH_MSG_KEXDH_INIT
(mpint e))
(define-ssh-message-type ssh-msg-kexdh-reply SSH_MSG_KEXDH_REPLY
(string host-key)
(mpint f)
(string h-signature))
(define-ssh-message-type ssh-msg-disconnect SSH_MSG_DISCONNECT
(uint32 reason-code)
(string description)
(string language-tag))
(define-ssh-message-type ssh-msg-unimplemented SSH_MSG_UNIMPLEMENTED
(uint32 sequence-number))
(define-ssh-message-type ssh-msg-newkeys SSH_MSG_NEWKEYS)
(define-ssh-message-type ssh-msg-debug SSH_MSG_DEBUG
(boolean always-display?)
(string message)
(string language-tag))
(define-ssh-message-type ssh-msg-ignore SSH_MSG_IGNORE
(string data))