syndicate-2017/prospect/bitset.rkt

152 lines
4.8 KiB
Racket

#lang racket/base
;; Set of small integers stored in bytes.
(provide bitset
bitset?
list->bitset
bitset-count
bitset-empty
bitset-empty?
bitset-add
bitset-remove
bitset-union
bitset-intersect
bitset-subtract
bitset->list
bitset-member?
)
(require racket/performance-hint)
(define (bitset . vs)
(list->bitset vs))
(define (bitset? x)
(bytes? x))
(define (length-to-include v)
(define-values (y i) (quotient/remainder v 8))
(+ y 1))
(define (list->bitset vs)
(define limit (apply max 0 vs))
(foldr bitset-add!* (make-bytes (length-to-include limit)) vs))
;; Cribbed from data/bit-vector.rkt
(require (for-syntax racket/base
(only-in data/private/count-bits-in-fixnum fxpopcount)))
(define popcount-table
(let ()
(define-syntax (make-table stx)
(with-syntax ([(elt ...)
(for/list ([i (in-range 256)])
(fxpopcount i))])
;; Literal immutable vector allocated once (?)
#'(quote #(elt ...))))
(make-table)))
(require (only-in racket/unsafe/ops unsafe-vector-ref))
(define (bitset-count bs)
(for/sum [(b (in-bytes bs))]
(unsafe-vector-ref popcount-table b)))
(define (bitset-empty)
(bytes))
(define (bitset-empty? bs)
(for/and ((b (in-bytes bs))) (zero? b)))
(define (bitset-copy/extend bs v)
(define minsize (length-to-include v))
(if (>= (bytes-length bs) minsize)
(bytes-copy bs)
(let ((result (make-bytes minsize)))
(bytes-copy! result 0 bs)
result)))
(define (bitset-add bs v)
(bitset-add!* v (bitset-copy/extend bs v)))
(define (bitset-trim bs)
(define len (bytes-length bs))
(let loop ((y (- len 1)))
(cond
[(negative? y) (bytes)]
[(zero? (bytes-ref bs y)) (loop (- y 1))]
[(= y (- len 1)) bs]
[else (subbytes bs 0 (+ y 1))])))
(define (bitset-remove bs v)
(bitset-trim (bitset-remove! (bytes-copy bs) v)))
(define (bitset-for-merge bs1 bs2)
(make-bytes (max (bytes-length bs1) (bytes-length bs2))))
(define-inline (bitset-merge* bs bs-short bs-long combiner)
(for ((y (in-range (bytes-length bs-short))))
(bytes-set! bs y (combiner (bytes-ref bs-short y) (bytes-ref bs-long y))))
(for ((y (in-range (bytes-length bs-short) (bytes-length bs-long))))
(bytes-set! bs y (combiner 0 (bytes-ref bs-long y))))
bs)
(define-inline (bitset-merge bs1 bs2 combiner)
(define bs (bitset-for-merge bs1 bs2))
(if (< (bytes-length bs1) (bytes-length bs2))
(bitset-merge* bs bs1 bs2 combiner)
(bitset-merge* bs bs2 bs1 (lambda (b a) (combiner a b)))))
(define (bitset-union bs1 bs2) (bitset-merge bs1 bs2 bitwise-ior))
(define (bitset-intersect bs1 bs2) (bitset-trim (bitset-merge bs1 bs2 bitwise-and)))
(define (bitset-subtract bs1 bs2)
(bitset-trim (bitset-merge bs1 bs2 (lambda (a b) (bitwise-and a (bitwise-not b))))))
(define (bitset->list bs)
(for/fold [(acc '())]
[(b (in-bytes bs)) (byte-index (in-naturals))]
(if (zero? b)
acc
(for/fold [(acc acc)] [(bit-index (in-range 8))]
(if (bitwise-bit-set? b bit-index)
(cons (+ (* byte-index 8) bit-index) acc)
acc)))))
(define (bitset-member? bs v)
(define-values (y i) (quotient/remainder v 8))
(and (> (bytes-length bs) y)
(bitwise-bit-set? (bytes-ref bs y) i)))
(define (bitset-add!* v bs)
(define-values (y i) (quotient/remainder v 8))
(bytes-set! bs y (bitwise-ior (bytes-ref bs y) (arithmetic-shift 1 i)))
bs)
(define (bitset-remove! bs v)
(define-values (y i) (quotient/remainder v 8))
(bytes-set! bs y (bitwise-and (bytes-ref bs y) (bitwise-not (arithmetic-shift 1 i))))
bs)
(module+ test
(require rackunit)
(require racket/set)
(define-syntax-rule (check-set-equal? actual expected)
(check-equal? (list->set actual) (list->set expected)))
(check-set-equal? (bitset->list (bitset 1 2 3 4 5)) '(1 2 3 4 5))
(check-set-equal? (bitset->list (bitset 10 20 30 40 50)) '(10 20 30 40 50))
(check-set-equal? (bitset->list (bitset 5 4 3 2 1)) '(1 2 3 4 5))
(check-set-equal? (bitset->list (bitset-union (bitset 1 2 3) (bitset 2 3 4))) '(1 2 3 4))
(check-set-equal? (bitset->list (bitset-intersect (bitset 1 2 3) (bitset 2 3 4))) '(2 3))
(check-set-equal? (bitset->list (bitset-subtract (bitset 1 2 3) (bitset 2 3 4))) '(1))
(check-true (bitset-member? (bitset 1 2 3) 2))
(check-false (bitset-member? (bitset 1 2 3) 4))
(check-true (bitset-empty? (bitset)))
(check-false (bitset-empty? (bitset 1)))
(check-equal? (bitset-count (bitset 1 2 3)) 3)
(check-equal? (bitset-count (bitset)) 0)
(check-equal? (bitset-count (bitset-union (bitset 1 2 3) (bitset 2 3 4))) 4)
(check-true (bitset? (bitset-empty)))
(check-true (bitset? (bitset)))
(check-false (bitset? 123))
(check-false (bitset? (list 1 2 3)))
(check-false (bitset? 'a))
)