------------------------------------------------------
-- Functional language with side effects "ML-style"
------------------------------------------------------

-- variables are just names
type Var = String


-- locations are just numbers (addresses)
type Location = Int


-- values are integers, functions, and references
data Value = ValInt Integer
           | ValFunc (Value -> Cmpt)
           | ValRef Location
           | ValError String

instance Show Value where
  show (ValInt i) = show i
  show (ValFunc _) = "function"
  show (ValRef r) = "ref " ++ show r
  show (ValError s) = "error: " ++ s



-- an Environment maps variables to Values
type Env = Var -> Value

emptyEnv :: Env
emptyEnv v = ValError ("undef var " ++ v)


-- a Memory is a list of Values
type Mem = [Value]

emptyMem :: Mem
emptyMem = []

set :: Mem -> Int -> Value -> Mem
set (x:xs) 0 y = y:xs
set (x:xs) n y = x : set xs (n - 1) y

get :: Mem -> Int -> Value
get (x:xs) 0 = x
get (x:xs) n = get xs (n - 1)

newref :: Mem -> Value -> (Int, Mem)
newref xs x = (length xs, xs ++ [x])


-- auxiliary function to map Values to Booleans
isTrue :: Value -> Bool
isTrue (ValInt i) = (i /= 0)


-- the "side-effect" of an expression
type Cmpt = Mem -> (Value, Mem)


cerror :: String -> Cmpt
cerror s = \m -> (ValError s, m)


query :: (a -> b) -> a -> b
query f x = f x

update :: Eq a => a -> b -> (a -> b) -> (a -> b)
update k v m = \x -> if x == k then v else m x



-- executes a binary operation on computations
op2 :: (Value -> Value -> Cmpt) -> Cmpt -> Cmpt -> Cmpt
op2 op ca cb m = 
  let (v1, m1) = ca m in
    let (v2, m2) = cb m1 in
      op v1 v2 m2


-- executes a binary integer operation on computations
arith :: (Integer -> Integer -> Integer) -> Cmpt -> Cmpt -> Cmpt
arith op = op2 op_aux
  where op_aux (ValInt i1) (ValInt i2) m = ((ValInt (op i1 i2)), m)
        op_aux _ _ m = cerror "binary operation over non-int value" m


--------------------------------------------------------------------
-- Abstract Syntax Tree for Expressions
data Exp = ExpK Integer          -- constants
         | ExpVar Var            -- variables
         | ExpAdd Exp Exp        -- e1 + e2
         | ExpSub Exp Exp        -- e1 - e2
         | ExpMul Exp Exp        -- e1 * e2
         | ExpDiv Exp Exp        -- e1 / e2
         | ExpAssg Exp Exp       -- e1 := e2
         | ExpDeref Exp          -- !e
         | ExpNewref Exp         -- ref e
         | ExpIf Exp Exp Exp     -- if e1 then e2 else e3
         | ExpApp Exp Exp        -- e1 e2
         | ExpLambda Var Exp     -- \x -> e
         | ExpSeq Exp Exp        -- e1; e2
         | ExpWhile Exp Exp      -- while e1 do e2
             deriving Show


closure :: Var -> Exp -> Env -> Value
closure v e env = ValFunc f
  where f x =  evalExp e (update v x env)


-- Evaluates an expression in a given environment
evalExp :: Exp -> Env -> Cmpt

evalExp (ExpK i) env m = (ValInt i, m)

evalExp (ExpVar var) env m = (query env var, m)

evalExp (ExpDeref e) env m = 
  let (v, m') = evalExp e env m in
    case v of
      ValRef loc -> (get m' loc, m')
      _ -> cerror "unref a non-reference value" m'


evalExp (ExpAssg e1 e2) env m = op2 aux (evalExp e1 env) (evalExp e2 env) m
  where aux (ValRef loc) v m' = (v, set m' loc v)
        aux _ _ m = cerror "unref a non-reference value" m


evalExp (ExpNewref e) env m =
  let (v, m') = evalExp e env m in
    let (ref, m'') = newref m' v in
      (ValRef ref, m'')


evalExp (ExpAdd e1 e2) env m = arith (+) (evalExp e1 env) (evalExp e2 env) m
evalExp (ExpSub e1 e2) env m = arith (-) (evalExp e1 env) (evalExp e2 env) m
evalExp (ExpMul e1 e2) env m = arith (*) (evalExp e1 env) (evalExp e2 env) m
evalExp (ExpDiv e1 e2) env m = arith div (evalExp e1 env) (evalExp e2 env) m


evalExp (ExpIf e1 e2 e3) env m =
  let (v1, m1) = evalExp e1 env m in
    if isTrue v1 then evalExp e2 env m1
                 else evalExp e3 env m1


evalExp (ExpApp e1 e2) env m = op2 app (evalExp e1 env) (evalExp e2 env) m
  where app (ValFunc f) vp = f vp
        app _ _ = cerror "attempt to call a non-function value"


evalExp (ExpLambda v e) env m = (closure v e env, m)



evalExp (ExpSeq e1 e2) env m = op2 seq (evalExp e1 env) (evalExp e2 env) m
  where seq v1 v2 m = (v2, m)

evalExp (ExpWhile e1 e2) env m = w m
  where w m = let (v1, m1) = evalExp e1 env m in
                if isTrue v1 then
                  let (_, m2) = evalExp e2 env m1 in
                     w m2
                else (ValInt 0, m1)


-- "syntax sugar": let var = e1 in e2  ⇒  (\var -> e2)(e1)
expLet :: Var -> Exp -> Exp -> Exp
expLet var def body = ExpApp (ExpLambda var body) def



--------------------------------------------------------------------------
-------------------------------------------------------------------

-- examples


getRef s = ExpDeref (ExpVar s)
setRef s v = ExpAssg (ExpVar s) v

{-
-- compute 100!
let n = ref 100 in
let r = ref 1 in
  while !n do (
    r := !r * !n;
    n := !n - 1 );
  !r
-}

fact100 =
  expLet "n" (ExpNewref (ExpK 100))
 (expLet "r" (ExpNewref (ExpK 1))
    (ExpSeq
      (ExpWhile (getRef "n")
                (ExpSeq
                  (setRef "r" (ExpMul (getRef "r") (getRef "n")))
                  (setRef "n" (ExpSub (getRef "n") (ExpK 1)))))
      (getRef "r")))


finalResult e = fst (evalExp e emptyEnv emptyMem)


-- code to show the final value of an expression
main :: IO ()
main = print (finalResult fact100)