Sunday 1 February 2009

CPS transformation using delimited continuations


;; Idea: Use reified continuations to implement a CPS converter
;;
;; in the expression: (+ (* x x) (* y y))
;; the continuation of (* x x) is (lambda (x-squared) (+ x-squared (* y y)))
;;
;; similiarly,
;;
;; in the expression: (list '+ (list '* 'x 'x) (list '* 'y 'y))
;; the continuation of (list '* 'x 'x) is (lambda (x-squared) (list '+ x-squared (list '* 'y 'y)))

(require scheme/control) ;; shift/reset delimited continuations

;; > (reset (list '+ (shift k (list '* 'x 'x)) (list '* 'y 'y)))
;; (* x x)
;; > (reset (list '+ (shift k (k '?)) (list '* 'y 'y)))
;; (+ ? (* y y))
;; > (reset (list '+ (shift k (k (list '* 'x 'x))) (list '* 'y 'y)))
;; (+ (* x x) (* y y))
;; > (reset (list '+ (shift k `(let ((x-squared ,(list '* 'x 'x))) ,(k 'x-squared))) (list '* 'y 'y)))
;; (let ((x-squared (* x x))) (+ x-squared (* y y)))


;; CPS applications should throw the result value into a continuation
;; so (f x y z) turns into (f x y z (lambda (result) (continuation result)))

(define (apply# f . args) (shift k (let ((g (gensym "g"))) (reset `(,f ,@args (lambda (,g) ,(k g)))))))

;; Some examples of apply# in action:
;;
;; > `(k ,(apply# 'f--> 'x 'y 'z))
;; (f--> x y z (lambda (g349) (k g349)))
;; > (apply# 'f--> (apply# '+--> 2 3) 'y 'z)
;; (+--> 2 3 (lambda (g345) (f--> g345 y z (lambda (g346) g346))))
;; > (apply# 'f--> (apply# '+--> 2 3) (apply# '*--> 'x 'y) 'z)
;; (+--> 2 3 (lambda (g415) (*--> x y (lambda (g416) (f--> g415 g416 z (lambda (g417) g417))))))

;; What about syntax like IF? clearly apply# would be wrong (due to evaluation order) so define new syntax!
(define-syntax if#
(syntax-rules ()
((if# <cond> <then> <else>) (shift k `(if ,<cond> ,(reset (k <then>)) ,(reset (k <else>)))))))

;; Examples:
;;
;; > (if# (apply# 'zero?--> 'n) ''yes ''no)
;; (zero?--> n (lambda (g374) (if g374 'yes 'no)))
;; > `(display ,(if# (apply# 'zero?--> 'n) ''yes ''no))
;; (zero?--> n (lambda (g850) (if g850 (display 'yes) (display 'no))))

(define-syntax define#
(syntax-rules ()
((define# (name/args ...) body) `(define (name/args ... k-->) ,(reset `(k--> ,body))))))

;; That's enough now to CPS convert entire procedures:

;; > (define# (fact-iter--> n acc)
;; (if# (apply# 'zero?--> 'n)
;; 'acc
;; (apply# 'fact-iter--> (apply# '---> 'n 1) (apply# '*--> 'acc 'n))))
(define (fact-iter--> n acc k-->)
(zero?--> n (lambda (g875)
(if g875
(k--> acc)
(---> n 1 (lambda (g876)
(*--> acc n (lambda (g877)
(fact-iter--> g876 g877 (lambda (g878) (k--> g878)))))))))))

;; Test it! (this CPS format is a subset of Scheme)
(define (zero?--> n k) (k (zero? n)))
(define (*--> x y k) (k (* x y)))
(define (---> x y k) (k (- x y)))
;; > (fact-iter--> 7 1 display)
;; 5040


;; Now a function to CPS convert based on all this is trivial, it's just a fold
;; that replaces if with if#, define with define# and applications with apply#!

(define (if#-thunked cond then-thunk else-thunk) (if# cond (then-thunk) (else-thunk)))
(define (define#-thunked name/args body-thunk)
(let ((k--> (gensym "k-->"))) `(define (,@name/args ,k-->) ,(reset `(,k--> ,(body-thunk))))))
(define (cps term)
(if (pair? term)
(case (car term)
((quote) `',term)
((if) (if#-thunked (cps (cadr term)) (lambda () (cps (caddr term))) (lambda () (cps (cadddr term)))))
((define) (define#-thunked (cadr term) (lambda () (cps (caddr term)))))
(else (apply apply# (map cps term))))
term))

;; > (cps '(define (fact-iter-2 n acc) (if (zero? n) acc (fact-iter-2 (- n 1) (* n acc)))))
(define (fact-iter-2--> n acc k-->)
(zero?--> n (lambda (g443)
(if g443
(k--> acc)
(---> n 1 (lambda (g444)
(*--> n acc (lambda (g445)
(fact-iter-2--> g444 g445 (lambda (g446)
(k--> g446)))))))))))
;; > (fact-iter-2--> 8 1 display)
;; 40320

No comments: