{-# LANGUAGE RecordWildCards #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Internal
  ( -- * Create new constraints

    TcPlugin.newWanted
  , newGiven
    -- * Creating evidence

  , evByFiat
    -- * Lookup

  , lookupModule
  , lookupName
    -- * Trace state of the plugin

  , tracePlugin
    -- * Substitutions

  , flattenGivens
  , mkSubst
  , mkSubst'
  , substType
  , substCt
  )
where

import GHC.Driver.Config.Finder (initFinderOpts)
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
  (newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginSolveResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)

import GhcApi.Constraint (Ct(..))
import GhcApi.GhcPlugins

import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)

-- | Find a module

lookupModule :: ModuleName -- ^ Name of the module

             -> FastString -- ^ Name of the package containing the module.

                           -- NOTE: This value is ignored on ghc>=8.0.

             -> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
  HscEnv
hsc_env <- TcPluginM HscEnv
TcPlugin.getTopEnv
  let fc :: FinderCache
fc         = HscEnv -> FinderCache
hsc_FC HscEnv
hsc_env
      dflags :: DynFlags
dflags     = HscEnv -> DynFlags
hsc_dflags HscEnv
hsc_env
      fopts :: FinderOpts
fopts      = DynFlags -> FinderOpts
initFinderOpts DynFlags
dflags
      units :: UnitState
units      = (() :: Constraint) => HscEnv -> UnitState
HscEnv -> UnitState
hsc_units HscEnv
hsc_env
      mhome_unit :: Maybe HomeUnit
mhome_unit = HscEnv -> Maybe HomeUnit
hsc_home_unit_maybe HscEnv
hsc_env
  FindResult
found_module <- IO FindResult -> TcPluginM FindResult
forall a. IO a -> TcPluginM a
TcPlugin.tcPluginIO (IO FindResult -> TcPluginM FindResult)
-> IO FindResult -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ FinderCache
-> FinderOpts
-> UnitState
-> Maybe HomeUnit
-> ModuleName
-> IO FindResult
findPluginModule FinderCache
fc FinderOpts
fopts UnitState
units
                                          Maybe HomeUnit
mhome_unit ModuleName
mod_nm
  case FindResult
found_module of
    Found ModLocation
_ Module
h -> Module -> TcPluginM Module
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return Module
h
    FindResult
_ -> do
      let pkg_qual :: PkgQual
pkg_qual = PkgQual -> (HomeUnit -> PkgQual) -> Maybe HomeUnit -> PkgQual
forall b a. b -> (a -> b) -> Maybe a -> b
maybe PkgQual
NoPkgQual (UnitId -> PkgQual
ThisPkg (UnitId -> PkgQual) -> (HomeUnit -> UnitId) -> HomeUnit -> PkgQual
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HomeUnit -> UnitId
forall u. GenHomeUnit u -> UnitId
homeUnitId) Maybe HomeUnit
mhome_unit
      FindResult
found_module' <- ModuleName -> PkgQual -> TcPluginM FindResult
TcPlugin.findImportedModule ModuleName
mod_nm PkgQual
pkg_qual
      case FindResult
found_module' of
        Found ModLocation
_ Module
h -> Module -> TcPluginM Module
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return Module
h
        FindResult
_ -> String -> SDoc -> TcPluginM Module
forall a. String -> SDoc -> a
panicDoc String
"Couldn't find module" (ModuleName -> SDoc
forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)

-- | Find a 'Name' in a 'Module' given an 'OccName'

lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = Module -> OccName -> TcPluginM Name
lookupOrig

-- | Print out extra information about the initialisation, stop, and every run

-- of the plugin when @-ddump-tc-trace@ is enabled.

tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> UniqFM TyCon TcPluginRewriter
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: TcPluginM s
tcPluginSolve :: s -> TcPluginSolver
tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter
tcPluginStop :: s -> TcPluginM ()
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginRewrite :: ()
tcPluginStop :: ()
..} = TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit    = TcPluginM s
traceInit
                                      , tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve   = s -> TcPluginSolver
traceSolve
                                      , tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = s -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite
                                      , tcPluginStop :: s -> TcPluginM ()
tcPluginStop    = s -> TcPluginM ()
traceStop
                                      }
  where
    traceInit :: TcPluginM s
traceInit = do
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
forall doc. IsOutput doc => doc
empty TcPluginM () -> TcPluginM s -> TcPluginM s
forall a b. TcPluginM a -> TcPluginM b -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit

    traceStop :: s -> TcPluginM ()
traceStop  s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
forall doc. IsOutput doc => doc
empty TcPluginM () -> TcPluginM () -> TcPluginM ()
forall a b. TcPluginM a -> TcPluginM b -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z

    traceSolve :: s -> TcPluginSolver
traceSolve s
z EvBindsVar
ev [Ct]
given [Ct]
wanted = do
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                        (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"given   =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
given
                      SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"wanted  =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
      TcPluginSolveResult
r <- s -> TcPluginSolver
tcPluginSolve s
z EvBindsVar
ev [Ct]
given [Ct]
wanted
      case TcPluginSolveResult
r of
        TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
          -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                           (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"solved =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
                         SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"new    =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
        TcPluginContradiction [Ct]
bad
          -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve contradiction " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                           (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"bad =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
        TcPluginSolveResult [Ct]
bad [(EvTerm, Ct)]
solved [Ct]
new
          -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolveResult " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
                           (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"solved =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
                         SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"bad    =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad
                         SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"new    =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
      TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginSolveResult
r

-- | Flattens evidence of constraints by substituting each others equalities.

--

-- __NB:__ Should only be used on /[G]iven/ constraints!

--

-- __NB:__ Doesn't flatten under binders

flattenGivens :: [Ct] -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
  ([((TcTyVar, TcType), Ct)] -> Maybe Ct)
-> [[((TcTyVar, TcType), Ct)]] -> [Ct]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, TcType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, TcType), Ct)]]
flat [Ct] -> [Ct] -> [Ct]
forall a. [a] -> [a] -> [a]
++ (Ct -> Ct) -> [Ct] -> [Ct]
forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, TcType)] -> Ct -> Ct
substCt [(TcTyVar, TcType)]
subst') [Ct]
givens
 where
  subst :: [((TcTyVar, TcType), Ct)]
subst = [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' [Ct]
givens
  ([[((TcTyVar, TcType), Ct)]]
flat,[(TcTyVar, TcType)]
subst')
    = ([[((TcTyVar, TcType), Ct)]] -> [(TcTyVar, TcType)])
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
-> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)])
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst ([((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)])
-> ([[((TcTyVar, TcType), Ct)]] -> [((TcTyVar, TcType), Ct)])
-> [[((TcTyVar, TcType), Ct)]]
-> [(TcTyVar, TcType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[((TcTyVar, TcType), Ct)]] -> [((TcTyVar, TcType), Ct)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
    (([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
 -> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)]))
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
-> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)])
forall a b. (a -> b) -> a -> b
$ ([((TcTyVar, TcType), Ct)] -> Bool)
-> [[((TcTyVar, TcType), Ct)]]
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2) (Int -> Bool)
-> ([((TcTyVar, TcType), Ct)] -> Int)
-> [((TcTyVar, TcType), Ct)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((TcTyVar, TcType), Ct)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length)
    ([[((TcTyVar, TcType), Ct)]]
 -> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]]))
-> [[((TcTyVar, TcType), Ct)]]
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct) -> Bool)
-> [((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (TcTyVar -> TcTyVar -> Bool
forall a. Eq a => a -> a -> Bool
(==) (TcTyVar -> TcTyVar -> Bool)
-> (((TcTyVar, TcType), Ct) -> TcTyVar)
-> ((TcTyVar, TcType), Ct)
-> ((TcTyVar, TcType), Ct)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((TcTyVar, TcType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, TcType) -> TcTyVar)
-> (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst))
    ([((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]])
-> [((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]]
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, TcType), Ct) -> TcTyVar)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn ((TcTyVar, TcType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, TcType) -> TcTyVar)
-> (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst) [((TcTyVar, TcType), Ct)]
subst

-- | Create flattened substitutions from type equalities, i.e. the substitutions

-- have been applied to each others right hand sides.

mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' = (((TcTyVar, TcType), Ct)
 -> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)])
-> [((TcTyVar, TcType), Ct)]
-> [((TcTyVar, TcType), Ct)]
-> [((TcTyVar, TcType), Ct)]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
substSubst [] ([((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)])
-> ([Ct] -> [((TcTyVar, TcType), Ct)])
-> [Ct]
-> [((TcTyVar, TcType), Ct)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> Maybe ((TcTyVar, TcType), Ct))
-> [Ct] -> [((TcTyVar, TcType), Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, TcType), Ct)
mkSubst
 where
  substSubst :: ((TcTyVar,TcType),Ct)
             -> [((TcTyVar,TcType),Ct)]
             -> [((TcTyVar,TcType),Ct)]
  substSubst :: ((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
substSubst ((TcTyVar
tv,TcType
t),Ct
ct) [((TcTyVar, TcType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, TcType)] -> TcType -> TcType
substType ((((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst [((TcTyVar, TcType), Ct)]
s) TcType
t),Ct
ct)
                           ((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall a. a -> [a] -> [a]
: (((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct))
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall a b. (a -> b) -> [a] -> [b]
map (((TcTyVar, TcType) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((TcType -> TcType) -> (TcTyVar, TcType) -> (TcTyVar, TcType)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar
tv,TcType
t)]))) [((TcTyVar, TcType), Ct)]
s

-- | Apply substitution in the evidence of Cts

substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt [(TcTyVar, TcType)]
subst = (TcType -> TcType) -> Ct -> Ct
overEvidencePredType ([(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar, TcType)]
subst)