Implementing probabilistic programs in Haskell
Table of Contents
1. Overview
Here we extend our existing data types Term and Constant, in order to
implement the DSL described in the last set of notes. We'll also update our
definitions of reorder and normalForm to take into account the new
constructors.
2. New types and constructors
Besides our old types E and T and our arrow and product constructors, we may
add a type R to represent the type of real numbers, as well as a type I to
represent the type of possible worlds (which will end up being useful to us).
In addition, we add a new type constructor P which takes a type onto the type
of probabilistic programs returning values of that type.
data Type = E | T | I | R | Type :-> Type | Type :/\ Type | P Type
We should also extent our data types Term and Constant to encode the
expressions of our DSL. We'll extent Term by adding the two monadic
constructors (for which we can therefore define reduction rules),
data Term (γ :: Context) (φ :: Type) where Var :: In φ γ -> Term γ φ -- variables Con :: Constant φ -> Term γ φ -- constants Lam :: Term (Cons φ γ) ψ -> Term γ (φ :-> ψ) -- abstraction App :: Term γ (φ :-> ψ) -> Term γ φ -> Term γ ψ -- applications Pair :: Term γ φ -> Term γ ψ -> Term γ (φ :/\ ψ) -- pairing Pi1 :: Term γ (φ :/\ ψ) -> Term γ φ -- first projection Pi2 :: Term γ (φ :/\ ψ) -> Term γ ψ -- second projection Un :: Term γ Unit -- unit Let :: Term γ (P φ) -> Term (Cons φ γ) (P ψ) -> Term γ (P ψ) -- monadic bind Return :: Term γ φ -> Term γ (P φ) -- monadic return
and we'll extend Constant with constructors for the other expressions.
data Constant (φ :: Type) where Factor :: Constant (R :-> P Unit) -- factor ExpVal :: Constant (P α :-> ((α :-> R) :-> R)) -- 𝔼_{(·)} (the expected value opeartor) Indi :: Constant (T :-> R) -- 𝟙 (the indicator function) ... deriving instance Show (Constant φ)
3. Updated operational semantics
Given our new Term constructors, we should extend our functions reorder and
normalForm to take them into account. Extending the latter will endow our
monadic constructors Let and Return with their own set of reduction rules and,
thus, a definition of normal form.
reorder :: forall γ δ ψ. (forall φ. In φ γ -> In φ δ) -> Term γ ψ -> Term δ ψ reorder f (Var i) = Var (f i) reorder _ (Con c) = Con c reorder f (Lam t) = Lam (reorder g t) where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ)) g First = First g (Next i) = Next (f i) reorder f (App t u) = App (reorder f t) (reorder f u) reorder f (Pair t u) = Pair (reorder f t) (reorder f u) reorder f (Pi1 t) = Pi1 (reorder f t) reorder f (Pi2 t) = Pi2 (reorder f t) reorder _ Un = Un reorder f (Return t) = Return (reorder f t) reorder f (Let t u) = Let (reorder f t) (reorder g u) where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ)) g First = First g (Next i) = Next (f i)
While the definition of reorder on Return is pretty uninteresting, the branch
of the definition handling Let does more or less the same thing it does on the
branch handling Lam: it takes into account the new bound variable, free in u,
by ensuring that it is not affected by the behavior of reorder.
We can give a new definition of normalForm as follows:
normalForm :: Term γ φ -> Term γ φ normalForm v@(Var _) = v -- Variables are already in normal form. normalForm c@(Con _) = c -- So are constants. normalForm (Lam t) = Lam (normalForm t) -- Abstractions are in normal form just in case their bodies are in normal form. normalForm (App t u) = case normalForm t of Lam t' -> normalForm (subst0 (normalForm u) t') -- If the normal form of t is an abstraction, then we need to substitute and further normalize. t' -> App t' (normalForm u) -- Otherwise, we just need to take the normal form of the argument. normalForm (Pair t u) = Pair (normalForm t) (normalForm u) -- Just normalize the projections. normalForm (Pi1 t) = case normalForm t of Pair u _ -> u -- If the normal form inside a projection is actually a pair, we should take that pair's projection. t' -> Pi1 t' -- Otherwise, nothing needs to be done. normalForm (Pi2 t) = -- Ditto. case normalForm t of Pair _ u -> u t' -> Pi2 t' normalForm Un = Un -- ⋄ is already in normal form. normalForm (Return t) = Return (normalForm t) -- Returning something doesn't change whether or not it is in normal form. normalForm (Let t u) = case normalForm t of Return t' -> normalForm (subst0 t' (normalForm u)) -- Here we apply Left Identity. Let t' u' -> normalForm (Let t' (Let u' (weaken2 (normalForm u)))) -- Here we rebracket, potentially leading to another reduction based on Left Identity. t' -> Let t' (normalForm u) -- Here we don't do anything.
where weaken and weaken2 are defined as
-- weaken, targetting the first position in the context weaken :: Term γ φ -> Term (Cons ψ γ) φ weaken = reorder Next -- weaken, targeting the second position in the context weaken2 :: Term (Cons φ γ) ψ -> Term (Cons φ (Cons χ γ)) ψ weaken2 = reorder g where g :: In ψ (Cons φ γ) -> In ψ (Cons φ (Cons χ γ)) g First = First g (Next i) = Next (Next i)
It can be seen that the normal form of a term with Let and Return reduces any
programs as allowed by Left Identity, as well as re-brackets programs so that
such reductions may potentially be applied.