diff --git a/tests/pipelines/cosmos/test_cosmos.py b/tests/pipelines/cosmos/test_cosmos.py index 0c1024a9a9f8..4d3202f78508 100644 --- a/tests/pipelines/cosmos/test_cosmos.py +++ b/tests/pipelines/cosmos/test_cosmos.py @@ -153,11 +153,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/cosmos/test_cosmos2_text2image.py b/tests/pipelines/cosmos/test_cosmos2_text2image.py index 386bf161a095..cc2fcec64175 100644 --- a/tests/pipelines/cosmos/test_cosmos2_text2image.py +++ b/tests/pipelines/cosmos/test_cosmos2_text2image.py @@ -140,11 +140,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images generated_image = image[0] - self.assertEqual(generated_image.shape, (3, 32, 32)) - expected_video = torch.randn(3, 32, 32) - max_diff = np.abs(generated_image - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333]) + # fmt: on + + generated_slice = generated_image.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py index 421e3a1ad343..b23c8aed1734 100644 --- a/tests/pipelines/cosmos/test_cosmos2_video2world.py +++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py @@ -147,11 +147,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_components_function(self): init_components = self.get_dummy_components() diff --git a/tests/pipelines/cosmos/test_cosmos_video2world.py b/tests/pipelines/cosmos/test_cosmos_video2world.py index 2b893e99700e..d0dba5575bb7 100644 --- a/tests/pipelines/cosmos/test_cosmos_video2world.py +++ b/tests/pipelines/cosmos/test_cosmos_video2world.py @@ -159,11 +159,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_components_function(self): init_components = self.get_dummy_components()