Implementing probabilistic programs in Haskell

Table of Contents

1. Overview

Here we extend our existing data types Term and Constant, in order to implement the DSL described in the last set of notes. We'll also update our definitions of reorder and normalForm to take into account the new constructors.

2. New types and constructors

Besides our old types E and T and our arrow and product constructors, we may add a type R to represent the type of real numbers, as well as a type I to represent the type of possible worlds (which will end up being useful to us). In addition, we add a new type constructor P which takes a type onto the type of probabilistic programs returning values of that type.

data Type = E | T | I | R
          | Type :-> Type
          | Type :/\ Type
          | P Type

We should also extent our data types Term and Constant to encode the expressions of our DSL. We'll extent Term by adding the two monadic constructors (for which we can therefore define reduction rules),

data Term:: Context) (φ :: Type) where
  Var :: In φ γ -> Term γ φ                        -- variables
  Con :: Constant φ -> Term γ φ                    -- constants
  Lam :: Term (Cons φ γ) ψ -> Term γ (φ :-> ψ)     -- abstraction
  App :: Term γ (φ :-> ψ) -> Term γ φ -> Term γ ψ  -- applications
  Pair :: Term γ φ -> Term γ ψ -> Term γ (φ :/\ ψ) -- pairing
  Pi1 :: Term γ (φ :/\ ψ) -> Term γ φ              -- first projection
  Pi2 :: Term γ (φ :/\ ψ) -> Term γ ψ              -- second projection
  Un :: Term γ Unit                                -- unit
  Let :: Term γ (P φ) -> Term (Cons φ γ) (P ψ) -> Term γ (P ψ) -- monadic bind
  Return :: Term γ φ -> Term γ (P φ)                           -- monadic return

and we'll extend Constant with constructors for the other expressions.

data Constant:: Type) where 
  Factor :: Constant (R :-> P Unit)              -- factor
  ExpVal :: Constant (P α :-> ((α :-> R) :-> R)) -- 𝔼_{(·)} (the expected value opeartor)
  Indi :: Constant (T :-> R)                     -- 𝟙 (the indicator function)
  ...
deriving instance Show (Constant φ)

3. Updated operational semantics

Given our new Term constructors, we should extend our functions reorder and normalForm to take them into account. Extending the latter will endow our monadic constructors Let and Return with their own set of reduction rules and, thus, a definition of normal form.

reorder :: forall γ δ ψ. (forall φ. In φ γ -> In φ δ) -> Term γ ψ -> Term δ ψ
reorder f (Var i) = Var (f i)
reorder _ (Con c) = Con c
reorder f (Lam t) = Lam (reorder g t)
  where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ))
        g First = First
        g (Next i) = Next (f i)
reorder f (App t u) = App (reorder f t) (reorder f u)
reorder f (Pair t u) = Pair (reorder f t) (reorder f u)
reorder f (Pi1 t) = Pi1 (reorder f t)
reorder f (Pi2 t) = Pi2 (reorder f t)
reorder _ Un = Un
reorder f (Return t) = Return (reorder f t)
reorder f (Let t u) = Let (reorder f t) (reorder g u)
  where g :: (forall χ. In χ (Cons φ γ) -> In χ (Cons φ δ))
        g First = First
        g (Next i) = Next (f i)

While the definition of reorder on Return is pretty uninteresting, the branch of the definition handling Let does more or less the same thing it does on the branch handling Lam: it takes into account the new bound variable, free in u, by ensuring that it is not affected by the behavior of reorder.

We can give a new definition of normalForm as follows:

normalForm :: Term γ φ -> Term γ φ
normalForm v@(Var _) = v                -- Variables are already in normal form.
normalForm c@(Con _) = c                -- So are constants.
normalForm (Lam t) = Lam (normalForm t) -- Abstractions are in normal form just in case their bodies are in normal form.
normalForm (App t u) =
  case normalForm t of
    Lam t' -> normalForm (subst0 (normalForm u) t') -- If the normal form of t is an abstraction, then we need to substitute and further normalize.
    t' -> App t' (normalForm u)                     -- Otherwise, we just need to take the normal form of the argument.
normalForm (Pair t u) = Pair (normalForm t) (normalForm u) -- Just normalize the projections.
normalForm (Pi1 t) = 
  case normalForm t of
    Pair u _ -> u -- If the normal form inside a projection is actually a pair, we should take that pair's projection.
    t' -> Pi1 t'  -- Otherwise, nothing needs to be done.
normalForm (Pi2 t) = -- Ditto.
  case normalForm t of
    Pair _ u -> u
    t' -> Pi2 t'
normalForm Un = Un -- ⋄ is already in normal form.
normalForm (Return t) = Return (normalForm t) -- Returning something doesn't change whether or not it is in normal form.
normalForm (Let t u) =
  case normalForm t of
    Return t' -> normalForm (subst0 t' (normalForm u)) -- Here we apply Left Identity.
    Let t' u' -> normalForm (Let t' (Let u' (weaken2 (normalForm u)))) -- Here we rebracket, potentially leading to another reduction based on Left Identity.
    t' -> Let t' (normalForm u) -- Here we don't do anything.

where weaken and weaken2 are defined as

-- weaken, targetting the first position in the context
weaken :: Term γ φ -> Term (Cons ψ γ) φ
weaken = reorder Next

-- weaken, targeting the second position in the context
weaken2 :: Term (Cons φ γ) ψ -> Term (Cons φ (Cons χ γ)) ψ
weaken2 = reorder g
  where g :: In ψ (Cons φ γ) -> In ψ (Cons φ (Cons χ γ))
        g First = First
        g (Next i) = Next (Next i)

It can be seen that the normal form of a term with Let and Return reduces any programs as allowed by Left Identity, as well as re-brackets programs so that such reductions may potentially be applied.

Author: Julian Grove

Created: 2023-12-10 Sun 15:18

Validate