Skip to content

Commit f9fbd91

Browse files
authored
[CI] fix CI failure of transformer dev (#3457)
1 parent 54d4f6b commit f9fbd91

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/test_dpo_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,13 +1344,15 @@ def test_vdpo_trainer(self, model_id):
13441344
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
13451345
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
13461346
] and (
1347-
n.startswith("vision_tower.vision_model.encoder.layers.1")
1348-
or n == "vision_tower.vision_model.post_layernorm.weight"
1347+
"vision_tower.vision_model.encoder.layers.1" in n
1348+
or "vision_tower.vision_model.post_layernorm.weight" in n
13491349
):
13501350
# For some reason, these params are not updated. This is probably not related to TRL, but to
13511351
# the model itself. We should investigate this further, but for now we just skip these params.
13521352
continue
1353-
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
1353+
self.assertFalse(
1354+
torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
1355+
)
13541356

13551357

13561358
if __name__ == "__main__":

0 commit comments

Comments
 (0)