diff --git a/racket/typed/roles.rkt b/racket/typed/roles.rkt index 75e08b6..42cfb5d 100644 --- a/racket/typed/roles.rkt +++ b/racket/typed/roles.rkt @@ -65,6 +65,7 @@ (require (for-meta 2 macrotypes/stx-utils racket/list syntax/stx syntax/parse racket/base)) (require (for-syntax turnstile/examples/util/filter-maximal)) +(require (for-syntax macrotypes/type-constraints macrotypes/variance-constraints)) (require (for-syntax racket/struct-info)) (require macrotypes/postfix-in) (require (postfix-in - racket/list)) @@ -357,21 +358,195 @@ ------------------------ [⊢ (x-) (⇒ : τ)]) -(define-typed-syntax (typed-app e_fn e_arg ...) ≫ - [⊢ e_fn ≫ e_fn- (⇒ : (~→ τ_in ... (~Computation (~Value τ-out) - (~Endpoints τ-ep ...) - (~Roles τ-f ...) - (~Spawns τ-s ...))))] - #:fail-unless (pure? #'e_fn-) "expression not allowed to have effects" - #:fail-unless (stx-length=? #'[τ_in ...] #'[e_arg ...]) - (num-args-fail-msg #'e_fn #'[τ_in ...] #'[e_arg ...]) - [⊢ e_arg ≫ e_arg- (⇐ : τ_in)] ... - #:fail-unless (stx-andmap pure? #'(e_arg- ...)) "expressions not allowed to have effects" - ------------------------------------------------------------------------ - [⊢ (#%app- e_fn- e_arg- ...) (⇒ : τ-out) - (⇒ ν-ep (τ-ep ...)) - (⇒ ν-s (τ-s ...)) - (⇒ ν-f (τ-f ...))]) +(define-typed-syntax typed-app + ;; Polymorphic, Pure Function - Perform Local Inference + [(_ e_fn e_arg ...) ≫ + ;; compute fn type (ie ∀ and →) + [⊢ e_fn ≫ e_fn- ⇒ (~∀ Xs (~→fn tyX_in ... tyX_out))] + ;; successfully matched a polymorphic fn type, don't backtrack + #:cut + #:with tyX_args #'(tyX_in ... tyX_out) + ;; solve for type variables Xs + #:with [[e_arg- ...] Xs* cs] (solve #'Xs #'tyX_args this-syntax) + ;; instantiate polymorphic function type + #:with [τ_in ... τ_out] (inst-types/cs #'Xs* #'cs #'tyX_args) + #:with (unsolved-X ...) (find-free-Xs #'Xs* #'τ_out) + ;; arity check + #:fail-unless (stx-length=? #'[τ_in ...] #'[e_arg ...]) + (num-args-fail-msg #'e_fn #'[τ_in ...] #'[e_arg ...]) + ;; purity check + #:fail-unless (all-pure? #'(e_fn- e_arg- ...)) "expressions must be pure" + ;; compute argument types + #:with (τ_arg ...) (stx-map typeof #'(e_arg- ...)) + ;; typecheck args + [τ_arg τ⊑ τ_in #:for e_arg] ... + #:with τ_out* (if (stx-null? #'(unsolved-X ...)) + #'τ_out + (syntax-parse #'τ_out + [(~?∀ (Y ...) τ_out) + #:fail-unless (→? #'τ_out) + (mk-app-poly-infer-error this-syntax #'(τ_in ...) #'(τ_arg ...) #'e_fn) + (for ([X (in-list (syntax->list #'(unsolved-X ...)))]) + (unless (covariant-X? X #'τ_out) + (raise-syntax-error + #f + (mk-app-poly-infer-error this-syntax #'(τ_in ...) #'(τ_arg ...) #'e_fn) + this-syntax))) + (mk-∀- #'(unsolved-X ... Y ...) #'(τ_out))])) + -------- + [⊢ (#%plain-app- e_fn- e_arg- ...) ⇒ τ_out*]] + ;; All Other Functions + [(_ e_fn e_arg ...) ≫ + [⊢ e_fn ≫ e_fn- (⇒ : (~→ τ_in ... (~Computation (~Value τ-out) + (~Endpoints τ-ep ...) + (~Roles τ-f ...) + (~Spawns τ-s ...))))] + #:fail-unless (pure? #'e_fn-) "expression not allowed to have effects" + #:fail-unless (stx-length=? #'[τ_in ...] #'[e_arg ...]) + (num-args-fail-msg #'e_fn #'[τ_in ...] #'[e_arg ...]) + [⊢ e_arg ≫ e_arg- (⇐ : τ_in)] ... + #:fail-unless (stx-andmap pure? #'(e_arg- ...)) "expressions not allowed to have effects" + ------------------------------------------------------------------------ + [⊢ (#%app- e_fn- e_arg- ...) (⇒ : τ-out) + (⇒ ν-ep (τ-ep ...)) + (⇒ ν-s (τ-s ...)) + (⇒ ν-f (τ-f ...))]]) + +(begin-for-syntax + ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id) + ;; finds the free Xs in the type + (define (find-free-Xs Xs ty) + (for/list ([X (in-stx-list Xs)] + #:when (stx-contains-id? ty X)) + X)) + + ;; Type -> Bool + ;; checks if the type contains any unions + (define (contains-union? ty) + (syntax-parse ty + [(~U* _ ...) + #t] + [(~Base _) #f] + [X:id #f] + [(~Any/bvs _ _ τ ...) + (stx-ormap contains-union? #'(τ ...))] + [_ + (type-error #:src (get-orig ty) + #:msg "contains-union?: unrecognized-type: ~a" + ty)])) + + ;; solve for Xs by unifying quantified fn type with the concrete types of stx's args + ;; stx = the application stx = (#%app e_fn e_arg ...) + ;; tyXs = input and output types from fn type + ;; ie (typeof e_fn) = (-> . tyXs) + ;; It infers the types of arguments from left-to-right, + ;; and it expands and returns all of the arguments. + ;; It returns list of 3 values if successful, else throws a type error + ;; - a list of all the arguments, expanded + ;; - a list of all the type variables + ;; - the constraints for substituting the types + (define (solve Xs tyXs stx) + (syntax-parse tyXs + [(τ_inX ... τ_outX) + ;; generate initial constraints with expected type and τ_outX + #:with (~?∀ Vs expected-ty) + (and (get-expected-type stx) + ((current-type-eval) (get-expected-type stx))) + (define initial-cs + (if (and (syntax-e #'expected-ty) (stx-null? #'Vs)) + (add-constraints Xs '() (list (list #'expected-ty #'τ_outX))) + '())) + (syntax-parse stx + [(_ e_fn . args) + (define-values (as- cs) + (for/fold ([as- null] [cs initial-cs]) + ([a (in-stx-list #'args)] + [tyXin (in-stx-list #'(τ_inX ...))]) + (define ty_in (inst-type/cs/orig Xs cs tyXin datum=?)) + (when (contains-union? ty_in) + (type-error #:src a + #:msg (format "can't infer types with unions: ~a\nraw: ~a" + (type->str ty_in) ty_in))) + (define/with-syntax [a- ty_a] + (infer+erase (if (null? (find-free-Xs Xs ty_in)) + (add-expected-type a ty_in) + a))) + (when (contains-union? #'ty_a) + (type-error #:src a + #:msg (format "can't infer types with unions: ~a\nraw: ~a" + (type->str #'ty_a) #'ty_a))) + (values + (cons #'a- as-) + (add-constraints Xs cs (list (list ty_in #'ty_a)) + (list (list (inst-type/cs/orig + Xs cs ty_in + datum=?) + #'ty_a)))))) + + (list (reverse as-) Xs cs)])])) + + (define (mk-app-poly-infer-error stx expected-tys given-tys e_fn) + (format (string-append + "Could not infer instantiation of polymorphic function ~s.\n" + " expected: ~a\n" + " given: ~a") + (syntax->datum (get-orig e_fn)) + (string-join (stx-map type->str expected-tys) ", ") + (string-join (stx-map type->str given-tys) ", "))) + + ;; covariant-Xs? : Type -> Bool + ;; Takes a possibly polymorphic type, and returns true if all of the + ;; type variables are in covariant positions within the type, false + ;; otherwise. + (define (covariant-Xs? ty) + (syntax-parse ((current-type-eval) ty) + [(~?∀ Xs ty) + (for/and ([X (in-stx-list #'Xs)]) + (covariant-X? X #'ty))])) + + ;; find-X-variance : Id Type [Variance] -> Variance + ;; Returns the variance of X within the type ty + (define (find-X-variance X ty [ctxt-variance covariant]) + (car (find-variances (list X) ty ctxt-variance))) + + ;; covariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a covariant position, false otherwise. + (define (covariant-X? X ty) + (variance-covariant? (find-X-variance X ty covariant))) + + ;; contravariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a contravariant position, false otherwise. + (define (contravariant-X? X ty) + (variance-contravariant? (find-X-variance X ty covariant))) + + ;; find-variances : (Listof Id) Type [Variance] -> (Listof Variance) + ;; Returns the variances of each of the Xs within the type ty, + ;; where it's already within a context represented by ctxt-variance. + (define (find-variances Xs ty [ctxt-variance covariant]) + (syntax-parse ty + [A:id + (for/list ([X (in-list Xs)]) + (cond [(free-identifier=? X #'A) ctxt-variance] + [else irrelevant]))] + [(~Any tycons) + (stx-map (λ _ irrelevant) Xs)] + [(~?∀ () (~Any tycons τ ...)) + #:when (get-arg-variances #'tycons) + #:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons)) + (define τ-ctxt-variances + (for/list ([arg-variance (in-list (get-arg-variances #'tycons))]) + (variance-compose ctxt-variance arg-variance))) + (for/fold ([acc (stx-map (λ _ irrelevant) Xs)]) + ([τ (in-stx-list #'[τ ...])] + [τ-ctxt-variance (in-list τ-ctxt-variances)]) + (map variance-join + acc + (find-variances Xs τ τ-ctxt-variance)))] + [ty + #:when (not (for/or ([X (in-list Xs)]) + (stx-contains-id? #'ty X))) + (stx-map (λ _ irrelevant) Xs)] + [_ (stx-map (λ _ invariant) Xs)]))) (define-typed-syntax (tuple e:expr ...) ≫ [⊢ e ≫ e- (⇒ : τ)] ... diff --git a/racket/typed/tests/inference.rkt b/racket/typed/tests/inference.rkt new file mode 100644 index 0000000..3148edc --- /dev/null +++ b/racket/typed/tests/inference.rkt @@ -0,0 +1,29 @@ +#lang typed/syndicate/roles + +(require rackunit/turnstile) + +(define (∀ (X) (poly-cons [x : X] + [xs : (List X)] + -> (List X))) + (cons x xs)) + +(define int-list : (List Int) (list 1 2 3)) + +(check-type (poly-cons 0 int-list) + : (List Int) + -> (list 0 1 2 3)) + +(define string-list : (List String) (list "group" "of" "helpful" "badgers")) + +(check-type (poly-cons "a" string-list) + : (List String) + -> (list "a" "group" "of" "helpful" "badgers")) + +(typecheck-fail (poly-cons "hello" int-list)) + +(define string-int-list : (List (U String Int)) + (list "hi" 42 "badgers")) + +;; shouldn't mess about with unions +(typecheck-fail (poly-cons "go" string-int-list)) +