diff --git a/ssh-message-types.rkt b/ssh-message-types.rkt index ff67532..5c8e58a 100644 --- a/ssh-message-types.rkt +++ b/ssh-message-types.rkt @@ -14,6 +14,8 @@ (provide ssh-message-decode ssh-message-encode) +(provide (struct-out ssh-msg-kexinit)) + (define decoder-map (make-hasheqv)) (define-values (prop:ssh-message-encoder ssh-message-encoder? ssh-message-encoder) @@ -29,13 +31,12 @@ packet))) (define (ssh-message-encode m) - ((ssh-message-encoder m) m)) + (bit-string->bytes ((ssh-message-encoder m) m))) (define-syntax define-ssh-message-type (syntax-rules () ((_ name type-byte-value (field-type field-name) ...) (begin - (provide (struct-out name)) (struct name (field-name ...) #:transparent #:property prop:ssh-message-encoder @@ -95,16 +96,16 @@ (boolean #`(#,temp-name)) (uint32 - #`((#,temp-name :: integer bits 32))) + #`((#,temp-name :: unsigned integer bits 32))) (uint64 - #`((#,temp-name :: integer bits 64))) + #`((#,temp-name :: unsigned integer bits 64))) (string (let ((length-name (car (generate-temporaries (list temp-name))))) #`((#,length-name :: integer bits 32) (#,temp-name :: binary bytes #,length-name)))) (mpint (let ((length-name (car (generate-temporaries (list temp-name))))) - #`((#,@length-name :: integer bits 32) + #`((#,length-name :: integer bits 32) (#,temp-name :: binary bytes #,length-name)))) (name-list (let ((length-name (car (generate-temporaries (list temp-name))))) @@ -115,7 +116,9 @@ ((byte n) #`(bit-string->bytes #,temp-name)) (boolean #`(not (zero? #,temp-name))) (string #`(bit-string->bytes #,temp-name)) - (mpint #`(bit-string->integer #,temp-name)) + (mpint #`(if (zero? (bit-string-length #,temp-name)) + 0 + (bit-string->integer #,temp-name #t #t))) (name-list #`(name-list->symbols #,temp-name)) (else temp-name))) (syntax-case stx () @@ -124,8 +127,8 @@ (temp-names (generate-temporaries field-types))) #`(lambda (packet) (bit-string-case packet - (( type-byte-value - #,@(append* (map field-extractor temp-names field-types))) + (( (= type-byte-value) + #,@(append* (map field-extractor temp-names field-types)) ) (struct-name #,@(map field-transformer temp-names field-types)))))))))) (define (mpint-width n) @@ -152,6 +155,31 @@ '() (map string->symbol (regexp-split #rx"," (bytes->string/utf-8 (bit-string->bytes bs)))))) +(struct test-message (value) + #:transparent + #:property prop:ssh-message-encoder (compute-ssh-message-encoder 123 mpint)) +(let ((test-decode (compute-ssh-message-decoder test-message 123 mpint))) + (define (bidi-check msg enc-without-type-tag) + (let ((enc (bytes-append (bytes 123) enc-without-type-tag))) + (let ((msg-enc (ssh-message-encode msg)) + (enc-msg (test-decode enc))) + (if (and (equal? msg-enc enc) + (equal? enc-msg msg)) + 'ok + `(fail ,msg-enc ,enc-msg))))) + (check-eqv? (bidi-check (test-message 0) (bytes 0 0 0 0)) 'ok) + (check-eqv? (bidi-check (test-message #x9a378f9b2e332a7) + (bytes #x00 #x00 #x00 #x08 + #x09 #xa3 #x78 #xf9 + #xb2 #xe3 #x32 #xa7)) 'ok) + (check-eqv? (bidi-check (test-message #x80) + (bytes #x00 #x00 #x00 #x02 #x00 #x80)) 'ok) + (check-eqv? (bidi-check (test-message #x-1234) + (bytes #x00 #x00 #x00 #x02 #xed #xcc)) 'ok) + (check-eqv? (bidi-check (test-message #x-deadbeef) + (bytes #x00 #x00 #x00 #x05 + #xff #x21 #x52 #x41 #x11)) 'ok)) + (define-ssh-message-type ssh-msg-kexinit SSH_MSG_KEXINIT ((byte 16) cookie) (name-list kex_algorithms)