{-# LANGUAGE LambdaCase #-}

{-|
Module      : Framework.Target.Stan
Description : Exports probabilistic programs as Stan code.
Copyright   : (c) Julian Grove and Aaron Steven White, 2025
License     : MIT
Maintainer  : julian.grove@gmail.com

Probabilistic programs encoded as λ-terms are translated into Stan code.
-}

module Framework.Target.Stan where

import Analysis.Adjectives.Adjectives
import Analysis.Factivity.Factivity
import Control.Monad.Writer
import Control.Monad.State
import Data.Char                      (toLower)
import Framework.Lambda
import Framework.Grammar
import Theory.Signature

type Distr   = String
type VarName = String

data Model = Model { Model -> [(VarName, VarName)]
statements :: [(VarName, Distr)] } deriving (Model -> Model -> Bool
(Model -> Model -> Bool) -> (Model -> Model -> Bool) -> Eq Model
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Model -> Model -> Bool
== :: Model -> Model -> Bool
$c/= :: Model -> Model -> Bool
/= :: Model -> Model -> Bool
Eq)

instance Show Model where
  show :: Model -> VarName
show (Model [(VarName, VarName)]
m) = VarName
"model {\n  // FIXED EFFECTS\n" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)] -> VarName
render [(VarName, VarName)]
m VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
"}"
    where render :: [(VarName, VarName)] -> VarName
render [] = VarName
""
          render [(VarName
v, VarName
d)]     = VarName
" \n  // LIKELIHOOD\n  "  VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
                                VarName
"target += " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
d VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
";\n"
          render ((VarName
v, VarName
d) : [(VarName, VarName)]
s) = VarName
"  " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" ~ " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
d VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
";\n" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)] -> VarName
render [(VarName, VarName)]
s

data Error = TypeError deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq)

instance Show Error where
  show :: Error -> VarName
show Error
TypeError = VarName
"Error: Term does not have type P r!"

stanShow :: Term -> String
stanShow :: Term -> VarName
stanShow v :: Term
v@(Var VarName
_)         = Term -> VarName
forall a. Show a => a -> VarName
show Term
v
stanShow x :: Term
x@(DCon Double
_)        = Term -> VarName
forall a. Show a => a -> VarName
show Term
x
stanShow (NormalCDF Term
x Term
y Term
z) = VarName
"normal_cdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
stanShow (Add Term
x (Neg Term
y))   = Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" - " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y
stanShow (Add Term
x Term
y)         = Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" + " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y
stanShow (Neg Term
x)           = VarName
"-" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x

lRender :: VarName -> Term -> String
lRender :: VarName -> Term -> VarName
lRender VarName
v (Truncate (Normal Term
x Term
y) Term
z Term
w) = VarName
"truncated_normal_lpdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
                                        VarName
" | " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
                                        VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
forall a. Show a => a -> VarName
show Term
w VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
lRender VarName
v (Normal Term
x Term
y) = VarName
"normal_lpdf(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
v VarName -> ShowS
forall a. [a] -> [a] -> [a]
++
                         VarName
" | " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
lRender VarName
v (Disj Term
x Term
y Term
z) = VarName
"log_mix(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName -> Term -> VarName
lRender VarName
v Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName -> Term -> VarName
lRender VarName
v Term
z VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")" 

pRender :: Term -> String
pRender :: Term -> VarName
pRender (Normal Term
x Term
y) = VarName
"normal(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
pRender (LogitNormal Term
x Term
y) = VarName
"logit_normal(" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
")"
pRender (Truncate Term
m Term
x Term
y) = Term -> VarName
pRender Term
m VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
" T[" VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
x VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
", " VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> VarName
stanShow Term
y VarName -> ShowS
forall a. [a] -> [a] -> [a]
++ VarName
"]"

toStan :: Term -> Writer [Error] Model
toStan :: Term -> Writer [Error] Model
toStan = \case
  Term
t         | Typed -> Maybe Type
typeOf (Sig -> Term -> Typed
ty Sig
tau0 Term
t) Maybe Type -> Maybe Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Type
P (VarName -> Type
Atom VarName
"r")) -> do
      [Error] -> WriterT [Error] Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Error
TypeError]
      Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VarName, VarName)] -> Model
Model [])
  t :: Term
t@(Let VarName
x Term
y Term
z) -> Term -> Writer [Error] Model
toStan' Term
t
    where toStan' :: Term -> Writer [Error] Model
toStan' (Let VarName
x Term
y Term
z) = do
            Model
yResult <- Term -> Writer [Error] Model
toStan' Term
y
            case Model
yResult of
              Model ((VarName
_, VarName
distr) : [(VarName, VarName)]
ys) -> do
                Model [(VarName, VarName)]
zs      <- Term -> Writer [Error] Model
toStan Term
z
                Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Model -> Writer [Error] Model) -> Model -> Writer [Error] Model
forall a b. (a -> b) -> a -> b
$ [(VarName, VarName)] -> Model
Model ((VarName
x, VarName
distr) (VarName, VarName) -> [(VarName, VarName)] -> [(VarName, VarName)]
forall a. a -> [a] -> [a]
: [(VarName, VarName)]
ys [(VarName, VarName)]
-> [(VarName, VarName)] -> [(VarName, VarName)]
forall a. [a] -> [a] -> [a]
++ [(VarName, VarName)]
zs)
          toStan' Term
result = Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Model -> Writer [Error] Model) -> Model -> Writer [Error] Model
forall a b. (a -> b) -> a -> b
$ [(VarName, VarName)] -> Model
Model [(VarName
"", Term -> VarName
pRender Term
result)]
  Term
result   -> do
    Model -> Writer [Error] Model
forall a. a -> WriterT [Error] Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VarName, VarName)] -> Model
Model [(VarName
"y", VarName -> Term -> VarName
lRender VarName
"y" Term
result)])