#lang racket/base (require "ssh-numbers.rkt") (require (for-syntax racket/base)) (require (for-syntax (only-in racket/list append*))) (require (for-syntax (only-in srfi/1 iota))) (require (planet tonyg/bitsyntax)) (require racket/bytes) (require rackunit) (provide ssh-message-decode ssh-message-encode) (provide (struct-out ssh-msg-kexinit)) (define decoder-map (make-hasheqv)) (define-values (prop:ssh-message-encoder ssh-message-encoder? ssh-message-encoder) (make-struct-type-property 'ssh-message-encoder)) (define (ssh-message-decode packet) (let ((type-code (bytes-ref packet 0))) ((hash-ref decoder-map type-code (lambda () (error 'ssh-message-decode "Unknown message packet type number ~v" type-code))) packet))) (define (ssh-message-encode m) (bit-string->bytes ((ssh-message-encoder m) m))) (define-syntax define-ssh-message-type (syntax-rules () ((_ name type-byte-value (field-type field-name) ...) (begin (struct name (field-name ...) #:transparent #:property prop:ssh-message-encoder (compute-ssh-message-encoder type-byte-value field-type ...)) (hash-set! decoder-map type-byte-value (compute-ssh-message-decoder name type-byte-value field-type ...)))))) (define-syntax compute-ssh-message-encoder (lambda (stx) (define (encoder-field index vec field-type) (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) (byte #`(vector-ref #,vec #,index)) ((byte n) #`((vector-ref #,vec #,index) :: binary bytes n)) (boolean #`(if (vector-ref #,vec #,index) 1 0)) (uint32 #`((vector-ref #,vec #,index) :: integer bits 32)) (uint64 #`((vector-ref #,vec #,index) :: integer bits 64)) (string #`((let ((v (vector-ref #,vec #,index))) (bit-string ((bytes-length v) :: integer bits 32) (v :: binary))) :: binary)) (mpint #`((let* ((v (vector-ref #,vec #,index)) (width (mpint-width v)) (buf (integer->bit-string v (* 8 width) #t))) (bit-string (width :: integer bits 32) (buf :: binary))) :: binary)) (name-list #`((let ((v (symbols->name-list (vector-ref #,vec #,index)))) (bit-string ((quotient (bit-string-length v) 8) :: integer bits 32) (v :: binary))) :: binary)))) (syntax-case stx () ((_ type-byte-value field-type ...) #`(lambda (message) (let ((vec (struct->vector message))) #,(with-syntax (((field-spec ...) (let ((type-list (syntax->list #'(field-type ...)))) (map (lambda (index type) (encoder-field index #'vec type)) (iota (length type-list) 1) type-list)))) #'(bit-string (type-byte-value :: integer bytes 1) field-spec ...)))))))) (define-syntax compute-ssh-message-decoder (lambda (stx) (define (field-extractor temp-name field-type) (syntax->list (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) (byte #`(#,temp-name)) ((byte n) #`((#,temp-name :: binary bytes n))) (boolean #`(#,temp-name)) (uint32 #`((#,temp-name :: unsigned integer bits 32))) (uint64 #`((#,temp-name :: unsigned integer bits 64))) (string (let ((length-name (car (generate-temporaries (list temp-name))))) #`((#,length-name :: integer bits 32) (#,temp-name :: binary bytes #,length-name)))) (mpint (let ((length-name (car (generate-temporaries (list temp-name))))) #`((#,length-name :: integer bits 32) (#,temp-name :: binary bytes #,length-name)))) (name-list (let ((length-name (car (generate-temporaries (list temp-name))))) #`((#,length-name :: integer bits 32) (#,temp-name :: binary bytes #,length-name))))))) (define (field-transformer temp-name field-type) (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) ((byte n) #`(bit-string->bytes #,temp-name)) (boolean #`(not (zero? #,temp-name))) (string #`(bit-string->bytes #,temp-name)) (mpint #`(if (zero? (bit-string-length #,temp-name)) 0 (bit-string->integer #,temp-name #t #t))) (name-list #`(name-list->symbols #,temp-name)) (else temp-name))) (syntax-case stx () ((_ struct-name type-byte-value field-type ...) (let* ((field-types (syntax->list #'(field-type ...))) (temp-names (generate-temporaries field-types))) #`(lambda (packet) (bit-string-case packet (( (= type-byte-value) #,@(append* (map field-extractor temp-names field-types)) ) (struct-name #,@(map field-transformer temp-names field-types)))))))))) (define (mpint-width n) (if (zero? n) 0 (+ 1 (quotient (integer-length n) 8)))) (check-eqv? (mpint-width 0) 0) (check-eqv? (mpint-width #x9a378f9b2e332a7) 8) (check-eqv? (mpint-width #x7f) 1) (check-eqv? (mpint-width #x80) 2) (check-eqv? (mpint-width #x81) 2) (check-eqv? (mpint-width #xff) 2) (check-eqv? (mpint-width #x100) 2) (check-eqv? (mpint-width #x101) 2) (check-eqv? (mpint-width #x-1234) 2) (check-eqv? (mpint-width #x-deadbeef) 5) (define (symbols->name-list syms) (bytes-join (map (lambda (s) (string->bytes/utf-8 (symbol->string s))) syms) #",")) (define (name-list->symbols bs) (if (zero? (bit-string-length bs)) '() (map string->symbol (regexp-split #rx"," (bytes->string/utf-8 (bit-string->bytes bs)))))) (struct test-message (value) #:transparent #:property prop:ssh-message-encoder (compute-ssh-message-encoder 123 mpint)) (let ((test-decode (compute-ssh-message-decoder test-message 123 mpint))) (define (bidi-check msg enc-without-type-tag) (let ((enc (bytes-append (bytes 123) enc-without-type-tag))) (let ((msg-enc (ssh-message-encode msg)) (enc-msg (test-decode enc))) (if (and (equal? msg-enc enc) (equal? enc-msg msg)) 'ok `(fail ,msg-enc ,enc-msg))))) (check-eqv? (bidi-check (test-message 0) (bytes 0 0 0 0)) 'ok) (check-eqv? (bidi-check (test-message #x9a378f9b2e332a7) (bytes #x00 #x00 #x00 #x08 #x09 #xa3 #x78 #xf9 #xb2 #xe3 #x32 #xa7)) 'ok) (check-eqv? (bidi-check (test-message #x80) (bytes #x00 #x00 #x00 #x02 #x00 #x80)) 'ok) (check-eqv? (bidi-check (test-message #x-1234) (bytes #x00 #x00 #x00 #x02 #xed #xcc)) 'ok) (check-eqv? (bidi-check (test-message #x-deadbeef) (bytes #x00 #x00 #x00 #x05 #xff #x21 #x52 #x41 #x11)) 'ok)) (define-ssh-message-type ssh-msg-kexinit SSH_MSG_KEXINIT ((byte 16) cookie) (name-list kex_algorithms) (name-list server_host_key_algorithms) (name-list encryption_algorithms_client_to_server) (name-list encryption_algorithms_server_to_client) (name-list mac_algorithms_client_to_server) (name-list mac_algorithms_server_to_client) (name-list compression_algorithms_client_to_server) (name-list compression_algorithms_server_to_client) (name-list languages_client_to_server) (name-list languages_server_to_client) (boolean first_kex_packet_follows) (uint32 reserved))