------------------------------------------------------
-- Functional language with side effects "ML-style"
------------------------------------------------------

-- variables are just names
type Var = String


-- locations are just numbers (addresses)
type Location = Int


-- values are integers, functions, and references
data Value = ValInt Integer
           | ValFunc (Value -> Cmpt Value)
           | ValRef Location
           | ValError String

instance Show Value where
  show (ValInt i) = show i
  show (ValFunc _) = "function"
  show (ValRef r) = "ref " ++ show r
  show (ValError s) = "error: " ++ s



-- an Environment maps variables to Values
type Env = Var -> Value

emptyEnv :: Env
emptyEnv v = ValError ("undef var " ++ v)


-- a Memory maps locations to Values
type Mem = Location -> Value

emptyMem :: Mem
emptyMem loc = ValError ("undef loc " ++ show loc)


-- auxiliary function to map Values to Booleans
isTrue :: Value -> Bool
isTrue (ValInt 0) = False
isTrue _ = True

type Result = Value

type K a = Mem -> a -> Result

-- the "side-effect" of an expression
type Cmpt a = Mem -> K a -> Result

-- transforms a value into a computation
unit :: a -> Cmpt a         -- op0
unit v = \m k -> k m v

-- executes an unary operation on computations
bind :: Cmpt a -> (a -> Cmpt b) -> Cmpt b
bind ca f = \m k -> ca m (\m' a -> f a m' k)


cerror :: String -> Cmpt a
cerror s = \m -> \k -> ValError s


query :: (a -> b) -> a -> b
query f x = f x

update :: Eq a => a -> b -> (a -> b) -> (a -> b)
update k v m = \x -> if x == k then v else m x


queryM :: Location -> Cmpt Value
queryM loc = \m k -> k m (query m loc)


updateM :: Location -> Value -> Cmpt Value
updateM loc v = \m k -> k (update loc v m) v


-- return a free location in a memory
freeM :: Cmpt Location
freeM m k = k m (freeLoc 0)
  where isError (ValError _) = True
        isError (_) = False
        freeLoc i = if isError (m i) then i else freeLoc (i + 1)


-- bind a new value in a variable
bindVar :: Var -> Value -> Env -> Env
bindVar var val env = update var val env


-- executes a binary operation on computations
op2 :: (a -> b -> Cmpt c) -> Cmpt a -> Cmpt b -> Cmpt c
op2 op ca cb = bind ca (\a -> bind cb (op a))


-- executes a binary integer operation on computations
arith :: (Integer -> Integer -> Integer) ->
             Cmpt Value -> Cmpt Value -> Cmpt Value
arith op = op2 op_aux
  where op_aux (ValInt i1) (ValInt i2) = unit (ValInt (op i1 i2))
        op_aux _ _ = cerror "binary operation over non-int value"


--------------------------------------------------------------------
-- Abstract Syntax Tree for Expressions
data Exp = ExpK Integer          -- constants
         | ExpVar Var            -- variables
         | ExpAdd Exp Exp        -- e1 + e2
         | ExpSub Exp Exp        -- e1 - e2
         | ExpMul Exp Exp        -- e1 * e2
         | ExpDiv Exp Exp        -- e1 / e2
         | ExpAssg Exp Exp       -- e1 := e2
         | ExpDeref Exp          -- !e
         | ExpNewref Exp         -- ref e
         | ExpIf Exp Exp Exp     -- if e1 then e2 else e3
         | ExpApp Exp Exp        -- e1 e2
         | ExpLambda Var Exp     -- \x -> e
         | ExpLetrec Var Var Exp Exp    -- letrec x=e1 in e2
         | ExpSeq Exp Exp        -- e1; e2
         | ExpWhile Exp Exp Exp  -- while e1 do e2 else e3
             deriving Show


closure :: Var -> Exp -> Env -> Value
closure v e env = ValFunc f
  where f x =  evalExp e (bindVar v x env)


-- Evaluates an expression in a given environment
evalExp :: Exp -> Env -> Cmpt Value

evalExp (ExpK i) env = unit (ValInt i)
evalExp (ExpVar var) env = unit (query env var)

evalExp (ExpDeref e) env = bind (evalExp e env) aux
  where aux (ValRef loc) = queryM loc
        aux v = cerror ("unref a non-reference value:" ++ show v)

evalExp (ExpAssg e1 e2) env = op2 aux (evalExp e1 env) (evalExp e2 env)
  where aux (ValRef loc) v = updateM loc v
        aux _ _ = cerror "unref a non-reference value"


evalExp (ExpAdd e1 e2) env = arith (+) (evalExp e1 env) (evalExp e2 env)
evalExp (ExpSub e1 e2) env = arith (-) (evalExp e1 env) (evalExp e2 env)
evalExp (ExpMul e1 e2) env = arith (*) (evalExp e1 env) (evalExp e2 env)
evalExp (ExpDiv e1 e2) env = arith div (evalExp e1 env) (evalExp e2 env)

evalExp (ExpIf e1 e2 e3) env =
  bind (evalExp e1 env) (\b -> if isTrue b then evalExp e2 env
                                           else evalExp e3 env)

evalExp (ExpApp e1 e2) env = op2 app (evalExp e1 env) (evalExp e2 env)
  where app (ValFunc f) vp = f vp
        app _ _ = cerror "attempt to call a non-function value"

evalExp (ExpLambda v e) env = unit (closure v e env)

evalExp (ExpLetrec v v' e' e) env = evalExp e env'
  where env' = bindVar v (closure v' e' env') env

evalExp (ExpSeq e1 e2) env = op2 seq (evalExp e1 env) (evalExp e2 env)
  where seq v1 v2 = unit v2

evalExp (ExpWhile e1 e2 e3) env = w
  where w = bind (evalExp e1 env) (\b ->
              if isTrue b then bind (evalExp e2 env) (\_ -> w)
                          else evalExp e3 env)

evalExp (ExpNewref e) env =
  bind (evalExp e env) (\v ->
    bind freeM (\loc ->
      bind (updateM loc v) (\_ -> unit (ValRef loc))))


-- "syntax sugar": let var = e1 in e2  ⇒  (\var -> e2)(e1)
expLet :: Var -> Exp -> Exp -> Exp
expLet var def body = ExpApp (ExpLambda var body) def

--------------------------------------------------------------------------
-------------------------------------------------------------------

-- examples

-- 10!
-- letrec f = (\x -> if x then x * f(x - 1) else 1) in f 10
fact10 = ExpLetrec "f" "x"
      (ExpIf (ExpVar "x")
             (ExpMul (ExpVar "x")
                     (ExpApp (ExpVar "f") (ExpSub (ExpVar "x") (ExpK 1))))
             (ExpK 1))
      (ExpApp (ExpVar "f") (ExpK 10))


getRef s = ExpDeref (ExpVar s)
setRef s v = ExpAssg (ExpVar s) v

{-
-- compute 100!
let n = ref 100 in
let r = ref 1 in
  while !n do (
    r = !r * !n;
    n = !n - 1 )
  else !r
-}

fact100 =
  expLet "n" (ExpNewref (ExpK 100))
 (expLet "r" (ExpNewref (ExpK 1))
    (ExpWhile (getRef "n")
              (ExpSeq
                (setRef "r" (ExpMul (getRef "r") (getRef "n")))
                (setRef "n" (ExpSub (getRef "n") (ExpK 1))))
              (getRef "r")))


finalResult e = evalExp e emptyEnv emptyMem (\m -> \v -> v)

-- code to show the final value of an expression
main :: IO ()
main = print (finalResult fact100)