syndicate-ssh/asn1-ber.rkt

165 lines
4.7 KiB
Racket

#lang racket/base
;; A very small subset of ASN.1 BER (from ITU-T X.690), suitable for
;; en- and decoding public-key data for the ssh-rsa and ssh-dss
;; algorithms.
(require racket/match)
(require (planet tonyg/bitsyntax))
(provide t:long-ber-tag
t:ber-length-indicator
asn1-ber-decode-all
asn1-ber-decode
asn1-ber-encode)
(define-syntax t:long-ber-tag
(syntax-rules ()
((_ #t input ks kf) (read-long-tag input ks kf))
((_ #f v) (write-long-tag v))))
(define (read-long-tag input ks kf)
(let loop ((acc 0)
(input input))
(bit-string-case input
([ (= 1 :: bits 1)
(x :: bits 7)
(rest :: binary) ]
(loop (+ x (arithmetic-shift acc 7)) rest))
([ (= 0 :: bits 1)
(x :: bits 7)
(rest :: binary) ]
(when (not (zero? x)))
(ks (+ x (arithmetic-shift acc 7)) rest))
(else (kf)))))
(define (write-long-tag v)
(list->bytes
(reverse-and-set-high-bits
(let loop ((v v))
(if (< v 128)
(list v)
(cons (bitwise-and v 127)
(loop (arithmetic-shift v -7))))))))
(define (reverse-and-set-high-bits bs)
(let loop ((acc (list (car bs)))
(bs (cdr bs)))
(if (null? bs)
acc
(loop (cons (bitwise-ior 128 (car bs)) acc) (cdr bs)))))
(define-syntax t:ber-length-indicator
(syntax-rules ()
((_ #t input ks0 kf)
(let ((ks ks0)) ;; avoid code explosion
(bit-string-case input
([ (= 128 :: bits 8)
(rest :: binary) ]
(ks 'indefinite rest))
([ (= 0 :: bits 1)
(len :: bits 7)
(rest :: binary) ]
(ks len rest))
([ (= 1 :: bits 1)
(lenlen :: bits 7)
(len :: integer bytes lenlen)
(rest :: binary) ]
(when (not (= lenlen 127))) ;; restriction from section 8.1.3.5
(ks len rest))
(else (kf)))))
((_ #f len)
(cond
((eq? len 'indefinite)
(bytes 128))
((< len 128)
(bytes len))
(else
(let ((lenlen (quotient (+ 7 (integer-length len)) 8)))
(bit-string (1 :: bits 1)
(lenlen :: bits 7)
(len :: integer bytes lenlen))))))))
(define (asn1-ber-decode-all packet)
(let-values (((value rest) (asn1-ber-decode packet)))
(if (equal? rest #"")
value
(error 'asn1-ber-decode-all "Trailing bytes present in encoded ASN.1 BER term"))))
(define (asn1-ber-decode packet)
(asn1-ber-decode* packet (lambda (class tag value rest)
(values (list class tag value)
(bit-string->bytes rest)))))
(define (asn1-ber-decode* packet k)
(bit-string-case packet
;; Tag with number >= 31
([ (class :: bits 2)
(constructed :: bits 1)
(= 31 :: bits 5)
(tag :: (t:long-ber-tag))
(length :: (t:ber-length-indicator))
(rest :: binary) ]
(asn1-ber-decode-contents class constructed tag length rest k))
([ (class :: bits 2)
(constructed :: bits 1)
(tag :: bits 5)
(length :: (t:ber-length-indicator))
(rest :: binary) ]
(asn1-ber-decode-contents class constructed tag length rest k))))
(define (asn1-ber-decode-contents class constructed tag length rest k)
(cond
((= constructed 1)
(define indefinite? (eq? length 'indefinite))
(define block (if indefinite? rest (sub-bit-string rest 0 (* length 8))))
(asn1-ber-decode-seq block indefinite? (lambda (seq rest) (k class tag seq rest))))
((= constructed 0)
(bit-string-case rest
([ (block :: binary bytes length)
(rest :: binary) ]
(k class tag (bit-string->bytes block) rest))))))
(define (asn1-ber-decode-seq packet indefinite? k)
(let loop ((rest packet)
(k k))
(if (and (bit-string-empty? rest)
(not indefinite?))
(k '() rest)
(asn1-ber-decode* rest
(lambda (class tag value rest)
(if (and indefinite?
(= class 0)
(= tag 0)
(equal? value #""))
(k '() rest)
(loop rest
(lambda (seq rest)
(k (cons (list class tag value) seq) rest)))))))))
(define (asn1-ber-encode entry)
(bit-string->bytes (asn1-ber-encode* entry)))
(define (asn1-ber-encode* entry)
(match entry
(`(,class ,tag ,value)
(if (list? value)
(let* ((encoded-values (map asn1-ber-encode* value))
(content-octets (foldr bit-string-append #"" encoded-values))
(content-length (quotient (bit-string-length content-octets) 8)))
(bit-string (class :: bits 2)
(1 :: bits 1) ;; constructed
((asn1-ber-encode-tag tag) :: binary)
(content-length :: (t:ber-length-indicator))
(content-octets :: binary bytes content-length)))
(bit-string (class :: bits 2)
(0 :: bits 1) ;; not constructed
((asn1-ber-encode-tag tag) :: binary)
((bytes-length value) :: (t:ber-length-indicator))
(value :: binary))))))
(define (asn1-ber-encode-tag tag)
(if (>= tag 31)
(bit-string (31 :: bits 5) (tag :: (t:long-ber-tag)))
(bit-string (tag :: bits 5))))