Implementing probabilistic programs in Haskell
Table of Contents
1. Overview
Here we extend our existing data types Term
and Constant
, in order to
implement the DSL described in the last set of notes. We'll also update our
definitions of reorder
and normalForm
to take into account the new
constructors.
2. New types and constructors
Besides our old types E
and T
and our arrow and product constructors, we may
add a type R
to represent the type of real numbers, as well as a type I
to
represent the type of possible worlds (which will end up being useful to us).
In addition, we add a new type constructor P
which takes a type onto the type
of probabilistic programs returning values of that type.
data Type = E | T | I | R | Type :-> Type | Type :/\ Type | P Type
We should also extent our data types Term
and Constant
to encode the
expressions of our DSL. We'll extent Term
by adding the two monadic
constructors (for which we can therefore define reduction rules),
data Term (γ :: Context) (φ :: Type) where Var :: In φ γ -> Term γ φ -- variables Con :: Constant φ -> Term γ φ -- constants Lam :: Term (Cons φ γ) ψ -> Term γ (φ :-> ψ) -- abstraction App :: Term γ (φ :-> ψ) -> Term γ φ -> Term γ ψ -- applications Pair :: Term γ φ -> Term γ ψ -> Term γ (φ :/\ ψ) -- pairing Pi1 :: Term γ (φ :/\ ψ) -> Term γ φ -- first projection Pi2 :: Term γ (φ :/\ ψ) -> Term γ ψ -- second projection Un :: Term γ Unit -- unit Let :: Term γ (P φ) -> Term (Cons φ γ) (P ψ) -> Term γ (P ψ) -- monadic bind Return :: Term γ φ -> Term γ (P φ) -- monadic return
and we'll extend Constant
with constructors for the other expressions.
data Constant (φ :: Type) where Factor :: Constant (R :-> P Unit) -- factor ExpVal :: Constant (P α :-> ((α :-> R) :-> R)) -- 𝔼_{(·)} (the expected value opeartor) Indi :: Constant (T :-> R) -- 𝟙 (the indicator function) ... deriving instance Show (Constant φ)
3. Updated operational semantics
Given our new Term
constructors, we should extend our functions reorder
and
normalForm
to take them into account. Extending the latter will endow our
monadic constructors Let
and Return
with their own set of reduction rules and,
thus, a definition of normal form.
reorder :: forall γ δ ψ. (forall φ. In φ γ -> In φ δ) -> Term γ ψ -> Term δ ψ reorder f (Var i) = Var (f i) reorder _ (Con c) = Con c reorder f (Lam t) = Lam (reorder g t) where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ)) g First = First g (Next i) = Next (f i) reorder f (App t u) = App (reorder f t) (reorder f u) reorder f (Pair t u) = Pair (reorder f t) (reorder f u) reorder f (Pi1 t) = Pi1 (reorder f t) reorder f (Pi2 t) = Pi2 (reorder f t) reorder _ Un = Un reorder f (Return t) = Return (reorder f t) reorder f (Let t u) = Let (reorder f t) (reorder g u) where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ)) g First = First g (Next i) = Next (f i)
While the definition of reorder
on Return
is pretty uninteresting, the branch
of the definition handling Let
does more or less the same thing it does on the
branch handling Lam
: it takes into account the new bound variable, free in u
,
by ensuring that it is not affected by the behavior of reorder
.
We can give a new definition of normalForm
as follows:
normalForm :: Term γ φ -> Term γ φ normalForm v@(Var _) = v -- Variables are already in normal form. normalForm c@(Con _) = c -- So are constants. normalForm (Lam t) = Lam (normalForm t) -- Abstractions are in normal form just in case their bodies are in normal form. normalForm (App t u) = case normalForm t of Lam t' -> normalForm (subst0 (normalForm u) t') -- If the normal form of t is an abstraction, then we need to substitute and further normalize. t' -> App t' (normalForm u) -- Otherwise, we just need to take the normal form of the argument. normalForm (Pair t u) = Pair (normalForm t) (normalForm u) -- Just normalize the projections. normalForm (Pi1 t) = case normalForm t of Pair u _ -> u -- If the normal form inside a projection is actually a pair, we should take that pair's projection. t' -> Pi1 t' -- Otherwise, nothing needs to be done. normalForm (Pi2 t) = -- Ditto. case normalForm t of Pair _ u -> u t' -> Pi2 t' normalForm Un = Un -- ⋄ is already in normal form. normalForm (Return t) = Return (normalForm t) -- Returning something doesn't change whether or not it is in normal form. normalForm (Let t u) = case normalForm t of Return t' -> normalForm (subst0 t' (normalForm u)) -- Here we apply Left Identity. Let t' u' -> normalForm (Let t' (Let u' (weaken2 (normalForm u)))) -- Here we rebracket, potentially leading to another reduction based on Left Identity. t' -> Let t' (normalForm u) -- Here we don't do anything.
where weaken
and weaken2
are defined as
-- weaken, targetting the first position in the context weaken :: Term γ φ -> Term (Cons ψ γ) φ weaken = reorder Next -- weaken, targeting the second position in the context weaken2 :: Term (Cons φ γ) ψ -> Term (Cons φ (Cons χ γ)) ψ weaken2 = reorder g where g :: In ψ (Cons φ γ) -> In ψ (Cons φ (Cons χ γ)) g First = First g (Next i) = Next (Next i)
It can be seen that the normal form of a term with Let
and Return
reduces any
programs as allowed by Left Identity, as well as re-brackets programs so that
such reductions may potentially be applied.