racket-ssh-2012/ssh-message-types.rkt

169 lines
5.8 KiB
Racket

#lang racket/base
(require "ssh-numbers.rkt")
(require (for-syntax racket/base))
(require (for-syntax (only-in racket/list append*)))
(require (for-syntax (only-in srfi/1 iota)))
(require (planet tonyg/bitsyntax))
(require racket/bytes)
(require rackunit)
(provide ssh-message-decode
ssh-message-encode)
(define decoder-map (make-hasheqv))
(define-values (prop:ssh-message-encoder ssh-message-encoder? ssh-message-encoder)
(make-struct-type-property 'ssh-message-encoder))
(define (ssh-message-decode packet)
(let ((type-code (bytes-ref packet 0)))
((hash-ref decoder-map
type-code
(lambda () (error 'ssh-message-decode
"Unknown message packet type number ~v"
type-code)))
packet)))
(define (ssh-message-encode m)
((ssh-message-encoder m) m))
(define-syntax define-ssh-message-type
(syntax-rules ()
((_ name type-byte-value (field-type field-name) ...)
(begin
(provide (struct-out name))
(struct name (field-name ...)
#:transparent
#:property prop:ssh-message-encoder
(compute-ssh-message-encoder type-byte-value field-type ...))
(hash-set! decoder-map type-byte-value
(compute-ssh-message-decoder name type-byte-value field-type ...))))))
(define-syntax compute-ssh-message-encoder
(lambda (stx)
(define (encoder-field index vec field-type)
(syntax-case field-type (byte boolean uint32 uint64 string mpint name-list)
(byte
#`(vector-ref #,vec #,index))
((byte n)
#`((vector-ref #,vec #,index) :: binary bytes n))
(boolean
#`(if (vector-ref #,vec #,index) 1 0))
(uint32
#`((vector-ref #,vec #,index) :: integer bits 32))
(uint64
#`((vector-ref #,vec #,index) :: integer bits 64))
(string
#`((let ((v (vector-ref #,vec #,index)))
(bit-string ((bytes-length v) :: integer bits 32)
(v :: binary))) :: binary))
(mpint
#`((let* ((v (vector-ref #,vec #,index))
(width (mpint-width v))
(buf (integer->bit-string v (* 8 width) #t)))
(bit-string (width :: integer bits 32)
(buf :: binary))) :: binary))
(name-list
#`((let ((v (symbols->name-list (vector-ref #,vec #,index))))
(bit-string ((quotient (bit-string-length v) 8) :: integer bits 32)
(v :: binary))) :: binary))))
(syntax-case stx ()
((_ type-byte-value field-type ...)
#`(lambda (message)
(let ((vec (struct->vector message)))
#,(with-syntax (((field-spec ...)
(let ((type-list (syntax->list #'(field-type ...))))
(map (lambda (index type) (encoder-field index #'vec type))
(iota (length type-list) 1)
type-list))))
#'(bit-string (type-byte-value :: integer bytes 1)
field-spec ...))))))))
(define-syntax compute-ssh-message-decoder
(lambda (stx)
(define (field-extractor temp-name field-type)
(syntax->list
(syntax-case field-type (byte boolean uint32 uint64 string mpint name-list)
(byte
#`(#,temp-name))
((byte n)
#`((#,temp-name :: binary bytes n)))
(boolean
#`(#,temp-name))
(uint32
#`((#,temp-name :: integer bits 32)))
(uint64
#`((#,temp-name :: integer bits 64)))
(string
(let ((length-name (car (generate-temporaries (list temp-name)))))
#`((#,length-name :: integer bits 32)
(#,temp-name :: binary bytes #,length-name))))
(mpint
(let ((length-name (car (generate-temporaries (list temp-name)))))
#`((#,@length-name :: integer bits 32)
(#,temp-name :: binary bytes #,length-name))))
(name-list
(let ((length-name (car (generate-temporaries (list temp-name)))))
#`((#,length-name :: integer bits 32)
(#,temp-name :: binary bytes #,length-name)))))))
(define (field-transformer temp-name field-type)
(syntax-case field-type (byte boolean uint32 uint64 string mpint name-list)
((byte n) #`(bit-string->bytes #,temp-name))
(boolean #`(not (zero? #,temp-name)))
(string #`(bit-string->bytes #,temp-name))
(mpint #`(bit-string->integer #,temp-name))
(name-list #`(name-list->symbols #,temp-name))
(else temp-name)))
(syntax-case stx ()
((_ struct-name type-byte-value field-type ...)
(let* ((field-types (syntax->list #'(field-type ...)))
(temp-names (generate-temporaries field-types)))
#`(lambda (packet)
(bit-string-case packet
(( type-byte-value
#,@(append* (map field-extractor temp-names field-types)))
(struct-name #,@(map field-transformer temp-names field-types))))))))))
(define (mpint-width n)
(if (zero? n)
0
(+ 1 (quotient (integer-length n) 8))))
(check-eqv? (mpint-width 0) 0)
(check-eqv? (mpint-width #x9a378f9b2e332a7) 8)
(check-eqv? (mpint-width #x7f) 1)
(check-eqv? (mpint-width #x80) 2)
(check-eqv? (mpint-width #x81) 2)
(check-eqv? (mpint-width #xff) 2)
(check-eqv? (mpint-width #x100) 2)
(check-eqv? (mpint-width #x101) 2)
(check-eqv? (mpint-width #x-1234) 2)
(check-eqv? (mpint-width #x-deadbeef) 5)
(define (symbols->name-list syms)
(bytes-join (map (lambda (s) (string->bytes/utf-8 (symbol->string s))) syms) #","))
(define (name-list->symbols bs)
(if (zero? (bit-string-length bs))
'()
(map string->symbol (regexp-split #rx"," (bytes->string/utf-8 (bit-string->bytes bs))))))
(define-ssh-message-type ssh-msg-kexinit SSH_MSG_KEXINIT
((byte 16) cookie)
(name-list kex_algorithms)
(name-list server_host_key_algorithms)
(name-list encryption_algorithms_client_to_server)
(name-list encryption_algorithms_server_to_client)
(name-list mac_algorithms_client_to_server)
(name-list mac_algorithms_server_to_client)
(name-list compression_algorithms_client_to_server)
(name-list compression_algorithms_server_to_client)
(name-list languages_client_to_server)
(name-list languages_server_to_client)
(boolean first_kex_packet_follows)
(uint32 reserved))