diff --git a/prospect/comprehensions.rkt b/prospect/comprehensions.rkt index 86a4b75..6d11962 100644 --- a/prospect/comprehensions.rkt +++ b/prospect/comprehensions.rkt @@ -13,14 +13,13 @@ (for-syntax racket/match)) (begin-for-syntax - ; Pattern-Syntax Syntax -> (SyntaxOf TempVar TempVar Projection-Pattern Match-Pattern) + ; Pattern-Syntax Syntax -> + ; (SyntaxOf TempVar TempVar Projection-Pattern Match-Pattern) (define (helper pat-stx outer-stx) (match-define (list temp1 temp2) (generate-temporaries #'(tmp1 tmp2))) (define-values (proj-stx pat match-pat bindings) (analyze-pattern outer-stx pat-stx)) - (datum->syntax - outer-stx - (list temp1 temp2 pat match-pat)))) + (list temp1 temp2 pat match-pat))) ;; trie projection symbol -> (U set exn:fail?) ;; tries to project the trie. If the resulting trie would be infinite, raise an @@ -32,61 +31,55 @@ (error "pattern projection created infinite trie:" pat)) s?) +(begin-for-syntax + (define (build-fold stx ctx) + (syntax-case stx () + [(_ ([acc-id acc-init] ...) + () + body ...) + #'(begin body ...)] + [(_ ([acc-id acc-init] ...) + ((pat_0 trie_0) + clauses ...) + body ...) + (begin + (match-define (list set-tmp loop-tmp proj-stx match-pat) + (helper #'pat_0 ctx)) + (with-syntax ([new-acc (generate-temporary 'acc)]) + #`(let ([#,set-tmp (project-finite trie_0 #,proj-stx 'pat_0)]) + (for/fold/derived #,ctx ([acc-id acc-init] + ...) + ([loop-tmp (in-set #,set-tmp)]) + (match loop-tmp + [(list #,match-pat) + #,(build-fold + #`(_ ([acc-id acc-id] + ...) + (clauses ...) + body ...) + ctx)] + [_ (values acc-id ...)])))))] + [(_ ([acc-id acc-init] ...) + (#:where pred clauses ...) + body ...) + #`(if pred + #,(build-fold #'(_ ([acc-id acc-init] ...) (clauses ...) body ...) + ctx) + (values acc-id ...))]))) + (define-syntax (for-trie/fold stx) - (syntax-case stx () - [(_ ([acc-id acc-init] ...) - ((pat_0 trie_0) - (pat_n trie_n) ... - #:where pred) - body) - (with-syntax* ([(set-tmp loop-tmp proj-stx match-pat) - (helper #'pat_0 #'body)] - [new-acc (generate-temporary 'acc)]) - #`(let ([set-tmp (project-finite trie_0 proj-stx 'pat_0)]) - (for/fold/derived #,stx ([acc-id acc-init] - ...) - ([loop-tmp (in-set set-tmp)]) - (match loop-tmp - [(list match-pat) - (for-trie/fold ([acc-id acc-id] - ...) - ([pat_n trie_n] - ... - #:where pred) - body)] - [_ (values acc-id ...)]))))] - [(_ ([acc-id acc-init] ...) - (#:where pred) - body) - #'(if pred body (values acc-id ...))] - [(_ ([acc-id acc-init] ...) ([pat exp] ...) body) - #'(for-trie/fold ([acc-id acc-init] ...) ([pat exp] ... #:where #t) body)] - [(_ (accs ...) (clauses ...) body_0 body_1 body_n ...) - (with-syntax [(new-body (replace-context #'body_0 - #'(begin body_0 body_1 body_n ...)))] - #'(for-trie/fold (accs ...) (clauses ...) new-body))])) + (build-fold stx stx)) (define-syntax (make-fold stx) (syntax-case stx () [(_ name folder initial) #'(define-syntax (name stx) (syntax-case stx () - [(_ ([pat expr] (... ...) #:where pred) body) - (with-syntax* ([acc (replace-context #'body (generate-temporary 'acc))] - [new-body #'(folder body acc)] - [new-body (replace-context #'body #'new-body)]) - #'(for-trie/fold ([acc initial]) - ([pat expr] - (... ...) - #:where pred) - new-body))] - [(_ ([pat exp] (... ...)) body) - #'(name ([pat exp] (... ...) #:where #t) body)] - [(_ (clauses (... ...)) body_0 body_1 body_n (... ...)) - (with-syntax [(new-body (replace-context #'body_0 - #'(begin body_0 body_1 body_n (... ...))))] - #'(name (clauses (... ...)) new-body))]))])) - + [(_ (clauses (... ...)) body (... ...)) + (with-syntax ([loop #'(for-trie/fold ([acc initial]) + (clauses (... ...)) + (folder (begin body (... ...)) acc))]) + (build-fold #'loop stx))]))])) (make-fold for-trie/list cons empty) @@ -103,8 +96,11 @@ (define-syntax (for-trie stx) (syntax-case stx () - [(_ more-stx ...) - #'(void (for-trie-inner more-stx ...))])) + [(_ (clauses ...) body ...) + (with-syntax ([loop #'(for-trie/fold ([acc (void)]) + (clauses ...) + (begin body ... acc))]) + (build-fold #'loop stx))])) (module+ test (require rackunit) @@ -125,7 +121,12 @@ #:where (even? x)) (+ x 1)) (set 3 5)) - (check-equal? (for-trie/set ([(cons $x _) (make-trie 1 2 (list 0) (list 1 2 3) (cons 'x 'y) (cons 3 4) (cons 'a 'b) "x" 'foo)]) + (check-equal? (for-trie/set ([(cons $x _) (make-trie 1 2 (list 0) + (list 1 2 3) + (cons 'x 'y) + (cons 3 4) + (cons 'a 'b) + "x" 'foo)]) x) (set 'x 3 'a)) (check-equal? (for-trie/fold ([acc 0]) @@ -193,7 +194,7 @@ x))) ;; projecting something finite out is ok (check-equal? (for-trie/list ([1 (pattern->trie 'x (projection->pattern ?))]) - 1) + 1) (list 1)) (let ([a-set (mutable-set)]) ;; for-trie results in (void) @@ -201,5 +202,4 @@ (set-add! a-set x)) (void)) ;; for-trie runs body for effects - (check-equal? a-set (mutable-set 1 2 3 4)))) - + (check-equal? a-set (mutable-set 1 2 3 4)))) \ No newline at end of file