diff --git a/ssh-message-types.rkt b/ssh-message-types.rkt index 5c8e58a..133050c 100644 --- a/ssh-message-types.rkt +++ b/ssh-message-types.rkt @@ -44,41 +44,69 @@ (hash-set! decoder-map type-byte-value (compute-ssh-message-decoder name type-byte-value field-type ...)))))) +(define-syntax t:boolean + (syntax-rules () + ((_ #t) (lambda (input ks kf) + (bit-string-case input + ([ v (rest :: binary) ] + (ks (not (zero? v)) rest)) + (else (kf))))) + ((_ #f) (lambda (v) (bit-string (if v 1 0)))))) + +(define-syntax t:string + (syntax-rules () + ((_ #t) (lambda (input ks kf) + (bit-string-case input + ([ (length :: integer bits 32) (body :: binary bytes length) (rest :: binary) ] + (ks body rest)) + (else (kf))))) + ((_ #f) (lambda (bs) + (bit-string ((bytes-length (bit-string->bytes bs)) :: integer bits 32) + (bs :: binary)))))) + +(define-syntax t:mpint + (syntax-rules () + ((_ #t) (lambda (input ks kf) + (bit-string-case input + ([ (length :: integer bits 32) (body :: binary bytes length) (rest :: binary) ] + (ks (if (zero? (bit-string-length body)) 0 (bit-string->integer body #t #t)) + rest)) + (else (kf))))) + ((_ #f) (lambda (n) + (let* ((width (mpint-width n)) + (buf (integer->bit-string n (* 8 width) #t))) + (bit-string (width :: integer bits 32) (buf :: binary))))))) + +(define-syntax t:name-list + (syntax-rules () + ((_ #t) (lambda (input ks kf) + ((t:string #t) input + (lambda (body rest) (ks (name-list->symbols body) rest)) + kf))) + ((_ #f) (lambda (ns) + ((t:string #f) (symbols->name-list ns)))))) + +(define-for-syntax (codec-options field-type) + (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) + (byte #'(integer bits 8)) + ((byte n) #'(binary bytes n)) + (boolean #'((t:boolean))) + (uint32 #'(integer bits 32)) + (uint64 #'(integer bits 64)) + (string #'((t:string))) + (mpint #'((t:mpint))) + (name-list #'((t:name-list))))) + (define-syntax compute-ssh-message-encoder (lambda (stx) - (define (encoder-field index vec field-type) - (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) - (byte - #`(vector-ref #,vec #,index)) - ((byte n) - #`((vector-ref #,vec #,index) :: binary bytes n)) - (boolean - #`(if (vector-ref #,vec #,index) 1 0)) - (uint32 - #`((vector-ref #,vec #,index) :: integer bits 32)) - (uint64 - #`((vector-ref #,vec #,index) :: integer bits 64)) - (string - #`((let ((v (vector-ref #,vec #,index))) - (bit-string ((bytes-length v) :: integer bits 32) - (v :: binary))) :: binary)) - (mpint - #`((let* ((v (vector-ref #,vec #,index)) - (width (mpint-width v)) - (buf (integer->bit-string v (* 8 width) #t))) - (bit-string (width :: integer bits 32) - (buf :: binary))) :: binary)) - (name-list - #`((let ((v (symbols->name-list (vector-ref #,vec #,index)))) - (bit-string ((quotient (bit-string-length v) 8) :: integer bits 32) - (v :: binary))) :: binary)))) (syntax-case stx () ((_ type-byte-value field-type ...) #`(lambda (message) (let ((vec (struct->vector message))) #,(with-syntax (((field-spec ...) (let ((type-list (syntax->list #'(field-type ...)))) - (map (lambda (index type) (encoder-field index #'vec type)) + (map (lambda (index type) + #`((vector-ref vec #,index) :: #,@(codec-options type))) (iota (length type-list) 1) type-list)))) #'(bit-string (type-byte-value :: integer bytes 1) @@ -86,50 +114,15 @@ (define-syntax compute-ssh-message-decoder (lambda (stx) - (define (field-extractor temp-name field-type) - (syntax->list - (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) - (byte - #`(#,temp-name)) - ((byte n) - #`((#,temp-name :: binary bytes n))) - (boolean - #`(#,temp-name)) - (uint32 - #`((#,temp-name :: unsigned integer bits 32))) - (uint64 - #`((#,temp-name :: unsigned integer bits 64))) - (string - (let ((length-name (car (generate-temporaries (list temp-name))))) - #`((#,length-name :: integer bits 32) - (#,temp-name :: binary bytes #,length-name)))) - (mpint - (let ((length-name (car (generate-temporaries (list temp-name))))) - #`((#,length-name :: integer bits 32) - (#,temp-name :: binary bytes #,length-name)))) - (name-list - (let ((length-name (car (generate-temporaries (list temp-name))))) - #`((#,length-name :: integer bits 32) - (#,temp-name :: binary bytes #,length-name))))))) - (define (field-transformer temp-name field-type) - (syntax-case field-type (byte boolean uint32 uint64 string mpint name-list) - ((byte n) #`(bit-string->bytes #,temp-name)) - (boolean #`(not (zero? #,temp-name))) - (string #`(bit-string->bytes #,temp-name)) - (mpint #`(if (zero? (bit-string-length #,temp-name)) - 0 - (bit-string->integer #,temp-name #t #t))) - (name-list #`(name-list->symbols #,temp-name)) - (else temp-name))) (syntax-case stx () ((_ struct-name type-byte-value field-type ...) - (let* ((field-types (syntax->list #'(field-type ...))) - (temp-names (generate-temporaries field-types))) + (with-syntax (((temp-name ...) (generate-temporaries #'(field-type ...))) + (((codec-option ...) ...) + (map codec-options (syntax->list #'(field-type ...))))) #`(lambda (packet) (bit-string-case packet - (( (= type-byte-value) - #,@(append* (map field-extractor temp-names field-types)) ) - (struct-name #,@(map field-transformer temp-names field-types)))))))))) + ([ (= type-byte-value) (temp-name :: codec-option ...) ... ] + (struct-name temp-name ...))))))))) (define (mpint-width n) (if (zero? n)