diff --git a/tests/pipelines/hidream_image/test_pipeline_hidream.py b/tests/pipelines/hidream_image/test_pipeline_hidream.py index ada4a11d1608..1c5f30e8704f 100644 --- a/tests/pipelines/hidream_image/test_pipeline_hidream.py +++ b/tests/pipelines/hidream_image/test_pipeline_hidream.py @@ -146,11 +146,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs)[0] generated_image = image[0] - self.assertEqual(generated_image.shape, (128, 128, 3)) - expected_image = torch.randn(128, 128, 3).numpy() - max_diff = np.abs(generated_image - expected_image).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = np.array([0.4507, 0.5256, 0.4205, 0.5791, 0.4848, 0.4831, 0.4443, 0.5107, 0.6586, 0.3163, 0.7318, 0.5933, 0.6252, 0.5512, 0.5357, 0.5983]) + # fmt: on + + generated_slice = generated_image.flatten() + generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(np.allclose(generated_slice, expected_slice, atol=1e-3)) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-4)