racket-matrix-2012/relation.rkt

158 lines
4.4 KiB
Racket

#lang racket/base
;; Relations are equivalent to Hash<X, Set<Y>> and Set<Pair<X, Y>>.
(require racket/set)
(require racket/match)
(provide (rename-out [make-relation relation])
relation?
relation->list
list->relation
relation->hash
in-relation-domain
in-relation/grouped
relation-empty?
relation-count
relation-add
relation-add-all
relation-remove
relation-remove-all
relation-ref
relation-set
relation-update
relation-domain-member?
relation-member?
relation-domain
relation-domain-eq?
relation-domain-eqv?
relation-domain-equal?
relation-codomain-eq?
relation-codomain-eqv?
relation-codomain-equal?
relation-for-each
relation-fold
relation-map
;; TODO: -subtract, -intersect, -symmetric-difference, -union
)
(struct relation (table set-constructor))
(define (make-relation #:domain [domain-comparator equal?]
#:codomain [codomain-comparator equal?])
(relation ((cond
[(eq? domain-comparator equal?) hash]
[(eq? domain-comparator eqv?) hasheqv]
[(eq? domain-comparator eq?) hasheq]))
(cond
[(eq? codomain-comparator equal?) set]
[(eq? codomain-comparator eqv?) seteqv]
[(eq? codomain-comparator eq?) seteq])))
(define (relation->list r)
(for*/list ([(k vs) (in-hash (relation-table r))]
[v (in-set vs)])
(cons k v)))
(define (list->relation xs
#:domain [domain-comparator equal?]
#:codomain [codomain-comparator equal?])
(let loop ((xs xs)
(r (relation #:domain domain-comparator #:codomain codomain-comparator)))
(match xs
['() r]
[(cons (cons k v) rest) (loop rest (relation-add r k v))]
[_ (error 'list->relation "Expected list of key/value pairs")])))
(define (relation->hash r)
(relation-table r))
(define (in-relation-domain r)
(in-hash-keys (relation-table r)))
(define (in-relation/grouped r)
(in-hash (relation-table r)))
(define (relation-empty? r)
(zero? (hash-count (relation-table r))))
(define (relation-count r)
(for*/sum ([(k vs) (in-hash (relation-table r))]) (set-count vs)))
(define (relation-add r k v)
(struct-copy relation r
[table (hash-update (relation-table r)
k
(lambda (old-vs) (set-add old-vs v))
(relation-set-constructor r))]))
(define (relation-add-all r k vs)
(struct-copy relation r
[table (hash-update (relation-table r)
k
(lambda (old-vs) (set-union old-vs vs))
(relation-set-constructor r))]))
(define (relation-remove r k v)
(define old-vs (hash-ref (relation-table r) k (relation-set-constructor r)))
(define new-vs (set-remove old-vs v))
(if (set-empty? new-vs)
(hash-remove (relation-table r) k)
(hash-set (relation-table r) k new-vs)))
(define (relation-remove-all r k vs)
(define old-vs (hash-ref (relation-table r) k (relation-set-constructor r)))
(define new-vs (set-subtract old-vs vs))
(if (set-empty? new-vs)
(hash-remove (relation-table r) k)
(hash-set (relation-table r) k new-vs)))
(define (relation-ref r k [failure-result (relation-set-constructor r)])
(hash-ref (relation-table r) k failure-result))
(define (relation-set r k vs)
(hash-set (relation-table r) k vs))
(define (relation-update r k updater [failure-result (relation-set-constructor r)])
(hash-update (relation-table r) k updater failure-result))
(define (relation-domain-member? r k)
(hash-has-key? (relation-table r) k))
(define (relation-member? r k v)
(and (relation-domain-member? r k)
(set-member? (hash-ref (relation-table r) k) v)))
(define (relation-domain r)
(hash-keys (relation-table r)))
(define (relation-domain-eq? r) (hash-eq? (relation-table r)))
(define (relation-domain-eqv? r) (hash-eqv? (relation-table r)))
(define (relation-domain-equal? r) (hash-equal? (relation-table r)))
(define (relation-codomain-eq? r) (eq? (relation-set-constructor r) seteq))
(define (relation-codomain-eqv? r) (eqv? (relation-set-constructor r) seteqv))
(define (relation-codomain-equal? r) (equal? (relation-set-constructor r) set))
(define (relation-for-each r proc)
(hash-for-each (relation-table r)
(lambda (k vs)
(set-for-each vs (lambda (v) (proc k v))))))
(define (relation-fold r seed0 proc)
(define seed seed0)
(relation-for-each r (lambda (k v) (set! seed (proc k v seed))))
seed)
(define (relation-map r proc)
(define results '())
(relation-for-each r (lambda (k v) (cons (proc k v) results)))
results)