Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Util (log2)

import Control.Monad.Freer (req)
import Data.Bifunctor
import Data.Foldable (foldrM)
import Data.Type.Equality (TestEquality(..), (:~:)(..))
import qualified Data.Map as M
import Prelude hiding (last)
Expand Down Expand Up @@ -247,22 +248,19 @@ getThunks :: Modey m
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Braty (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
-- These shouldn't happen
(_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
v -> typeErr $ "Force called on non-thunk: " ++ show v
getThunks Kerny row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Kerny (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
(_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v
getThunks Braty ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Braty (src, ty)
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Kerny ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Kerny (src,ty)
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Braty ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Some (_ :* stk) -> do
Expand All @@ -274,15 +272,15 @@ getThunks Braty ((src, Left (Star args)):rest) = do
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,Some (Modey :* Flip CTy Z) -- The function type at the end
)
vecLayers (TVec ty (VNum n)) = do
vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,CTy m Z -- The function type at the end
)
vecLayers my (TVec ty (VNum n)) = do
src <- mkStaticNum n
(layers, fun) <- vecLayers ty
pure ((src, n):layers, fun)
vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty))
vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty
first ((src, n):) <$> vecLayers my ty
vecLayers Braty (VFun Braty cty) = pure ([], cty)
vecLayers Kerny (VFun Kerny cty) = pure ([], cty)
vecLayers my ty = typeErr $ "Expected a " ++ showMode my ++ "function or vector of functions, got " ++ show ty

mkStaticNum :: NumVal (VVar Z) -> Checking Src
mkStaticNum n@(NumValue c gro) = do
Expand Down Expand Up @@ -330,27 +328,29 @@ mkStaticNum n@(NumValue c gro) = do
wire (oneSrc, TNat, rhs)
pure src

vectorise :: (Src, Val Z) -> Checking (Src, Val Z)
vectorise (src, ty) = do
(layers, Some (my :* Flip cty)) <- vecLayers ty
modily my $ mkMapFuns (src, VFun my cty) layers
vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z)
vectorise my (src, ty) = do
(layers, cty) <- vecLayers my ty
modily my $ foldrM mkMapFun (src, cty) layers
where
mkMapFuns :: (Src, Val Z) -- The input to the mapfun
-> [(Src, NumVal (VVar Z))] -- Remaining layers
-> Checking (Src, Val Z)
mkMapFuns over [] = pure over
mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do
(valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers
mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply
-> (Src, CTy m Z) -- The input to this level of mapfun
-> Checking (Src, CTy m Z)
mkMapFun (lenSrc, len) (valSrc, cty) = do
let weak1 = changeVar (Thinning (ThDrop ThNull))
vecFun <- vectorisedFun len my cty
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <-
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right (VFun my' cty))], _) <-
next "" MapFun (S0, Some (Zy :* S0))
(REx ("len", Nat) (RPr ("value", weak1 ty) R0))
(RPr ("vector", weak1 vecFun) R0)
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
pure (vectorSrc, vecTy)
let vecCTy = case (my,my',cty) of
(Braty,Braty,cty) -> cty
(Kerny,Kerny,cty) -> cty
_ -> error "next returned wrong mode of computation type to that passed in"
pure (vectorSrc, vecCTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
Expand Down
4 changes: 1 addition & 3 deletions brat/test/golden/kernel/kernel_application.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ Error in test/golden/kernel/kernel_application.brat on line 16:
rotate = { q => maybeRotate(true) }
^^^^^^^^^^^

Expected function to be a (kernel) thunk, but found:
(thunk :: { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) })

Type error: Expected a (kernel) function or vector of functions, got { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }