{-# LANGUAGE LambdaCase #-}
module Framework.Target.Stan where
import Analysis.Adjectives.Adjectives
import Analysis.Factivity.Factivity
import Control.Monad.Writer
import Control.Monad.State
import Data.Char (toLower)
import Framework.Lambda
import Framework.Grammar
import Theory.Signature
type Distr = String
type VarName = String
data Model = Model { Model -> [(VarName, VarName)]
statements :: [(VarName, Distr)] } deriving (Model -> Model -> Bool
(Model -> Model -> Bool) -> (Model -> Model -> Bool) -> Eq Model
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Model -> Model -> Bool
== :: Model -> Model -> Bool
$c/= :: Model -> Model -> Bool
/= :: Model -> Model -> Bool
Eq)
instance Show Model where
show :: Model -> VarName
show (Model [(VarName, VarName)]
m) = VarName
"model {\n // FIXED EFFECTS\n" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)] -> VarName
render [(VarName, VarName)]
m VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
"}"
where render :: [(VarName, VarName)] -> VarName
render [] = VarName
""
render [(VarName
v, VarName
d)] = VarName
" \n // LIKELIHOOD\n " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
VarName
"target += " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
d VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
";\n"
render ((VarName
v, VarName
d) : [(VarName, VarName)]
s) = VarName
" " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" ~ " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
d VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
";\n" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)] -> VarName
render [(VarName, VarName)]
s
data Error = TypeError deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq)
instance Show Error where
show :: Error -> VarName
show Error
TypeError = VarName
"Error: Term does not have type P r!"
stanShow :: Term -> String
stanShow :: Term -> VarName
stanShow v :: Term
v@(Var VarName
_) = Term -> VarName
forall a. Show a => a -> VarName
show Term
v
stanShow x :: Term
x@(DCon Double
_) = Term -> VarName
forall a. Show a => a -> VarName
show Term
x
stanShow (NormalCDF Term
x Term
y Term
z) = VarName
"normal_cdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
stanShow (Add Term
x (Neg Term
y)) = Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" - " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y
stanShow (Add Term
x Term
y) = Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" + " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y
stanShow (Neg Term
x) = VarName
"-" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x
lRender :: VarName -> Term -> String
lRender :: VarName -> Term -> VarName
lRender VarName
v (Truncate (Normal Term
x Term
y) Term
z Term
w) = VarName
"truncated_normal_lpdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
VarName
" | " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
w VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
lRender VarName
v (Normal Term
x Term
y) = VarName
"normal_lpdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
VarName
" | " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
lRender VarName
v (Disj Term
x Term
y Term
z) = VarName
"log_mix(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName -> Term -> VarName
lRender VarName
v Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName -> Term -> VarName
lRender VarName
v Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
pRender :: Term -> String
pRender :: Term -> VarName
pRender (Normal Term
x Term
y) = VarName
"normal(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
pRender (LogitNormal Term
x Term
y) = VarName
"logit_normal(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
pRender (Truncate Term
m Term
x Term
y) = Term -> VarName
pRender Term
m VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" T[" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
"]"
toStan :: Term -> Writer [Error] Model
toStan :: Term -> Writer [Error] Model
toStan = \case
Term
t | Typed -> Maybe Type
typeOf (Sig -> Term -> Typed
ty Sig
tau0 Term
t) Maybe Type -> Maybe Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Type
P (VarName -> Type
Atom VarName
"r")) -> do
[Error] -> WriterT [Error] Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Error
TypeError]
Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VarName, VarName)] -> Model
Model [])
t :: Term
t@(Let VarName
x Term
y Term
z) -> Term -> Writer [Error] Model
toStan' Term
t
where toStan' :: Term -> Writer [Error] Model
toStan' (Let VarName
x Term
y Term
z) = do
Model
yResult <- Term -> Writer [Error] Model
toStan' Term
y
case Model
yResult of
Model ((VarName
_, VarName
distr) : [(VarName, VarName)]
ys) -> do
Model [(VarName, VarName)]
zs <- Term -> Writer [Error] Model
toStan Term
z
Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Model -> Writer [Error] Model) -> Model -> Writer [Error] Model
forall a b. (a -> b) -> a -> b
$ [(VarName, VarName)] -> Model
Model ((VarName
x, VarName
distr) (VarName, VarName) -> [(VarName, VarName)] -> [(VarName, VarName)]
forall a. a -> [a] -> [a]
: [(VarName, VarName)]
ys [(VarName, VarName)]
-> [(VarName, VarName)] -> [(VarName, VarName)]
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)]
zs)
toStan' Term
result = Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Model -> Writer [Error] Model) -> Model -> Writer [Error] Model
forall a b. (a -> b) -> a -> b
$ [(VarName, VarName)] -> Model
Model [(VarName
"", Term -> VarName
pRender Term
result)]
Term
result -> do
Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VarName, VarName)] -> Model
Model [(VarName
"y", VarName -> Term -> VarName
lRender VarName
"y" Term
result)])