Skip to content

Commit 588c02e

Browse files
committed
Consider collection query methods returning Set in AOT repositories.
We now correctly adapt results to Set in AOT-generated repsoitory query methods. Closes #4094
1 parent e1fe724 commit 588c02e

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaCodeBlocks.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Arrays;
2424
import java.util.Collection;
2525
import java.util.List;
26+
import java.util.Set;
2627
import java.util.function.LongSupplier;
2728

2829
import org.jspecify.annotations.Nullable;
@@ -715,6 +716,9 @@ public CodeBlock build() {
715716
"return ($1T) $2T.getSharedInstance().convert($3T.of($4L), $5T.valueOf($3T.class), $5T.valueOf($1T.class))",
716717
methodReturn.toClass(), DefaultConversionService.class, Streamable.class,
717718
context.localVariable("resultList"), TypeDescriptor.class);
719+
} else if (isSet(methodReturn)) {
720+
builder.addStatement("return ($T) convertOne($L, false, $T.class)", List.class,
721+
context.localVariable("resultList"), methodReturn.toClass());
718722
} else {
719723
builder.addStatement("return ($T) $L", List.class, context.localVariable("resultList"));
720724
}
@@ -750,8 +754,18 @@ public CodeBlock build() {
750754
}
751755

752756
if (queryMethod.isCollectionQuery()) {
753-
builder.addStatement("return ($T) convertMany($L.getResultList(), $L, $L)", methodReturn.getTypeName(),
754-
queryVariableName, aotQuery.isNative(), convertTo);
757+
758+
if (isStreamable(methodReturn)) {
759+
builder.addStatement("return ($1T) $1T.of(($2T) convertMany($3L.getResultList(), $4L, $5L))",
760+
Streamable.class, Iterable.class, queryVariableName, aotQuery.isNative(), convertTo);
761+
} else if (isSet(methodReturn)) {
762+
builder.addStatement("return ($T) convertOne(convertMany($L.getResultList(), $L, $L), false, $T.class)",
763+
methodReturn.getTypeName(), queryVariableName, aotQuery.isNative(), convertTo,
764+
methodReturn.toClass());
765+
} else {
766+
builder.addStatement("return ($T) convertMany($L.getResultList(), $L, $L)", methodReturn.getTypeName(),
767+
queryVariableName, aotQuery.isNative(), convertTo);
768+
}
755769
} else if (queryMethod.isStreamQuery()) {
756770
builder.addStatement("return ($T) convertMany($L.getResultStream(), $L, $L)", methodReturn.getTypeName(),
757771
queryVariableName, aotQuery.isNative(), convertTo);
@@ -791,6 +805,9 @@ public CodeBlock build() {
791805
"return ($1T) $2T.getSharedInstance().convert($3T.of($4L.getResultList()), $5T.valueOf($3T.class), $5T.valueOf($1T.class))",
792806
methodReturn.toClass(), DefaultConversionService.class, Streamable.class, queryVariableName,
793807
TypeDescriptor.class);
808+
} else if (isSet(methodReturn)) {
809+
builder.addStatement("return ($T) convertOne($L.getResultList(), false, $T.class)",
810+
methodReturn.getTypeName(), queryVariableName, methodReturn.toClass());
794811
} else {
795812
builder.addStatement("return ($T) $L.getResultList()", methodReturn.getTypeName(), queryVariableName);
796813
}
@@ -827,6 +844,10 @@ private boolean canConvert(Class<?> from, MethodReturn methodReturn) {
827844
return DefaultConversionService.getSharedInstance().canConvert(from, methodReturn.toClass());
828845
}
829846

847+
private static boolean isSet(MethodReturn methodReturn) {
848+
return Set.class.isAssignableFrom(methodReturn.toClass());
849+
}
850+
830851
private static boolean isStreamable(MethodReturn methodReturn) {
831852
return methodReturn.toClass().equals(Streamable.class);
832853
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/aot/JpaRepositoryContributorIntegrationTests.java

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.util.List;
2323
import java.util.Optional;
24+
import java.util.Set;
2425
import java.util.stream.Stream;
2526

2627
import org.hibernate.proxy.HibernateProxy;
@@ -105,12 +106,19 @@ void beforeEach() {
105106
}
106107

107108
@Test // GH-3830
108-
void testDerivedFinderWithoutArguments() {
109+
void testDerivedQueryWithoutArguments() {
109110

110111
List<User> users = fragment.findUserNoArgumentsBy();
111112
assertThat(users).hasSize(7).hasOnlyElementsOfType(User.class);
112113
}
113114

115+
@Test // GH-4094
116+
void testDerivedQueryAsSet() {
117+
118+
Set<User> users = fragment.findUserSetBy();
119+
assertThat(users).hasSize(7).hasOnlyElementsOfType(User.class);
120+
}
121+
114122
@Test // GH-3830
115123
void testFindDerivedQuerySingleEntity() {
116124

@@ -119,7 +127,7 @@ void testFindDerivedQuerySingleEntity() {
119127
}
120128

121129
@Test // GH-3830
122-
void testFindDerivedFinderOptionalEntity() {
130+
void testFindDerivedQueryOptionalEntity() {
123131

124132
Optional<User> user = fragment.findOptionalOneByEmailAddress("[email protected]");
125133
assertThat(user).isNotNull().containsInstanceOf(User.class)
@@ -141,7 +149,7 @@ void testDerivedExists() {
141149
}
142150

143151
@Test // GH-3830
144-
void testDerivedFinderReturningList() {
152+
void testDerivedQueryReturningList() {
145153

146154
List<User> users = fragment.findByLastnameStartingWith("S");
147155
assertThat(users).extracting(User::getEmailAddress).containsExactlyInAnyOrder("[email protected]", "[email protected]",
@@ -157,44 +165,44 @@ void shouldReturnStream() {
157165
}
158166

159167
@Test // GH-3830
160-
void testLimitedDerivedFinder() {
168+
void testLimitedDerivedQuery() {
161169

162170
List<User> users = fragment.findTop2ByLastnameStartingWith("S");
163171
assertThat(users).hasSize(2);
164172
}
165173

166174
@Test // GH-3830
167-
void testSortedDerivedFinder() {
175+
void testSortedDerivedQuery() {
168176

169177
List<User> users = fragment.findByLastnameStartingWithOrderByEmailAddress("S");
170178
assertThat(users).extracting(User::getEmailAddress).containsExactly("[email protected]", "[email protected]",
171179
172180
}
173181

174182
@Test // GH-3830
175-
void testDerivedFinderWithLimitArgument() {
183+
void testDerivedQueryWithLimitArgument() {
176184

177185
List<User> users = fragment.findByLastnameStartingWith("S", Limit.of(2));
178186
assertThat(users).hasSize(2);
179187
}
180188

181189
@Test // GH-3830
182-
void testDerivedFinderWithSort() {
190+
void testDerivedQueryWithSort() {
183191

184192
List<User> users = fragment.findByLastnameStartingWith("S", Sort.by("emailAddress"));
185193
assertThat(users).extracting(User::getEmailAddress).containsExactly("[email protected]", "[email protected]",
186194
187195
}
188196

189197
@Test // GH-3830
190-
void testDerivedFinderWithSortAndLimit() {
198+
void testDerivedQueryWithSortAndLimit() {
191199

192200
List<User> users = fragment.findByLastnameStartingWith("S", Sort.by("emailAddress"), Limit.of(2));
193201
assertThat(users).extracting(User::getEmailAddress).containsExactly("[email protected]", "[email protected]");
194202
}
195203

196204
@Test // GH-3830
197-
void testDerivedFinderReturningListWithPageable() {
205+
void testDerivedQueryReturningListWithPageable() {
198206

199207
List<User> users = fragment.findByLastnameStartingWith("S", PageRequest.of(0, 2, Sort.by("emailAddress")));
200208
assertThat(users).extracting(User::getEmailAddress).containsExactly("[email protected]", "[email protected]");
@@ -208,7 +216,7 @@ void testDerivedQueryMethodReturningStreamable() {
208216
}
209217

210218
@Test // GH-3830
211-
void testDerivedFinderReturningPage() {
219+
void testDerivedQueryReturningPage() {
212220

213221
Page<User> page = fragment.findPageOfUsersByLastnameStartingWith("S",
214222
PageRequest.of(0, 2, Sort.by("emailAddress")));
@@ -220,7 +228,7 @@ void testDerivedFinderReturningPage() {
220228
}
221229

222230
@Test // GH-3830
223-
void testDerivedFinderReturningSlice() {
231+
void testDerivedQueryReturningSlice() {
224232

225233
Slice<User> slice = fragment.findSliceOfUserByLastnameStartingWith("S",
226234
PageRequest.of(0, 2, Sort.by("emailAddress")));
@@ -374,15 +382,31 @@ void shouldEvaluateExpressionByPosition() {
374382
}
375383

376384
@Test // GH-3830
377-
void testDerivedFinderReturningListOfProjections() {
385+
void testDerivedQueryReturningListOfProjections() {
378386

379387
List<UserDtoProjection> users = fragment.findUserProjectionByLastnameStartingWith("S");
380388
assertThat(users).extracting(UserDtoProjection::getEmailAddress).containsExactlyInAnyOrder("[email protected]",
381389
382390
}
383391

392+
@Test // GH-4094
393+
void testDerivedQueryReturningSetOfProjections() {
394+
395+
Set<UserDtoProjection> users = fragment.findUserProjectionSetByLastnameStartingWith("S");
396+
assertThat(users).extracting(UserDtoProjection::getEmailAddress).containsExactlyInAnyOrder("[email protected]",
397+
398+
}
399+
400+
@Test // GH-4094
401+
void testDerivedQueryReturningStreamableOfProjections() {
402+
403+
Streamable<UserDtoProjection> users = fragment.findUserProjectionStreamableByLastnameStartingWith("S");
404+
assertThat(users).extracting(UserDtoProjection::getEmailAddress).containsExactlyInAnyOrder("[email protected]",
405+
406+
}
407+
384408
@Test // GH-3830
385-
void testDerivedFinderReturningPageOfProjections() {
409+
void testDerivedQueryReturningPageOfProjections() {
386410

387411
Page<UserDtoProjection> page = fragment.findUserProjectionByLastnameStartingWith("S",
388412
PageRequest.of(0, 2, Sort.by("emailAddress")));

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/aot/UserRepository.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.List;
2121
import java.util.Optional;
22+
import java.util.Set;
2223
import java.util.stream.Stream;
2324

2425
import org.springframework.data.domain.Limit;
@@ -46,6 +47,8 @@ interface UserRepository extends CrudRepository<User, Integer> {
4647

4748
List<User> findUserNoArgumentsBy();
4849

50+
Set<User> findUserSetBy();
51+
4952
User findOneByEmailAddress(String emailAddress);
5053

5154
Optional<User> findOptionalOneByEmailAddress(String emailAddress);
@@ -149,6 +152,10 @@ List<User> findWithParameterNameByLastnameStartingWithOrLastnameEndingWith(@Para
149152

150153
List<UserDtoProjection> findUserProjectionByLastnameStartingWith(String lastname);
151154

155+
Set<UserDtoProjection> findUserProjectionSetByLastnameStartingWith(String lastname);
156+
157+
Streamable<UserDtoProjection> findUserProjectionStreamableByLastnameStartingWith(String lastname);
158+
152159
Page<UserDtoProjection> findUserProjectionByLastnameStartingWith(String lastname, Pageable page);
153160

154161
Names findDtoByEmailAddress(String emailAddress);

0 commit comments

Comments
 (0)