Skip to content

Commit a12451e

Browse files
committed
Add torch.manual_seed for test_fault_tolerance
1 parent f151e23 commit a12451e

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_allreduce_fault_tolerance.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from enum import Enum, auto
45

56
import pytest
7+
import torch
68

79
import hivemind
8-
from hivemind.averaging.averager import *
10+
from hivemind.averaging.averager import AllReduceRunner, AveragingMode, GatheredData
911
from hivemind.averaging.group_info import GroupInfo
1012
from hivemind.averaging.load_balancing import load_balance_peers
1113
from hivemind.averaging.matchmaking import MatchmakingException
1214
from hivemind.proto import averaging_pb2
13-
from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
15+
from hivemind.utils.asyncio import AsyncIterator, aenumerate, as_aiter, azip, enter_asynchronously
1416
from hivemind.utils.logging import get_logger
1517

1618
logger = get_logger(__name__)
@@ -138,6 +140,8 @@ async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[avera
138140
],
139141
)
140142
def test_fault_tolerance(fault0: Fault, fault1: Fault):
143+
torch.manual_seed(0)
144+
141145
def _make_tensors():
142146
return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
143147

0 commit comments

Comments
 (0)