@@ -1025,136 +1025,79 @@ def test_multi_images(self):
1025
1025
class TruncateWithProtectedTokensTester (TrlTestCase ):
1026
1026
def test_basic_example (self ):
1027
1027
"""Test the basic example from the problem description."""
1028
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ], [6 , 7 , 8 , 9 , 10 ]])
1029
- prompt_mask = torch .ones_like (prompt_ids )
1030
- protected_tokens = [2 , 3 , 6 ]
1028
+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
1029
+ protected_tokens = [2 , 3 ]
1031
1030
target_length = 3
1032
1031
1033
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1034
-
1035
- expected_ids = torch .tensor ([[2 , 3 , 5 ], [6 , 9 , 10 ]])
1036
- expected_mask = torch .ones_like (expected_ids )
1032
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1037
1033
1038
- self . assertTrue ( torch . equal ( new_ids , expected_ids ))
1039
- self .assertTrue ( torch . equal ( new_mask , expected_mask ) )
1034
+ expected_ids = [ 2 , 3 , 5 ]
1035
+ self .assertEqual ( new_ids , expected_ids )
1040
1036
1041
1037
def test_no_truncation_needed (self ):
1042
1038
"""Test when target length equals current length."""
1043
- prompt_ids = torch .tensor ([[1 , 2 , 3 ]])
1044
- prompt_mask = torch .ones_like (prompt_ids )
1039
+ prompt_ids = [1 , 2 , 3 ]
1045
1040
protected_tokens = [2 ]
1046
1041
target_length = 3
1047
1042
1048
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1043
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1049
1044
1050
- self .assertTrue (torch .equal (new_ids , prompt_ids ))
1051
- self .assertTrue (torch .equal (new_mask , prompt_mask ))
1045
+ self .assertEqual (new_ids , prompt_ids )
1052
1046
1053
1047
def test_no_protected_tokens (self ):
1054
1048
"""Test truncation with no protected tokens (normal right truncation)."""
1055
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1056
- prompt_mask = torch .ones_like (prompt_ids )
1049
+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
1057
1050
protected_tokens = []
1058
1051
target_length = 3
1059
1052
1060
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1053
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1061
1054
1062
- expected_ids = torch . tensor ([[ 3 , 4 , 5 ]]) # Last 3 tokens
1063
- self .assertTrue ( torch . equal ( new_ids , expected_ids ) )
1055
+ expected_ids = [ 3 , 4 , 5 ] # Last 3 tokens
1056
+ self .assertEqual ( new_ids , expected_ids )
1064
1057
1065
1058
def test_all_tokens_protected (self ):
1066
1059
"""Test when all remaining tokens are protected."""
1067
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1068
- prompt_mask = torch .ones_like (prompt_ids )
1060
+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
1069
1061
protected_tokens = [3 , 4 , 5 ]
1070
1062
target_length = 3
1071
1063
1072
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1064
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1073
1065
1074
- expected_ids = torch . tensor ([[ 3 , 4 , 5 ]])
1075
- self .assertTrue ( torch . equal ( new_ids , expected_ids ) )
1066
+ expected_ids = [ 3 , 4 , 5 ]
1067
+ self .assertEqual ( new_ids , expected_ids )
1076
1068
1077
1069
def test_too_many_protected_tokens (self ):
1078
1070
"""Test error when too many protected tokens for target length."""
1079
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1080
- prompt_mask = torch .ones_like (prompt_ids )
1071
+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
1081
1072
protected_tokens = [1 , 2 , 3 , 4 ]
1082
1073
target_length = 3
1083
1074
1084
1075
with self .assertRaises (ValueError ):
1085
- truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1076
+ truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1086
1077
1087
1078
def test_single_batch_single_token (self ):
1088
1079
"""Test edge case with single batch and single token."""
1089
- prompt_ids = torch .tensor ([[5 ]])
1090
- prompt_mask = torch .ones_like (prompt_ids )
1080
+ prompt_ids = [5 ]
1091
1081
protected_tokens = [5 ]
1092
1082
target_length = 1
1093
1083
1094
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1095
-
1096
- self .assertTrue (torch .equal (new_ids , prompt_ids ))
1097
-
1098
- def test_mask_preservation (self ):
1099
- """Test that mask values are correctly preserved."""
1100
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1101
- prompt_mask = torch .tensor ([[1 , 0 , 1 , 0 , 1 ]]) # Mixed mask values
1102
- protected_tokens = [2 , 4 ]
1103
- target_length = 3
1104
-
1105
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1106
-
1107
- expected_ids = torch .tensor ([[2 , 4 , 5 ]])
1108
- expected_mask = torch .tensor ([[0 , 0 , 1 ]]) # Corresponding mask values
1109
-
1110
- self .assertTrue (torch .equal (new_ids , expected_ids ))
1111
- self .assertTrue (torch .equal (new_mask , expected_mask ))
1112
-
1113
- def test_multiple_batches_different_protected (self ):
1114
- """Test multiple batches where protected tokens appear differently."""
1115
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ], [2 , 6 , 7 , 8 , 9 ], [10 , 11 , 12 , 2 , 13 ]])
1116
- prompt_mask = torch .ones_like (prompt_ids )
1117
- protected_tokens = [2 ]
1118
- target_length = 3
1119
-
1120
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1084
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1121
1085
1122
- expected_ids = torch .tensor (
1123
- [
1124
- [2 , 4 , 5 ], # 2 is protected, keep last 2 non-protected (4,5)
1125
- [2 , 8 , 9 ], # 2 is protected, keep last 2 non-protected (8,9)
1126
- [12 , 2 , 13 ], # 2 is protected, keep last 2 non-protected (12,13)
1127
- ]
1128
- )
1129
-
1130
- self .assertTrue (torch .equal (new_ids , expected_ids ))
1086
+ self .assertEqual (new_ids , prompt_ids )
1131
1087
1132
1088
def test_order_preservation (self ):
1133
1089
"""Test that relative order is preserved."""
1134
- prompt_ids = torch .tensor ([[10 , 2 , 20 , 3 , 30 , 40 ]])
1135
- prompt_mask = torch .ones_like (prompt_ids )
1090
+ prompt_ids = [10 , 2 , 20 , 3 , 30 , 40 ]
1136
1091
protected_tokens = [2 , 3 ]
1137
1092
target_length = 4
1138
1093
1139
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1094
+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
1140
1095
1141
- # Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40
1096
+ # Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
1142
1097
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
1143
- expected_ids = torch .tensor ([[2 , 3 , 30 , 40 ]])
1144
-
1145
- self .assertTrue (torch .equal (new_ids , expected_ids ))
1146
-
1147
- def test_empty_protected_tokens_list (self ):
1148
- """Test with empty protected tokens list."""
1149
- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1150
- prompt_mask = torch .ones_like (prompt_ids )
1151
- protected_tokens = []
1152
- target_length = 2
1153
-
1154
- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1098
+ expected_ids = [2 , 3 , 30 , 40 ]
1155
1099
1156
- expected_ids = torch .tensor ([[4 , 5 ]]) # Last 2 tokens
1157
- self .assertTrue (torch .equal (new_ids , expected_ids ))
1100
+ self .assertEqual (new_ids , expected_ids )
1158
1101
1159
1102
1160
1103
class UnsplitPixelValuesByGridTester (TrlTestCase ):
0 commit comments