{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

{-|
Module      : Theory.Delta
Description : Delta rules.
Copyright   : (c) Julian Grove and Aaron Steven White, 2025
License     : MIT
Maintainer  : julian.grove@gmail.com

Delta rules are defined. These encode algebraic laws relating λ-terms that
feature constants.
-}

module Framework.Lambda.Delta where

import Data.List
import Framework.Lambda.Convenience
import Framework.Lambda.Terms

--------------------------------------------------------------------------------
-- * Delta rules

-- ** Example rules

-- | Performs some arithmetic simplifications.
arithmetic :: DeltaRule
arithmetic :: DeltaRule
arithmetic = \case
  Add Term
t Term
u      -> case Term
t of
                    Term
Zero -> DeltaRule
forall a. a -> Maybe a
Just Term
u
                    x :: Term
x@(DCon Double
_) -> case Term
u of
                                    Term
Zero       -> DeltaRule
forall a. a -> Maybe a
Just Term
x
                                    y :: Term
y@(DCon Double
_) -> DeltaRule
forall a. a -> Maybe a
Just (Term
x Term -> Term -> Term
forall a. Num a => a -> a -> a
+ Term
y)
                                    Term
_          -> Maybe Term
forall a. Maybe a
Nothing
                    Term
t'         -> case Term
u of
                                    Term
Zero -> DeltaRule
forall a. a -> Maybe a
Just Term
t'
                                    Term
_    -> Maybe Term
forall a. Maybe a
Nothing
  Mult Term
t Term
u     -> case Term
t of
                     Term
Zero       -> DeltaRule
forall a. a -> Maybe a
Just Term
Zero
                     Term
One        -> DeltaRule
forall a. a -> Maybe a
Just Term
u
                     x :: Term
x@(DCon Double
_) -> case Term
u of
                                     Term
Zero       -> DeltaRule
forall a. a -> Maybe a
Just Term
Zero
                                     Term
One        -> DeltaRule
forall a. a -> Maybe a
Just Term
x
                                     y :: Term
y@(DCon Double
_) -> DeltaRule
forall a. a -> Maybe a
Just (Term
x Term -> Term -> Term
forall a. Num a => a -> a -> a
* Term
y)
                     Term
t'         -> case Term
u of
                                     Term
Zero -> DeltaRule
forall a. a -> Maybe a
Just Term
Zero
                                     Term
One  -> DeltaRule
forall a. a -> Maybe a
Just Term
t'
                                     Term
_    -> Maybe Term
forall a. Maybe a
Nothing
  Neg (DCon Double
x) -> DeltaRule
forall a. a -> Maybe a
Just (Double -> Term
dCon (-Double
x))
  Term
_            -> Maybe Term
forall a. Maybe a
Nothing

-- | Get rid of vacuous let-bindings.
cleanUp :: DeltaRule
cleanUp :: DeltaRule
cleanUp = \case
  Let String
v Term
m Term
k | Term -> Bool
sampleOnly Term
m Bool -> Bool -> Bool
&& String
v String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Term -> [String]
freeVars Term
k -> DeltaRule
forall a. a -> Maybe a
Just Term
k
  Term
_                                                  -> Maybe Term
forall a. Maybe a
Nothing

-- | Marginalizes out certain distributions; some other stuff.
disjunctions :: DeltaRule
disjunctions :: DeltaRule
disjunctions = \case
  Let  String
b (Bern Term
x)     Term
k                  -> DeltaRule
forall a. a -> Maybe a
Just (Term -> Term -> Term -> Term
Disj Term
x
                                                  (String -> Term -> Term -> Term
subst String
b Term
Tr Term
k) (String -> Term -> Term -> Term
subst String
b Term
Fa Term
k)
                                                 )
  Let  String
v (Disj Term
x Term
m Term
n) Term
k                  -> DeltaRule
forall a. a -> Maybe a
Just (Term -> Term -> Term -> Term
Disj Term
x
                                                  (String -> Term -> Term -> Term
Let String
v Term
m Term
k) (String -> Term -> Term -> Term
Let String
v Term
n Term
k)
                                                 )
  Disj Term
_ Term
m            Term
n         | Term
m Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
n -> DeltaRule
forall a. a -> Maybe a
Just Term
m
  Disj Term
_ Term
m            Term
Undefined          -> DeltaRule
forall a. a -> Maybe a
Just Term
m
  Disj Term
_ Term
Undefined    Term
n                  -> DeltaRule
forall a. a -> Maybe a
Just Term
n
  Term
_                                      -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes syntactic equalities.
equality :: DeltaRule
equality :: DeltaRule
equality = \case
  Eq Term
x Term
y | Term
x Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
y -> DeltaRule
forall a. a -> Maybe a
Just Term
Tr
  Term
_               -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes the indicator function.
indicator :: DeltaRule
indicator :: DeltaRule
indicator = \case
  Indi Term
Tr -> DeltaRule
forall a. a -> Maybe a
Just Term
1
  Indi Term
Fa -> DeltaRule
forall a. a -> Maybe a
Just Term
0
  Term
_       -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes /if then else/.
ite :: DeltaRule
ite :: DeltaRule
ite = \case
  ITE Term
Tr Term
x Term
y -> DeltaRule
forall a. a -> Maybe a
Just Term
x
  ITE Term
Fa Term
x Term
y -> DeltaRule
forall a. a -> Maybe a
Just Term
y
  Term
_          -> Maybe Term
forall a. Maybe a
Nothing

logical :: DeltaRule
logical :: DeltaRule
logical = \case
  And Term
p  Term
Tr -> DeltaRule
forall a. a -> Maybe a
Just Term
p
  And Term
Tr Term
p  -> DeltaRule
forall a. a -> Maybe a
Just Term
p
  And Term
Fa Term
_  -> DeltaRule
forall a. a -> Maybe a
Just Term
Fa
  And Term
_  Term
Fa -> DeltaRule
forall a. a -> Maybe a
Just Term
Fa
  Or  Term
p  Term
Fa -> DeltaRule
forall a. a -> Maybe a
Just Term
p
  Or  Term
Fa Term
p  -> DeltaRule
forall a. a -> Maybe a
Just Term
p
  Or  Term
Tr Term
_  -> DeltaRule
forall a. a -> Maybe a
Just Term
Tr
  Or  Term
_  Term
Tr -> DeltaRule
forall a. a -> Maybe a
Just Term
Tr
  Term
_         -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes the /max/ function.
maxes :: DeltaRule
maxes :: DeltaRule
maxes = \case
   Max (Lam String
y (GE Term
x (Var String
y'))) | String
y' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
y -> DeltaRule
forall a. a -> Maybe a
Just Term
x
   Term
_                                     -> Maybe Term
forall a. Maybe a
Nothing          

-- | Observing @Tr@ is trivial, while observing @Fa@ yields an undefined
-- probability distribution.
observations :: DeltaRule
observations :: DeltaRule
observations = \case
  Let String
_ (Observe Term
Tr) Term
k -> DeltaRule
forall a. a -> Maybe a
Just Term
k
  Let String
_ (Observe Term
Fa) Term
_ -> DeltaRule
forall a. a -> Maybe a
Just Term
Undefined
  Term
_                    -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes probabilities for certain probabilitic programs.
probabilities :: DeltaRule
probabilities :: DeltaRule
probabilities = \case
  Pr (Return Term
Tr)                                             -> DeltaRule
forall a. a -> Maybe a
Just Term
1
  Pr (Return Term
Fa)                                             -> DeltaRule
forall a. a -> Maybe a
Just Term
0
  Pr (Bern Term
x)                                                -> DeltaRule
forall a. a -> Maybe a
Just Term
x
  Pr (Disj Term
x Term
t Term
u)                                            -> DeltaRule
forall a. a -> Maybe a
Just (Term
x Term -> Term -> Term
forall a. Num a => a -> a -> a
* Term -> Term
Pr Term
t Term -> Term -> Term
forall a. Num a => a -> a -> a
+ (Term
1 Term -> Term -> Term
forall a. Num a => a -> a -> a
- Term
x) Term -> Term -> Term
forall a. Num a => a -> a -> a
* Term -> Term
Pr Term
u)
  Pr (Let String
v (Normal Term
x Term
y) (Return (GE Term
t (Var String
v')))) | String
v' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v -> DeltaRule
forall a. a -> Maybe a
Just (Term -> Term -> Term -> Term
NormalCDF Term
x Term
y Term
t)
  Pr (Let String
v (Normal Term
x Term
y) (Return (GE (Var String
v') Term
t))) | String
v' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
v -> DeltaRule
forall a. a -> Maybe a
Just (Term
1 Term -> Term -> Term
forall a. Num a => a -> a -> a
- Term -> Term -> Term -> Term
NormalCDF Term
x Term
y Term
t)
  Term
_                                                          -> Maybe Term
forall a. Maybe a
Nothing

-- | Computes functions on indices and states. These include reading and writing
-- to locations of fixed type, as well as pushing to and popping from stacks,
-- thus possibly modifying the type of the state.
states :: DeltaRule
states :: DeltaRule
states = \case
  LkUp String
c   (Upd  String
c' Term
v Term
_) | String
c' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
c -> DeltaRule
forall a. a -> Maybe a
Just Term
v
  LkUp String
c   (Upd  String
c' Term
_ Term
s) | String
c' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
/= String
c -> DeltaRule
forall a. a -> Maybe a
Just (String -> Term -> Term
LkUp String
c Term
s)
  Upd  String
c Term
v (Upd  String
c' Term
_ Term
s) | String
c' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
c -> DeltaRule
forall a. a -> Maybe a
Just (String -> Term -> Term -> Term
Upd String
c Term
v Term
s)
  Pop  String
c   (Push String
c' Term
v Term
s) | String
c' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
c -> DeltaRule
forall a. a -> Maybe a
Just (Term
v Term -> Term -> Term
& Term
s)
  Pop  String
c   (Push String
c' Term
v Term
s) | String
c' String -> String -> Bool
forall a. Eq a => a -> a -> Bool
/= String
c -> DeltaRule
forall a. a -> Maybe a
Just (Term
v' Term -> Term -> Term
& Term
s')
    where v', s' :: Term
          v' :: Term
v' = Term -> Term
Pi1 (String -> Term -> Term
Pop String
c Term
s)
          s' :: Term
s' = String -> Term -> Term -> Term
Push String
c' Term
v (Term -> Term
Pi2 (String -> Term -> Term
Pop String
c Term
s))
  LkUp String
c   (Push String
_  Term
_ Term
s)           -> DeltaRule
forall a. a -> Maybe a
Just (String -> Term -> Term
LkUp String
c Term
s)
  Pop  String
c   (Upd  String
c' Term
v Term
s)           -> DeltaRule
forall a. a -> Maybe a
Just (Term
v' Term -> Term -> Term
& Term
s')
    where v', s' :: Term
          v' :: Term
v' = Term -> Term
Pi1 (String -> Term -> Term
Pop String
c Term
s)
          s' :: Term
s' = String -> Term -> Term -> Term
Upd String
c' Term
v (Term -> Term
Pi2 (String -> Term -> Term
Pop String
c Term
s))
  Term
_                                -> Maybe Term
forall a. Maybe a
Nothing