diff --git a/hls-graph/src/Development/IDE/Graph/Database.hs b/hls-graph/src/Development/IDE/Graph/Database.hs index 0c06a1766c..961a10b98c 100644 --- a/hls-graph/src/Development/IDE/Graph/Database.hs +++ b/hls-graph/src/Development/IDE/Graph/Database.hs @@ -24,8 +24,7 @@ import Control.Concurrent.Extra (Barrier, newBarrier, waitBarrierMaybe) import Control.Concurrent.STM.Stats (atomically, atomicallyNamed, - readTVar, readTVarIO, - writeTVar) + readTVarIO) import Control.Exception (SomeException, try) import Control.Monad (join, unless, void) import Control.Monad.IO.Class (liftIO) diff --git a/hls-graph/src/Development/IDE/Graph/Internal/Database.hs b/hls-graph/src/Development/IDE/Graph/Internal/Database.hs index 56482fedb1..950289459a 100644 --- a/hls-graph/src/Development/IDE/Graph/Internal/Database.hs +++ b/hls-graph/src/Development/IDE/Graph/Internal/Database.hs @@ -17,7 +17,8 @@ import Control.Concurrent.STM.Stats (STM, atomicallyNamed, modifyTVar', newTQueueIO, newTVarIO, readTVar, - readTVarIO, retry) + readTVarIO, retry, writeTVar) +import Control.Concurrent.Async (mapConcurrently) import Control.Exception import Control.Monad import Control.Monad.IO.Class (MonadIO (liftIO)) @@ -480,32 +481,55 @@ transitiveDirtyListBottomUp database seeds = do void $ State.runStateT (traverse_ go seeds) mempty readIORef acc --- the lefts are keys that are no longer affected, we can try to mark them clean --- the rights are new affected keys, we need to mark them dirty +-- | A concurrent variant of 'transitiveDirtyListBottomUp' that computes the difference +-- between two sets of affected keys. +-- +-- Returns: +-- * Right keys: newly affected keys that need to be marked dirty +-- * Left keys: previously affected keys that are no longer affected (can be marked clean) +-- +-- The function traverses the reverse-dependency graph concurrently, processing independent +-- branches in parallel while maintaining bottom-up ordering (dependencies before dependents). +-- This improves performance on large dependency graphs by utilizing multiple cores. +-- +-- Thread-safety is ensured by: +-- * Using TVar for shared state (visited set and accumulator) +-- * Atomic check-and-mark for the visited set +-- * mapConcurrently for parallel traversal of independent branches transitiveDirtyListBottomUpDiff :: Database -> [Key] -> [Key] -> IO [Either Key Key] transitiveDirtyListBottomUpDiff database seeds lastSeeds = do - acc <- newIORef [] - let go1 x = do - seen <- State.get - if x `memberKeySet` seen - then pure () - else do - State.put (insertKeySet x seen) - mnext <- lift $ atomically $ getRunTimeRDeps database x - traverse_ go1 (maybe mempty toListKeySet mnext) - lift $ modifyIORef' acc (Right x :) - let go2 x = do - seen <- State.get - if x `memberKeySet` seen - then pure () - else do - State.put (insertKeySet x seen) - mnext <- lift $ atomically $ getRunTimeRDeps database x - traverse_ go2 (maybe mempty toListKeySet mnext) - lift $ modifyIORef' acc (Left x :) - -- traverse all seeds - void $ State.runStateT (do traverse_ go1 seeds; traverse_ go2 lastSeeds) mempty - readIORef acc + -- Use TVars for thread-safe concurrent access + accTVar <- newTVarIO [] + seenTVar <- newTVarIO mempty + + let -- Process a key and its dependencies concurrently + go :: (Key -> Either Key Key) -> Key -> IO () + go wrapper x = do + alreadySeen <- atomically $ do + seen <- readTVar seenTVar + if x `memberKeySet` seen + then pure True + else do + writeTVar seenTVar (insertKeySet x seen) + pure False + + unless alreadySeen $ do + -- Fetch dependencies + mnext <- atomically $ getRunTimeRDeps database x + let deps = maybe [] toListKeySet mnext + + -- Process dependencies concurrently + unless (null deps) $ do + void $ mapConcurrently (go wrapper) deps + + -- Add this key to accumulator after all dependencies are processed + atomically $ modifyTVar' accTVar (wrapper x :) + + -- Process new seeds (Right) and old seeds (Left) concurrently + void $ mapConcurrently (go Right) seeds + void $ mapConcurrently (go Left) lastSeeds + + readTVarIO accTVar -- | Original spawnRefresh using the general pattern