2222import static io .lettuce .core .ClientOptions .DEFAULT_JSON_PARSER ;
2323import static io .lettuce .core .protocol .CommandType .*;
2424
25+ import java .nio .CharBuffer ;
2526import java .time .Duration ;
2627import java .util .ArrayList ;
2728import java .util .Collection ;
2829import java .util .List ;
30+ import java .util .concurrent .atomic .AtomicBoolean ;
31+ import java .util .concurrent .atomic .AtomicReference ;
32+ import java .util .concurrent .locks .ReentrantLock ;
2933import java .util .function .Consumer ;
3034import java .util .stream .Collectors ;
3135
3741import io .lettuce .core .cluster .api .sync .RedisClusterCommands ;
3842import io .lettuce .core .codec .RedisCodec ;
3943import io .lettuce .core .codec .StringCodec ;
44+ import io .lettuce .core .event .connection .ReauthenticateEvent ;
45+ import io .lettuce .core .event .connection .ReauthenticateFailedEvent ;
4046import io .lettuce .core .json .JsonParser ;
4147import io .lettuce .core .output .MultiOutput ;
4248import io .lettuce .core .output .StatusOutput ;
4349import io .lettuce .core .protocol .*;
50+ import io .netty .util .internal .logging .InternalLogger ;
51+ import io .netty .util .internal .logging .InternalLoggerFactory ;
4452import reactor .core .publisher .Mono ;
4553
4654/**
5563 */
5664public class StatefulRedisConnectionImpl <K , V > extends RedisChannelHandler <K , V > implements StatefulRedisConnection <K , V > {
5765
66+ private static final InternalLogger logger = InternalLoggerFactory .getInstance (StatefulRedisConnectionImpl .class );
67+
5868 protected final RedisCodec <K , V > codec ;
5969
6070 protected final RedisCommands <K , V > sync ;
@@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V>
7181
7282 protected MultiOutput <K , V > multi ;
7383
84+ private RedisAuthenticationHandler authHandler ;
85+
86+ private AtomicReference <RedisCredentials > credentialsRef = new AtomicReference <>();
87+
88+ private final ReentrantLock reAuthSafety = new ReentrantLock ();
89+
90+ private AtomicBoolean inTransaction = new AtomicBoolean (false );
91+
7492 /**
7593 * Initialize a new connection.
7694 *
@@ -181,20 +199,38 @@ public boolean isMulti() {
181199 public <T > RedisCommand <K , V , T > dispatch (RedisCommand <K , V , T > command ) {
182200
183201 RedisCommand <K , V , T > toSend = preProcessCommand (command );
184- return super .dispatch (toSend );
202+ RedisCommand <K , V , T > result = super .dispatch (toSend );
203+ if (toSend .getType () == EXEC || toSend .getType () == DISCARD ) {
204+ inTransaction .set (false );
205+ setCredentials (credentialsRef .getAndSet (null ));
206+ }
207+
208+ return result ;
185209 }
186210
187211 @ Override
188212 public Collection <RedisCommand <K , V , ?>> dispatch (Collection <? extends RedisCommand <K , V , ?>> commands ) {
189213
190214 List <RedisCommand <K , V , ?>> sentCommands = new ArrayList <>(commands .size ());
191215
192- commands .forEach (o -> {
216+ boolean transactionComplete = false ;
217+ for (RedisCommand <K , V , ?> o : commands ) {
193218 RedisCommand <K , V , ?> command = preProcessCommand (o );
194219 sentCommands .add (command );
195- });
220+ if (command .getType () == EXEC ) {
221+ transactionComplete = true ;
222+ }
223+ if (command .getType () == MULTI || command .getType () == DISCARD ) {
224+ transactionComplete = false ;
225+ }
226+ }
196227
197- return super .dispatch (sentCommands );
228+ Collection <RedisCommand <K , V , ?>> result = super .dispatch (sentCommands );
229+ if (transactionComplete ) {
230+ inTransaction .set (false );
231+ setCredentials (credentialsRef .getAndSet (null ));
232+ }
233+ return result ;
198234 }
199235
200236 // TODO [tihomir.mateev] Refactor to include as part of the Command interface
@@ -273,12 +309,20 @@ protected <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comm
273309
274310 if (commandType .equals (MULTI .name ())) {
275311
312+ reAuthSafety .lock ();
313+ try {
314+ inTransaction .set (true );
315+ } finally {
316+ reAuthSafety .unlock ();
317+ }
276318 multi = (multi == null ? new MultiOutput <>(codec ) : multi );
277319
278320 if (command instanceof CompleteableCommand ) {
279321 ((CompleteableCommand <?>) command ).onComplete ((ignored , e ) -> {
280322 if (e != null ) {
281323 multi = null ;
324+ inTransaction .set (false );
325+ setCredentials (credentialsRef .getAndSet (null ));
282326 }
283327 });
284328 }
@@ -318,11 +362,72 @@ public ConnectionState getConnectionState() {
318362 @ Override
319363 public void activated () {
320364 super .activated ();
365+ if (authHandler != null ) {
366+ authHandler .subscribe ();
367+ }
321368 }
322369
323370 @ Override
324371 public void deactivated () {
372+ if (authHandler != null ) {
373+ authHandler .unsubscribe ();
374+ }
325375 super .deactivated ();
326376 }
327377
378+ public void setAuthenticationHandler (RedisAuthenticationHandler handler ) {
379+ authHandler = handler ;
380+ }
381+
382+ public void setCredentials (RedisCredentials credentials ) {
383+ if (credentials == null ) {
384+ return ;
385+ }
386+ reAuthSafety .lock ();
387+ try {
388+ credentialsRef .set (credentials );
389+ if (!inTransaction .get ()) {
390+ dispatchAuthCommand (credentialsRef .getAndSet (null ));
391+ }
392+ } finally {
393+ reAuthSafety .unlock ();
394+ }
395+ }
396+
397+ private void dispatchAuthCommand (RedisCredentials credentials ) {
398+ if (credentials == null ) {
399+ return ;
400+ }
401+
402+ RedisFuture <String > auth ;
403+ if (credentials .getUsername () != null ) {
404+ auth = async ().auth (credentials .getUsername (), CharBuffer .wrap (credentials .getPassword ()));
405+ } else {
406+ auth = async ().auth (CharBuffer .wrap (credentials .getPassword ()));
407+ }
408+ auth .thenRun (() -> {
409+ publishReauthEvent ();
410+ logger .info ("Re-authentication succeeded {}." , getEpid ());
411+ }).exceptionally (throwable -> {
412+ publishReauthFailedEvent (throwable );
413+ logger .error ("Re-authentication failed {}." , getEpid (), throwable );
414+ return null ;
415+ });
416+ }
417+
418+ private void publishReauthEvent () {
419+ getResources ().eventBus ().publish (new ReauthenticateEvent (getEpid ()));
420+ }
421+
422+ private void publishReauthFailedEvent (Throwable throwable ) {
423+ getResources ().eventBus ().publish (new ReauthenticateFailedEvent (getEpid (), throwable ));
424+ }
425+
426+ private String getEpid () {
427+ if (getChannelWriter () instanceof Endpoint ) {
428+ return ((Endpoint ) getChannelWriter ()).getId ();
429+ }
430+ return "" ;
431+ }
432+
328433}
0 commit comments