{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

{-|
Module      : Framework.Lambda.Convenience
Description : Convenience functions, etc.
Copyright   : (c) Julian Grove and Aaron Steven White, 2025
License     : MIT
Maintainer  : julian.grove@gmail.com

Convenience functions, smart constructors, etc.
-}

module Framework.Lambda.Convenience where

import Control.Applicative
import Framework.Lambda.Terms
import Framework.Lambda.Types

--------------------------------------------------------------------------------
-- * Convenience functions, smart constructors, etc.

-- ** Type abbreviations.

α, β, ι, κ, σ, e, t, r :: Type
α :: Type
α = String -> Type
TyVar String
"α"
β :: Type
β = String -> Type
TyVar String
"β"
ι :: Type
ι = String -> Type
TyVar String
"ι"
ω :: Type
ω = String -> Type
TyVar String
"ω"
κ :: Type
κ = String -> Type
TyVar String
"κ"
σ :: Type
σ = String -> Type
TyVar String
"σ"
e :: Type
e = String -> Type
Atom String
"e" 
t :: Type
t = String -> Type
Atom String
"t"
r :: Type
r = String -> Type
Atom String
"r"

q, popQ :: Type -> Type -> Type -> Type
q :: Type -> Type -> Type -> Type
q    Type
i Type
q Type
a = String -> [Type] -> Type
TyCon String
"Q"    [Type
i, Type
q, Type
a]
popQ :: Type -> Type -> Type -> Type
popQ Type
i Type
q Type
a = String -> [Type] -> Type
TyCon String
"popQ" [Type
i, Type
q, Type
a]

pattern SCon :: String -> Term
pattern $bSCon :: String -> Term
$mSCon :: forall {r}. Term -> (String -> r) -> ((# #) -> r) -> r
SCon x = Con (Left x)

pattern DCon :: Double -> Term
pattern $bDCon :: Double -> Term
$mDCon :: forall {r}. Term -> (Double -> r) -> ((# #) -> r) -> r
DCon x = Con (Right x)

pattern Fa, GetPP, One, Zero, Tr, Undefined :: Term
pattern $bFa :: Term
$mFa :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
Fa        = SCon "F"
pattern $bGetPP :: Term
$mGetPP :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
GetPP     = Lam "s" (Return (Var "s" `Pair` Var "s"))
pattern $bOne :: Term
$mOne :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
One       = DCon 1
pattern $bZero :: Term
$mZero :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
Zero      = DCon 0
pattern $bTr :: Term
$mTr :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
Tr        = SCon "T"
pattern $bUndefined :: Term
$mUndefined :: forall {r}. Term -> ((# #) -> r) -> ((# #) -> r) -> r
Undefined = SCon "#"

pattern Bern, CG, Factor, Indi, Neg, Height, DTall, Observe, Pr, Epi, SocPla :: Term -> Term
pattern $bBern :: Term -> Term
$mBern :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Bern p    = SCon "Bernoulli" `App` p
pattern $bCG :: Term -> Term
$mCG :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
CG s      = SCon "CG" `App` s
pattern $bFactor :: Term -> Term
$mFactor :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Factor x  = SCon "factor" `App` x
pattern $bIndi :: Term -> Term
$mIndi :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Indi p    = SCon "𝟙" `App` p
pattern $bMax :: Term -> Term
$mMax :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Max pred  = SCon "max" `App` pred
pattern $bNeg :: Term -> Term
$mNeg :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Neg x     = SCon "neg" `App` x
pattern $bEpi :: Term -> Term
$mEpi :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Epi i     = SCon "epi" `App` i
pattern $bTauKnow :: Term -> Term
$mTauKnow :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
TauKnow s = SCon "tau_know" `App` s
pattern $bLing :: Term -> Term
$mLing :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Ling i    = SCon "ling" `App` i
pattern $bPhil :: Term -> Term
$mPhil :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Phil i    = SCon "phil" `App` i
pattern $bHeight :: Term -> Term
$mHeight :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Height i  = SCon "height" `App` i
pattern $bDTall :: Term -> Term
$mDTall :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
DTall s   = SCon "d_tall" `App` s
pattern $bSocPla :: Term -> Term
$mSocPla :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
SocPla i  = SCon "soc_pla" `App` i
pattern $bObserve :: Term -> Term
$mObserve :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Observe p = SCon "observe" `App` p
pattern $bPr :: Term -> Term
$mPr :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Pr t      = SCon "Pr" `App` t
pattern $bProp1 :: Term -> Term
$mProp1 :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
Prop1 i   = SCon "prop1" `App` i
pattern $bPopQUD :: Term -> Term
$mPopQUD :: forall {r}. Term -> (Term -> r) -> ((# #) -> r) -> r
PopQUD s  = Pop "QUD" s

pattern Add, And, Eq, GE, Mult, Normal, Or, UpdEpi, UpdCG, UpdHeight, UpdDTall, UpdSocPla, UpdProp1, PushQUD :: Term -> Term -> Term
pattern $bAdd :: Term -> Term -> Term
$mAdd :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Add x y         = SCon "add" `App` (Pair x y)
pattern $bAnd :: Term -> Term -> Term
$mAnd :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
And p q         = SCon "(∧)" `App` p `App` q
pattern $bOr :: Term -> Term -> Term
$mOr :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Or p q          = SCon "(∨)" `App` p `App` q
pattern $bEq :: Term -> Term -> Term
$mEq :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Eq x y          = SCon "(=)" `App` (Pair x y)
pattern $bGE :: Term -> Term -> Term
$mGE :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
GE a b          = SCon "(≥)" `App` a `App` b
pattern $bMult :: Term -> Term -> Term
$mMult :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Mult x y        = SCon "mult" `App` (Pair x y)
pattern $bBeta :: Term -> Term -> Term
$mBeta :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Beta x y        = SCon "Beta" `App` (Pair x y)
pattern $bNormal :: Term -> Term -> Term
$mNormal :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
Normal x y      = SCon "Normal" `App` (Pair x y)
pattern $bLogitNormal :: Term -> Term -> Term
$mLogitNormal :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
LogitNormal x y = SCon "Logit_normal" `App` (Pair x y)
pattern $bUpdEpi :: Term -> Term -> Term
$mUpdEpi :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdEpi acc i    = SCon "upd_epi" `App` acc `App` i
pattern $bUpdCG :: Term -> Term -> Term
$mUpdCG :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdCG cg s      = SCon "upd_CG" `App` cg `App` s
pattern $bUpdLing :: Term -> Term -> Term
$mUpdLing :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdLing p i     = SCon "upd_ling" `App` p `App` i
pattern $bUpdTauKnow :: Term -> Term -> Term
$mUpdTauKnow :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdTauKnow b s  = SCon "upd_tau_know" `App` b `App` s
pattern $bUpdHeight :: Term -> Term -> Term
$mUpdHeight :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdHeight p i   = SCon "upd_height" `App` p `App` i
pattern $bUpdDTall :: Term -> Term -> Term
$mUpdDTall :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdDTall d s    = SCon "upd_d_tall" `App` d `App` s
pattern $bUpdSocPla :: Term -> Term -> Term
$mUpdSocPla :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdSocPla p i   = SCon "upd_soc_pla" `App` p `App` i
pattern $bUpdProp1 :: Term -> Term -> Term
$mUpdProp1 :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
UpdProp1 b i    = SCon "upd_prop1" `App` b `App` i
pattern $bPushQUD :: Term -> Term -> Term
$mPushQUD :: forall {r}. Term -> (Term -> Term -> r) -> ((# #) -> r) -> r
PushQUD q s     = SCon "push_QUD" `App` q ` App` s

pattern Disj, ITE, Truncate :: Term -> Term -> Term -> Term
pattern $bDisj :: Term -> Term -> Term -> Term
$mDisj :: forall {r}.
Term -> (Term -> Term -> Term -> r) -> ((# #) -> r) -> r
Disj x m n      = SCon "disj" `App` (Pair (Pair x m) n)
pattern $bITE :: Term -> Term -> Term -> Term
$mITE :: forall {r}.
Term -> (Term -> Term -> Term -> r) -> ((# #) -> r) -> r
ITE p x y       = SCon "if_then_else" `App` (Pair (Pair p x) y)
pattern $bTruncate :: Term -> Term -> Term -> Term
$mTruncate :: forall {r}.
Term -> (Term -> Term -> Term -> r) -> ((# #) -> r) -> r
Truncate m x y  = SCon "Truncate" `App` (Pair x y) `App` m
pattern $bNormalCDF :: Term -> Term -> Term -> Term
$mNormalCDF :: forall {r}.
Term -> (Term -> Term -> Term -> r) -> ((# #) -> r) -> r
NormalCDF x y z = SCon "Normal_cdf" `App` (Pair x y) `App` z

pattern $bNormalCDF' :: String -> String -> Term -> Term -> Term -> Term
$mNormalCDF' :: forall {r}.
Term
-> (String -> String -> Term -> Term -> Term -> r)
-> ((# #) -> r)
-> r
NormalCDF' v v' x y t = Pr (Let v (Normal x y) (Return (GE t (Var v'))))

pattern LkUp :: String -> Term -> Term
pattern $bLkUp :: String -> Term -> Term
$mLkUp :: forall {r}. Term -> (String -> Term -> r) -> ((# #) -> r) -> r
LkUp c s = SCon c `App` s

pattern Upd :: String -> Term -> Term -> Term
pattern $bUpd :: String -> Term -> Term -> Term
$mUpd :: forall {r}.
Term -> (String -> Term -> Term -> r) -> ((# #) -> r) -> r
Upd c v s = SCon ('u' : 'p' : 'd' : '_' : c) `App` v `App` s

pattern Pop :: String -> Term -> Term
pattern $bPop :: String -> Term -> Term
$mPop :: forall {r}. Term -> (String -> Term -> r) -> ((# #) -> r) -> r
Pop c s = SCon ('p' : 'o' : 'p' : '_' : c) `App` s

pattern Push :: String -> Term -> Term -> Term
pattern $bPush :: String -> Term -> Term -> Term
$mPush :: forall {r}.
Term -> (String -> Term -> Term -> r) -> ((# #) -> r) -> r
Push c v s = SCon ('p' : 'u' : 's' : 'h' : '_' : c) `App` v `App` s

-- *** Convenience and smart constructors

getPP, a, b, c, d, i, k, m, n, p, s, u, v, w, x, y, z, _' :: Term
a :: Term
a  = String -> Term
Var String
"a"
b :: Term
b  = String -> Term
Var String
"b"
c :: Term
c  = String -> Term
Var String
"c"
d :: Term
d  = String -> Term
Var String
"d"
i :: Term
i  = String -> Term
Var String
"i"
j :: Term
j  = String -> Term
Var String
"j"
k :: Term
k  = String -> Term
Var String
"k"
m :: Term
m  = String -> Term
Var String
"m"
n :: Term
n  = String -> Term
Var String
"n"
p :: Term
p  = String -> Term
Var String
"p"
s :: Term
s  = String -> Term
Var String
"s"
u :: Term
u  = String -> Term
Var String
"u"
v :: Term
v  = String -> Term
Var String
"v"
w :: Term
w  = String -> Term
Var String
"w"
x :: Term
x  = String -> Term
Var String
"x"
y :: Term
y  = String -> Term
Var String
"y"
z :: Term
z  = String -> Term
Var String
"z"
_' :: Term
_' = String -> Term
Var String
"_"

_0, ϵ, prop1, prop2 :: Term
_0 :: Term
_0    = String -> Term
sCon String
"@"
ϵ :: Term
ϵ     = String -> Term
sCon String
"ϵ"
prop1 :: Term
prop1 = String -> Term
sCon String
"prop1"
prop2 :: Term
prop2 = String -> Term
sCon String
"prop2"

getPP :: Term
getPP = Term -> Term -> Term
lam Term
s (Term -> Term
Return (Term
s Term -> Term -> Term
& Term
s))

epi, cg, factor, observe, normalL, max', purePP, putPP, pr :: Term -> Term
assert :: Term -> Term
assert Term
φ       = Term
φ Term -> Term -> Term
>>>= Term -> Term -> Term
lam Term
p (Term
getPP Term -> Term -> Term
>>>= Term -> Term -> Term
lam Term
s ((Term -> Term
purePP (Term -> Term
cg Term
s)) Term -> Term -> Term
>>>= Term -> Term -> Term
lam Term
c (Term -> Term
putPP (Term -> Term -> Term
upd_CG (Term -> Term -> Term -> Term
let' Term
i Term
c (Term -> Term -> Term -> Term
let' Term
_' (Term -> Term
observe (Term
p Term -> Term -> Term
@@ Term
i)) (Term -> Term
Return Term
i))) Term
s))))
ask :: Term -> Term
ask Term
κ          = Term
κ Term -> Term -> Term
>>>= String -> Term -> Term
Lam String
"q" (Term
getPP Term -> Term -> Term
>>>= String -> Term -> Term
Lam String
"s" ((Term -> Term
putPP (Term -> Term -> Term
push_QUD (String -> Term
Var String
"q") (String -> Term
Var String
"s")))))
epi :: Term -> Term
epi Term
i          = String -> Term
sCon String
"epi" Term -> Term -> Term
@@ Term
i
cg :: Term -> Term
cg Term
s           = String -> Term
sCon String
"CG" Term -> Term -> Term
@@ Term
s
upd_CG :: Term -> Term -> Term
upd_CG Term
cg Term
s    = String -> Term
sCon String
"upd_CG" Term -> Term -> Term
@@ Term
cg Term -> Term -> Term
@@ Term
s
pop_qud :: Term -> Term
pop_qud Term
s      = String -> Term
sCon String
"pop_QUD" Term -> Term -> Term
@@ Term
s
push_QUD :: Term -> Term -> Term
push_QUD Term
q Term
s   = String -> Term
sCon String
"push_QUD" Term -> Term -> Term
@@ Term
q Term -> Term -> Term
@@ Term
s
factor :: Term -> Term
factor Term
x       = String -> Term
sCon String
"factor" Term -> Term -> Term
@@ Term
x
ling :: Term -> Term
ling Term
i         = String -> Term
sCon String
"ling" Term -> Term -> Term
@@ Term
i
phil :: Term -> Term
phil Term
i         = String -> Term
sCon String
"phil" Term -> Term -> Term
@@ Term
i
max' :: Term -> Term
max' Term
pred      = String -> Term
sCon String
"max" Term -> Term -> Term
@@ Term
pred
normalL :: Term -> Term
normalL Term
x      = Term -> Term -> Term
normal Term
x (String -> Term
sCon String
"σ")
observe :: Term -> Term
observe Term
x      = String -> Term
sCon String
"observe" Term -> Term -> Term
@@ Term
x
purePP :: Term -> Term
purePP Term
t       = String -> Term -> Term
Lam String
fr (Term -> Term
Return (Term
t Term -> Term -> Term
& String -> Term
Var String
fr))
  where String
fr:[String]
esh = [Term] -> [String]
fresh [Term
t]
putPP :: Term -> Term
putPP Term
s        = String -> Term -> Term
Lam String
fr (Term -> Term
Return (Term
TT Term -> Term -> Term
& Term
s))
  where String
fr:[String]
esh = [Term] -> [String]
fresh [Term
s]
pr :: Term -> Term
pr Term
t           = String -> Term
sCon String
"Pr" Term -> Term -> Term
@@ Term
t

(>>>=), (<**>), (<$$>), lam, normal :: Term -> Term -> Term
Term
t >>>= :: Term -> Term -> Term
>>>= Term
u    = Term -> Term -> Term
lam Term
fr (Term -> Term -> Term -> Term
let' Term
e (Term
t Term -> Term -> Term
@@ Term
fr) (Term
u Term -> Term -> Term
@@ Term -> Term
Pi1 Term
e Term -> Term -> Term
@@ Term -> Term
Pi2 Term
e))
  where Term
fr:Term
e:[Term]
sh = (String -> Term) -> [String] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map String -> Term
Var ([String] -> [Term]) -> [String] -> [Term]
forall a b. (a -> b) -> a -> b
$ [Term] -> [String]
fresh [Term
t, Term
u]
Term
m >>> :: Term -> Term -> Term
>>> Term
n     = Term
m Term -> Term -> Term
>>>= (Term -> Term -> Term
lam Term
_' Term
n)
Term
t <**> :: Term -> Term -> Term
<**> Term
u    = Term
t Term -> Term -> Term
>>>= (Term -> Term -> Term
lam Term
fr (Term
u Term -> Term -> Term
>>>= (Term -> Term -> Term
lam Term
e (Term -> Term
purePP (Term
fr Term -> Term -> Term
@@ Term
e)))))
  where Term
fr:Term
e:[Term]
sh = (String -> Term) -> [String] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map String -> Term
Var ([String] -> [Term]) -> [String] -> [Term]
forall a b. (a -> b) -> a -> b
$ [Term] -> [String]
fresh [Term
t, Term
u]
Term
t <$$> :: Term -> Term -> Term
<$$> Term
u    = Term -> Term
purePP Term
t Term -> Term -> Term
Framework.Lambda.Convenience.<**> Term
u
lam :: Term -> Term -> Term
lam (Var String
v) = String -> Term -> Term
Lam String
v
normal :: Term -> Term -> Term
normal Term
x Term
y  = String -> Term
sCon String
"Normal" Term -> Term -> Term
@@ (Term
x Term -> Term -> Term
& Term
y)

let', respond :: Term -> Term -> Term -> Term
let' :: Term -> Term -> Term -> Term
let' (Var String
v)   = String -> Term -> Term -> Term
Let String
v
respond :: Term -> Term -> Term -> Term
respond Term
f Term
bg Term
m = Term -> Term -> Term -> Term
let' Term
s Term
bg Term
m'
  where m' :: Term
m'     = Term -> Term -> Term -> Term
let' Term
_s' (Term
m Term -> Term -> Term
@@ Term
s) (Term -> Term -> Term -> Term
let' Term
i (Term -> Term
cg (Term -> Term
Pi2 Term
_s')) (Term
f Term -> Term -> Term
@@ Term -> Term
max' (Term -> Term -> Term
lam Term
x (Term -> Term
Pi1 (Term -> Term
pop_qud (Term -> Term
Pi2 Term
_s')) Term -> Term -> Term
@@ Term
x Term -> Term -> Term
@@ Term
i))))
        Term
s:Term
_s':Term
i:Term
x:[Term]
_ = (String -> Term) -> [String] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map String -> Term
Var ([String] -> [Term]) -> [String] -> [Term]
forall a b. (a -> b) -> a -> b
$ [Term] -> [String]
fresh [Term
bg, Term
m]

-- | 'Num' instance for 'Term', just as a notational convenience.
instance Num Term where
  Term
t * :: Term -> Term -> Term
* Term
u           = case (Term
t, Term
u) of
                      (DCon Double
x, DCon Double
y) -> Double -> Term
DCon (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
y)
                      (Term, Term)
_                -> String -> Term
sCon String
"mult"  Term -> Term -> Term
@@ (Term
t Term -> Term -> Term
& Term
u)
  Term
t + :: Term -> Term -> Term
+ Term
u           = case (Term
t, Term
u) of
                      (DCon Double
x, DCon Double
y) -> Double -> Term
DCon (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
y)
                      (Term, Term)
_                -> String -> Term
sCon String
"add" Term -> Term -> Term
@@ (Term
t Term -> Term -> Term
& Term
u)
  negate :: Term -> Term
negate Term
t        = String -> Term
sCon String
"neg"   Term -> Term -> Term
@@ Term
t
  fromInteger :: Integer -> Term
fromInteger Integer
x   = Double -> Term
dCon (Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
x)
  signum :: Term -> Term
signum (DCon Double
x) = Double -> Term
DCon (Double -> Double
forall a. Num a => a -> a
signum Double
x)
  abs :: Term -> Term
abs (DCon Double
x)    = Double -> Term
DCon (Double -> Double
forall a. Num a => a -> a
abs Double
x)

-- *** Generic functions

-- | Compute entailments.
entails :: Term -> Term -> Bool
entails :: Term -> Term -> Bool
entails Term
p Term
q         | Term
p Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
q = Bool
True
entails Term
p (And Term
q Term
r)          = Term -> Term -> Bool
entails Term
p Term
q Bool -> Bool -> Bool
&& Term -> Term -> Bool
entails Term
p Term
r
entails Term
p (Or Term
q Term
r)           = Term -> Term -> Bool
entails Term
p Term
q Bool -> Bool -> Bool
|| Term -> Term -> Bool
entails Term
p Term
r
entails (And Term
p Term
q) Term
r          = Term -> Term -> Bool
entails Term
p Term
r Bool -> Bool -> Bool
|| Term -> Term -> Bool
entails Term
q Term
r
entails (Or Term
p Term
q) Term
r           = Term -> Term -> Bool
entails Term
p Term
r Bool -> Bool -> Bool
&& Term -> Term -> Bool
entails Term
q Term
r
entails Term
_ Term
_                  = Bool
False

-- | Collect up constants appearing in some term.
cons :: Term -> [Constant]
cons :: Term -> [Constant]
cons = \case
  Var String
v     -> []
  Con Constant
c     -> [Constant
c]
  Lam String
v Term
t   -> Term -> [Constant]
cons Term
t
  App Term
t Term
u   -> Term -> [Constant]
cons Term
t [Constant] -> [Constant] -> [Constant]
forall a. [a] -> [a] -> [a]
++ Term -> [Constant]
cons Term
u
  Term
TT        -> []
  Pair Term
t Term
u  -> Term -> [Constant]
cons Term
t [Constant] -> [Constant] -> [Constant]
forall a. [a] -> [a] -> [a]
++ Term -> [Constant]
cons Term
u
  Pi1 Term
t     -> Term -> [Constant]
cons Term
t
  Pi2 Term
t     -> Term -> [Constant]
cons Term
t
  Let String
v Term
t Term
u -> Term -> [Constant]
cons Term
t [Constant] -> [Constant] -> [Constant]
forall a. [a] -> [a] -> [a]
++ Term -> [Constant]
cons Term
u
  Return Term
t  -> Term -> [Constant]
cons Term
t

-- | True of probabilistic programs that only sample, i.e., do not perform
-- inference.
sampleOnly :: Term -> Bool
sampleOnly :: Term -> Bool
sampleOnly = \case
  Bern Term
_          -> Bool
True
  Normal Term
_ Term
_      -> Bool
True
  LogitNormal Term
_ Term
_ -> Bool
True
  Truncate Term
_ Term
_ Term
_  -> Bool
True
  Term
_               -> Bool
False


-- | Combining signatures and rules
(<||>) :: Alternative m => (a -> m b) -> (a -> m b) -> a -> m b
a -> m b
f <||> :: forall (m :: * -> *) a b.
Alternative m =>
(a -> m b) -> (a -> m b) -> a -> m b
<||> a -> m b
g = \a
x -> a -> m b
f a
x m b -> m b -> m b
forall a. m a -> m a -> m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> a -> m b
g a
x

-- | Overwrite one state or index onto another, given a list of relevant
-- parameters.
overwrite :: [String] -> Term -> Term -> Term
overwrite :: [String] -> Term -> Term -> Term
overwrite []       Term
_ Term
j = Term
j
overwrite (String
c : [String]
cs) Term
i Term
j = String -> Term -> Term -> Term
Upd String
c (String -> Term -> Term
LkUp String
c Term
i) ([String] -> Term -> Term -> Term
overwrite [String]
cs Term
i Term
j)