{-# LANGUAGE LambdaCase #-}

{-|
Module      : Framework.Lambda.Terms
Description : λ-calculus (with probabilistic programs).
Copyright   : (c) Julian Grove and Aaron Steven White, 2025
License     : MIT
Maintainer  : julian.grove@gmail.com

We encode (untyped) λ-calculus, with constants, and including a definition of
probabilistic programs.
-}

module Framework.Lambda.Terms ( betaDeltaNormal
                              , betaEtaNormal
                              , betaNormal
                              , Constant
                              , dCon
                              , DeltaRule
                              , etaNormal
                              , freeVars
                              , fresh
                              , sCon
                              , subst
                              , Term(..)
                              , (@@)
                              , (&)
                              ) where

import Control.Monad.State (evalStateT, get, lift, put, StateT)
import Data.Char           (toLower)

--------------------------------------------------------------------------------
-- * Untyped λ-terms

-- ** Terms

-- *** Definitions

-- | Constants are indexed by either strings or real numbers.
type Constant = Either String Double

-- | Variable names are represented by strings.
type VarName = String

teVars :: [VarName]
teVars :: [VarName]
teVars = VarName
"" VarName -> [VarName] -> [VarName]
forall a. a -> [a] -> [a]
: (Integer -> VarName) -> [Integer] -> [VarName]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> VarName
forall a. Show a => a -> VarName
show [Integer]
ints [VarName] -> (VarName -> [VarName]) -> [VarName]
forall a b. [a] -> (a -> [b]) -> [b]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \VarName
i -> (Char -> VarName) -> VarName -> [VarName]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> VarName -> VarName
forall a. a -> [a] -> [a]
:VarName
i) [Char
'u'..Char
'z']
  where ints :: [Integer]
        ints :: [Integer]
ints = Integer
1 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> Integer
forall a. Enum a => a -> a
succ [Integer]
ints

-- | Untyped λ-terms. Types are assigned separately (i.e., "extrinsically").
data Term = Var VarName           -- Variables.
          | Con Constant          -- Constants.
          | Lam VarName Term      -- Abstractions.
          | App Term Term         -- Applications.
          | TT                    -- The 0-tuple.
          | Pair Term Term        -- Pairing.
          | Pi1 Term              -- First projection.
          | Pi2 Term              -- Second projection.
          | Return Term           -- Construct a degenerate distribution.
          | Let VarName Term Term -- Sample from a distribution and continue.

instance Eq Term where
  Term
x == :: Term -> Term -> Bool
== Term
y = [VarName] -> (Term, Term) -> Bool
alphaEq [VarName]
teVars (Term
x, Term
y)
    where alphaEq :: [VarName] -> (Term, Term) -> Bool
          alphaEq :: [VarName] -> (Term, Term) -> Bool
alphaEq (VarName
n:[VarName]
ns) = \case
            (Var VarName
x, Var VarName
y)         -> VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y
            (Con Constant
x, Con Constant
y)         -> Constant
x Constant -> Constant -> Bool
forall a. Eq a => a -> a -> Bool
== Constant
y
            (Lam VarName
x Term
t, Lam VarName
y Term
u)     -> [VarName] -> (Term, Term) -> Bool
alphaEq [VarName]
ns
                                      ( VarName -> Term -> Term -> Term
subst VarName
x (VarName -> Term
Var VarName
n) Term
t
                                      , VarName -> Term -> Term -> Term
subst VarName
y (VarName -> Term
Var VarName
n) Term
u )
            (App Term
t Term
u, App Term
r Term
s)     -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
r Bool -> Bool -> Bool
&& Term
u Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
s
            (Term
TT, Term
TT)               -> Bool
True
            (Pair Term
t Term
u, Pair Term
r Term
s)   -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
r Bool -> Bool -> Bool
&& Term
u Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
s
            (Pi1 Term
t, Pi1 Term
u)         -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
u
            (Pi2 Term
t, Pi2 Term
u)         -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
u
            (Return Term
t, Return Term
u)   -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
u
            (Let VarName
x Term
t Term
u, Let VarName
y Term
r Term
s) -> Term
t Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
r Bool -> Bool -> Bool
&&
                                      [VarName] -> (Term, Term) -> Bool
alphaEq [VarName]
ns
                                      ( VarName -> Term -> Term -> Term
subst VarName
x (VarName -> Term
Var VarName
n) Term
u
                                      , VarName -> Term -> Term -> Term
subst VarName
y (VarName -> Term
Var VarName
n) Term
s )
            (Term, Term)
_                      -> Bool
False

instance Show Term where
  show :: Term -> VarName
show = \case
    Var VarName
v             -> VarName
v
    Con (Left VarName
s)      -> VarName
s
    Con (Right Double
d)     -> Double -> VarName
forall a. Show a => a -> VarName
show Double
d
    Lam VarName
v Term
t           -> VarName
"λ" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
"." VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t
    App t :: Term
t@(Lam VarName
_ Term
_) Term
u -> VarName
"(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
u VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")"
    App Term
t Term
u           -> Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
"(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
u VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")"
    Term
TT                -> VarName
"⋄"
    Pair Term
t Term
u          -> VarName
"⟨" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
u VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
"⟩"
    Pi1 Term
t             -> VarName
"π₁(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")"
    Pi2 Term
t             -> VarName
"π₂(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")"
    Return Term
t          -> VarName
"[" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
"]"
    Let VarName
v Term
t Term
u         -> case Term -> Term
betaEtaNormal Term
t of
                           Con (Left VarName
"factor") `App` Term
_  -> VarName
"(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
rest
                           Con (Left VarName
"observe") `App` Term
_ -> VarName
"(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
rest
                           Term
_                            -> VarName
"(" VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
" ∼ " VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
rest
                           where rest :: VarName
rest = Term -> VarName
forall a. Show a => a -> VarName
show Term
t VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
"; " VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
u VarName -> VarName -> VarName
forall a. [a] -> [a] -> [a]
++ VarName
")"

-- **** Smart-ish constructors

-- | Make arbitrarily typed constants.
sCon :: String -> Term 
sCon :: VarName -> Term
sCon VarName
s = Constant -> Term
Con (VarName -> Constant
forall a b. a -> Either a b
Left VarName
s)

-- | Turn a 'Double' into a constant.
dCon :: Double -> Term
dCon :: Double -> Term
dCon Double
d = Constant -> Term
Con (Double -> Constant
forall a b. b -> Either a b
Right Double
d)

-- | Abbreviations for application and pairing.
(@@), (&) :: Term -> Term -> Term
Term
t @@ :: Term -> Term -> Term
@@ Term
u = Term -> Term -> Term
App Term
t Term
u
Term
t & :: Term -> Term -> Term
& Term
u = Term -> Term -> Term
Pair Term
t Term
u

-- *** Functions, relations on terms

freeVars :: Term -> [VarName]
freeVars :: Term -> [VarName]
freeVars = \case
  Var VarName
v     -> [VarName
v]
  Con Constant
_     -> []
  Lam VarName
v Term
t   -> (VarName -> Bool) -> [VarName] -> [VarName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
/= VarName
v) (Term -> [VarName]
freeVars Term
t)
  App Term
t Term
u   -> Term -> [VarName]
freeVars Term
t [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ Term -> [VarName]
freeVars Term
u
  Term
TT        -> []
  Pair Term
t Term
u  -> Term -> [VarName]
freeVars Term
t [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ Term -> [VarName]
freeVars Term
u
  Pi1 Term
t     -> Term -> [VarName]
freeVars Term
t
  Pi2 Term
t     -> Term -> [VarName]
freeVars Term
t
  Let VarName
v Term
t Term
u -> Term -> [VarName]
freeVars Term
t [VarName] -> [VarName] -> [VarName]
forall a. [a] -> [a] -> [a]
++ (VarName -> Bool) -> [VarName] -> [VarName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
/=VarName
v) (Term -> [VarName]
freeVars Term
u)
  Return Term
t  -> Term -> [VarName]
freeVars Term
t

fresh :: [Term] -> [VarName]
fresh :: [Term] -> [VarName]
fresh [Term]
ts = (VarName -> Bool) -> [VarName] -> [VarName]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VarName -> [VarName] -> Bool) -> [VarName] -> VarName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VarName -> [VarName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
notElem ([Term]
ts [Term] -> (Term -> [VarName]) -> [VarName]
forall a b. [a] -> (a -> [b]) -> [b]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term -> [VarName]
freeVars)) [VarName]
teVars

-- | Substitutions.
subst :: VarName -> Term -> Term -> Term
subst :: VarName -> Term -> Term -> Term
subst VarName
x Term
y = \case
  Var VarName
v     | VarName
v VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x   -> Term
y
  Var VarName
v     | VarName
v VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
/= VarName
x   -> VarName -> Term
Var VarName
v
  c :: Term
c@(Con Constant
_)            -> Term
c
  Lam VarName
v Term
t   | VarName
v VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x   -> VarName -> Term -> Term
Lam VarName
v Term
t
  Lam VarName
v Term
t   | VarName
v VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
/= VarName
x   -> VarName -> Term -> Term
Lam VarName
fr (VarName -> Term -> Term -> Term
subst VarName
x Term
y (VarName -> Term -> Term -> Term
subst VarName
v (VarName -> Term
Var VarName
fr) Term
t))
    where VarName
fr:[VarName]
esh = [Term] -> [VarName]
fresh [Term
t, Term
y, VarName -> Term
Var VarName
x]
  App Term
t Term
u              -> VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t Term -> Term -> Term
@@ VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
u
  Term
TT                   -> Term
TT
  Pair Term
t Term
u             -> VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t Term -> Term -> Term
& VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
u
  Pi1 Term
t                -> Term -> Term
Pi1 (VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t)
  Pi2 Term
t                -> Term -> Term
Pi2 (VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t)
  Return Term
t             -> Term -> Term
Return (VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t)
  Let VarName
v Term
t Term
u            -> VarName -> Term -> Term -> Term
Let VarName
fr (VarName -> Term -> Term -> Term
subst VarName
x Term
y Term
t) (VarName -> Term -> Term -> Term
subst VarName
x Term
y (VarName -> Term -> Term -> Term
subst VarName
v (VarName -> Term
Var VarName
fr) Term
u))
    where VarName
fr:[VarName]
esh = [Term] -> [VarName]
fresh [Term
u, Term
y, VarName -> Term
Var VarName
x]

-- | The type of Delta rules.
type DeltaRule = Term -> Maybe Term

-- | Beta normal forms, taking delta rules into account.
betaDeltaNormal :: DeltaRule -> Term -> Term
betaDeltaNormal :: DeltaRule -> Term -> Term
betaDeltaNormal DeltaRule
delta = Term -> Term
continue (Term -> Term) -> (Term -> Term) -> Term -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
        v :: Term
v@(Var VarName
_) -> Term
v
        c :: Term
c@(Con Constant
_) -> Term
c
        Lam VarName
v Term
t   -> VarName -> Term -> Term
Lam VarName
v (Term -> Term
bdnd Term
t)
        App Term
t Term
u   -> case Term -> Term
bdnd Term
t of
          Lam VarName
v Term
t' -> Term -> Term
bdnd (VarName -> Term -> Term -> Term
subst VarName
v Term
u Term
t')
          Term
t'       -> Term -> Term -> Term
App Term
t' (Term -> Term
bdnd Term
u)
        Term
TT        -> Term
TT
        Pair Term
t Term
u  -> Term -> Term -> Term
Pair (Term -> Term
bdnd Term
t) (Term -> Term
bdnd Term
u)
        Pi1 Term
t     -> case Term -> Term
bdnd Term
t of
                       Pair Term
x Term
_ -> Term
x
                       Term
t'       -> Term -> Term
Pi1 Term
t'
        Pi2 Term
t     -> case Term -> Term
bdnd Term
t of
                       Pair Term
_ Term
y -> Term
y
                       Term
t'       -> Term -> Term
Pi2 Term
t'
        Return Term
t  -> Term -> Term
Return (Term -> Term
bdnd Term
t)
        Let VarName
v Term
t Term
u -> case Term -> Term
bdnd Term
t of
                       Return Term
t'  -> Term -> Term
bdnd (VarName -> Term -> Term -> Term
subst VarName
v Term
t' Term
u)
                       Let VarName
w Term
t' Term
x -> Term -> Term
bdnd (VarName -> Term -> Term -> Term
Let VarName
fr Term
t'
                                           (VarName -> Term -> Term -> Term
Let VarName
v (VarName -> Term -> Term -> Term
subst VarName
w (VarName -> Term
Var VarName
fr) Term
x) Term
u)
                                          )
                         where VarName
fr:[VarName]
esh = [Term] -> [VarName]
fresh [Term
u, Term
x]
                       Term
t'         -> VarName -> Term -> Term -> Term
Let VarName
v Term
t' (Term -> Term
bdnd Term
u)
  where continue :: Term -> Term
        continue :: Term -> Term
continue Term
t = case (Term -> Term) -> Maybe Term -> Maybe Term
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Term -> Term
bdnd (DeltaRule
delta Term
t) of
                       Just Term
t' -> Term
t'
                       Maybe Term
Nothing -> Term
t

        bdnd :: Term -> Term
        bdnd :: Term -> Term
bdnd = DeltaRule -> Term -> Term
betaDeltaNormal DeltaRule
delta
        
  
-- | Beta normal forms.
betaNormal :: Term -> Term
betaNormal :: Term -> Term
betaNormal = DeltaRule -> Term -> Term
betaDeltaNormal (Maybe Term -> DeltaRule
forall a b. a -> b -> a
const Maybe Term
forall a. Maybe a
Nothing)

-- | Eta normal forms.
etaNormal :: Term -> Term
etaNormal :: Term -> Term
etaNormal = \case
  v :: Term
v@(Var VarName
_) -> Term
v
  c :: Term
c@(Con Constant
_) -> Term
c
  Lam VarName
v Term
t   -> case Term -> Term
etaNormal Term
t of
                 App Term
x Term
y | Term
y Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== VarName -> Term
Var VarName
v Bool -> Bool -> Bool
&& VarName
v VarName -> [VarName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Term -> [VarName]
freeVars Term
x -> Term
x
                 Term
t'                                             -> VarName -> Term -> Term
Lam VarName
v Term
t'
  App Term
t Term
u   -> Term -> Term
etaNormal Term
t Term -> Term -> Term
@@ Term -> Term
etaNormal Term
u
  Term
TT        -> Term
TT
  Pair Term
t Term
u  -> case (Term -> Term
etaNormal Term
t, Term -> Term
etaNormal Term
u) of
                 (Pi1 Term
t', Pi2 Term
u') | Term
t' Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
u' -> Term
t'
                 (Term
t', Term
u')                    -> Term -> Term -> Term
Pair Term
t' Term
u'
  Pi1 Term
t     -> Term -> Term
Pi1 (Term -> Term
etaNormal Term
t)
  Pi2 Term
t     -> Term -> Term
Pi2 (Term -> Term
etaNormal Term
t)
  Return Term
t  -> Term -> Term
Return (Term -> Term
etaNormal Term
t)
  Let VarName
v Term
t Term
u -> case Term -> Term
etaNormal Term
u of
                 Return Term
u' | Term
u' Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== VarName -> Term
Var VarName
v -> Term -> Term
etaNormal Term
t
                 Term
u'                      -> VarName -> Term -> Term -> Term
Let VarName
v (Term -> Term
etaNormal Term
t) Term
u'

-- | Beta-eta normal forms.
betaEtaNormal :: Term -> Term
betaEtaNormal :: Term -> Term
betaEtaNormal = Term -> Term
etaNormal (Term -> Term) -> (Term -> Term) -> Term -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Term
betaNormal