module Drasil.GOOL.Helpers (angles, doubleQuotedText, hicat, vicat, vibcat, 
  vmap, vimap, emptyIfEmpty, emptyIfNull, toCode, toState, onCodeValue, 
  onStateValue, on2CodeValues, on2StateValues, on3CodeValues, on3StateValues, 
  onCodeList, onStateList, on2StateLists, getInnerType, on2StateWrapped,
  getNestDegree
) where

import Utils.Drasil (blank)

import qualified Drasil.GOOL.CodeType as C (CodeType(..))

import Prelude hiding ((<>))
import Control.Applicative (liftA3)
import Control.Monad (liftM2, liftM3)
import Control.Monad.State (State)
import Data.List (intersperse)
import Text.PrettyPrint.HughesPJ (Doc, vcat, hcat, text, char, doubleQuotes, 
  (<>), empty, isEmpty)

angles :: Doc -> Doc
angles :: Doc -> Doc
angles Doc
d = Char -> Doc
char Char
'<' Doc -> Doc -> Doc
<> Doc
d Doc -> Doc -> Doc
<> Char -> Doc
char Char
'>'

doubleQuotedText :: String -> Doc
doubleQuotedText :: String -> Doc
doubleQuotedText = Doc -> Doc
doubleQuotes (Doc -> Doc) -> (String -> Doc) -> String -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Doc
text

hicat :: Doc -> [Doc] -> Doc
hicat :: Doc -> [Doc] -> Doc
hicat Doc
c [Doc]
l = [Doc] -> Doc
hcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse Doc
c [Doc]
l

vicat :: Doc -> [Doc] -> Doc
vicat :: Doc -> [Doc] -> Doc
vicat Doc
c = [Doc] -> Doc
vcat ([Doc] -> Doc) -> ([Doc] -> [Doc]) -> [Doc] -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse Doc
c ([Doc] -> [Doc]) -> ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Doc -> Bool) -> [Doc] -> [Doc]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Doc -> Bool) -> Doc -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc -> Bool
isEmpty)

vibcat :: [Doc] -> Doc
vibcat :: [Doc] -> Doc
vibcat = Doc -> [Doc] -> Doc
vicat Doc
blank

vmap :: (a -> Doc) -> [a] -> Doc
vmap :: forall a. (a -> Doc) -> [a] -> Doc
vmap a -> Doc
f = [Doc] -> Doc
vcat ([Doc] -> Doc) -> ([a] -> [Doc]) -> [a] -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Doc) -> [a] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map a -> Doc
f

vimap :: Doc -> (a -> Doc) -> [a] -> Doc
vimap :: forall a. Doc -> (a -> Doc) -> [a] -> Doc
vimap Doc
c a -> Doc
f = Doc -> [Doc] -> Doc
vicat Doc
c ([Doc] -> Doc) -> ([a] -> [Doc]) -> [a] -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Doc) -> [a] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map a -> Doc
f

emptyIfEmpty :: Doc -> Doc -> Doc
emptyIfEmpty :: Doc -> Doc -> Doc
emptyIfEmpty Doc
ifDoc Doc
elseDoc = if Doc -> Bool
isEmpty Doc
ifDoc then Doc
empty else Doc
elseDoc

emptyIfNull :: [a] -> Doc -> Doc
emptyIfNull :: forall a. [a] -> Doc -> Doc
emptyIfNull [a]
lst Doc
elseDoc = if [a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
lst then Doc
empty else Doc
elseDoc

toCode :: (Monad r) => a -> r a
toCode :: forall (r :: * -> *) a. Monad r => a -> r a
toCode = a -> r a
forall a. a -> r a
forall (m :: * -> *) a. Monad m => a -> m a
return

toState :: a -> State s a
toState :: forall a s. a -> State s a
toState = a -> StateT s Identity a
forall a. a -> StateT s Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return

onCodeValue :: (Functor r) => (a -> b) -> r a -> r b
onCodeValue :: forall (r :: * -> *) a b. Functor r => (a -> b) -> r a -> r b
onCodeValue = (a -> b) -> r a -> r b
forall a b. (a -> b) -> r a -> r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

onStateValue :: (a -> b) -> State s a -> State s b
onStateValue :: forall a b s. (a -> b) -> State s a -> State s b
onStateValue = (a -> b) -> StateT s Identity a -> StateT s Identity b
forall a b. (a -> b) -> StateT s Identity a -> StateT s Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

on2CodeValues :: (Applicative r) => (a -> b -> c) -> r a -> r b -> 
  r c
on2CodeValues :: forall (r :: * -> *) a b c.
Applicative r =>
(a -> b -> c) -> r a -> r b -> r c
on2CodeValues = (a -> b -> c) -> r a -> r b -> r c
forall a b c. (a -> b -> c) -> r a -> r b -> r c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2

on2StateValues :: (a -> b -> c) -> State s a -> State s b -> State s c
on2StateValues :: forall a b c s.
(a -> b -> c) -> State s a -> State s b -> State s c
on2StateValues = (a -> b -> c)
-> StateT s Identity a
-> StateT s Identity b
-> StateT s Identity c
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2

on3CodeValues :: (Applicative r) => (a -> b -> c -> d) -> r a -> r b 
  -> r c -> r d
on3CodeValues :: forall (r :: * -> *) a b c d.
Applicative r =>
(a -> b -> c -> d) -> r a -> r b -> r c -> r d
on3CodeValues = (a -> b -> c -> d) -> r a -> r b -> r c -> r d
forall (r :: * -> *) a b c d.
Applicative r =>
(a -> b -> c -> d) -> r a -> r b -> r c -> r d
liftA3

on3StateValues :: (a -> b -> c -> d) -> State s a -> State s b -> State s c ->
  State s d
on3StateValues :: forall a b c d s.
(a -> b -> c -> d)
-> State s a -> State s b -> State s c -> State s d
on3StateValues = (a -> b -> c -> d)
-> StateT s Identity a
-> StateT s Identity b
-> StateT s Identity c
-> StateT s Identity d
forall (m :: * -> *) a1 a2 a3 r.
Monad m =>
(a1 -> a2 -> a3 -> r) -> m a1 -> m a2 -> m a3 -> m r
liftM3

onCodeList :: Monad m => ([a] -> b) -> [m a] -> m b
onCodeList :: forall (m :: * -> *) a b. Monad m => ([a] -> b) -> [m a] -> m b
onCodeList [a] -> b
f [m a]
as = [a] -> b
f ([a] -> b) -> m [a] -> m b
forall (r :: * -> *) a b. Functor r => (a -> b) -> r a -> r b
<$> [m a] -> m [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [m a]
as

onStateList :: ([a] -> b) -> [State s a] -> State s b
onStateList :: forall a b s. ([a] -> b) -> [State s a] -> State s b
onStateList [a] -> b
f [State s a]
as = [a] -> b
f ([a] -> b) -> StateT s Identity [a] -> StateT s Identity b
forall (r :: * -> *) a b. Functor r => (a -> b) -> r a -> r b
<$> [State s a] -> StateT s Identity [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [State s a]
as

on2StateLists :: ([a] -> [b] -> c) -> [State s a] -> [State s b] -> State s c
on2StateLists :: forall a b c s.
([a] -> [b] -> c) -> [State s a] -> [State s b] -> State s c
on2StateLists [a] -> [b] -> c
f [State s a]
as [State s b]
bs = ([a] -> [b] -> c)
-> StateT s Identity [a]
-> StateT s Identity [b]
-> StateT s Identity c
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [a] -> [b] -> c
f ([State s a] -> StateT s Identity [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [State s a]
as) ([State s b] -> StateT s Identity [b]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [State s b]
bs)

on2StateWrapped :: (Monad m) => (a -> b -> m c) -> m a -> m b -> m c
on2StateWrapped :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> m a -> m b -> m c
on2StateWrapped a -> b -> m c
f m a
a' m b
b' = do 
    a
a <- m a
a'
    b
b <- m b
b'
    a -> b -> m c
f a
a b
b 

getInnerType :: C.CodeType -> C.CodeType
getInnerType :: CodeType -> CodeType
getInnerType (C.List CodeType
innerT) = CodeType
innerT
getInnerType (C.Array CodeType
innerT) = CodeType
innerT
getInnerType (C.Set CodeType
innerT) = CodeType
innerT
getInnerType CodeType
_ = String -> CodeType
forall a. HasCallStack => String -> a
error String
"Attempt to extract inner type from a non-nested type"

getNestDegree :: Integer -> C.CodeType -> Integer
getNestDegree :: Integer -> CodeType -> Integer
getNestDegree Integer
n (C.List CodeType
t) = Integer -> CodeType -> Integer
getNestDegree (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1) CodeType
t
getNestDegree Integer
n CodeType
_ = Integer
n