2222import static io .lettuce .core .ClientOptions .DEFAULT_JSON_PARSER ;
2323import static io .lettuce .core .protocol .CommandType .*;
2424
25+ import java .nio .CharBuffer ;
2526import java .time .Duration ;
2627import java .util .ArrayList ;
2728import java .util .Collection ;
2829import java .util .List ;
30+ import java .util .concurrent .atomic .AtomicBoolean ;
31+ import java .util .concurrent .atomic .AtomicReference ;
32+ import java .util .concurrent .locks .ReentrantLock ;
2933import java .util .function .Consumer ;
3034import java .util .stream .Collectors ;
3135
3741import io .lettuce .core .cluster .api .sync .RedisClusterCommands ;
3842import io .lettuce .core .codec .RedisCodec ;
3943import io .lettuce .core .codec .StringCodec ;
44+ import io .lettuce .core .event .connection .ReauthenticateEvent ;
45+ import io .lettuce .core .event .connection .ReauthenticateFailedEvent ;
4046import io .lettuce .core .json .JsonParser ;
4147import io .lettuce .core .output .MultiOutput ;
4248import io .lettuce .core .output .StatusOutput ;
4349import io .lettuce .core .protocol .*;
50+ import io .netty .util .internal .logging .InternalLogger ;
51+ import io .netty .util .internal .logging .InternalLoggerFactory ;
4452import reactor .core .publisher .Mono ;
4553
4654/**
5563 */
5664public class StatefulRedisConnectionImpl <K , V > extends RedisChannelHandler <K , V > implements StatefulRedisConnection <K , V > {
5765
66+ private static final InternalLogger logger = InternalLoggerFactory .getInstance (StatefulRedisConnectionImpl .class );
67+
5868 protected final RedisCodec <K , V > codec ;
5969
6070 protected final RedisCommands <K , V > sync ;
@@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V>
7181
7282 protected MultiOutput <K , V > multi ;
7383
84+ private RedisAuthenticationHandler authHandler ;
85+
86+ private AtomicReference <RedisCredentials > credentialsRef = new AtomicReference <>();
87+
88+ private final ReentrantLock reAuthSafety = new ReentrantLock ();
89+
90+ private AtomicBoolean inTransaction = new AtomicBoolean (false );
91+
7492 /**
7593 * Initialize a new connection.
7694 *
@@ -181,20 +199,38 @@ public boolean isMulti() {
181199 public <T > RedisCommand <K , V , T > dispatch (RedisCommand <K , V , T > command ) {
182200
183201 RedisCommand <K , V , T > toSend = preProcessCommand (command );
184- return super .dispatch (toSend );
202+ RedisCommand <K , V , T > result = super .dispatch (toSend );
203+ if (toSend .getType () == EXEC || toSend .getType () == DISCARD ) {
204+ inTransaction .set (false );
205+ setCredentials (credentialsRef .getAndSet (null ));
206+ }
207+
208+ return result ;
185209 }
186210
187211 @ Override
188212 public Collection <RedisCommand <K , V , ?>> dispatch (Collection <? extends RedisCommand <K , V , ?>> commands ) {
189213
190214 List <RedisCommand <K , V , ?>> sentCommands = new ArrayList <>(commands .size ());
191215
192- commands .forEach (o -> {
216+ boolean transactionComplete = false ;
217+ for (RedisCommand <K , V , ?> o : commands ) {
193218 RedisCommand <K , V , ?> command = preProcessCommand (o );
194219 sentCommands .add (command );
195- });
220+ if (command .getType () == EXEC ) {
221+ transactionComplete = true ;
222+ }
223+ if (command .getType () == MULTI || command .getType () == DISCARD ) {
224+ transactionComplete = false ;
225+ }
226+ }
196227
197- return super .dispatch (sentCommands );
228+ Collection <RedisCommand <K , V , ?>> result = super .dispatch (sentCommands );
229+ if (transactionComplete ) {
230+ inTransaction .set (false );
231+ setCredentials (credentialsRef .getAndSet (null ));
232+ }
233+ return result ;
198234 }
199235
200236 // TODO [tihomir.mateev] Refactor to include as part of the Command interface
@@ -273,12 +309,20 @@ protected <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comm
273309
274310 if (commandType .equals (MULTI .name ())) {
275311
312+ reAuthSafety .lock ();
313+ try {
314+ inTransaction .set (true );
315+ } finally {
316+ reAuthSafety .unlock ();
317+ }
276318 multi = (multi == null ? new MultiOutput <>(codec ) : multi );
277319
278320 if (command instanceof CompleteableCommand ) {
279321 ((CompleteableCommand <?>) command ).onComplete ((ignored , e ) -> {
280322 if (e != null ) {
281323 multi = null ;
324+ inTransaction .set (false );
325+ setCredentials (credentialsRef .getAndSet (null ));
282326 }
283327 });
284328 }
@@ -318,11 +362,78 @@ public ConnectionState getConnectionState() {
318362 @ Override
319363 public void activated () {
320364 super .activated ();
365+ if (authHandler != null ) {
366+ authHandler .subscribe ();
367+ }
321368 }
322369
323370 @ Override
324371 public void deactivated () {
372+ if (authHandler != null ) {
373+ authHandler .unsubscribe ();
374+ }
325375 super .deactivated ();
326376 }
327377
378+ public void setAuthenticationHandler (RedisAuthenticationHandler handler ) {
379+ if (authHandler != null ) {
380+ authHandler .unsubscribe ();
381+ }
382+ authHandler = handler ;
383+ if (isOpen ()) {
384+ authHandler .subscribe ();
385+ }
386+ }
387+
388+ public void setCredentials (RedisCredentials credentials ) {
389+ if (credentials == null ) {
390+ return ;
391+ }
392+ reAuthSafety .lock ();
393+ try {
394+ credentialsRef .set (credentials );
395+ if (!inTransaction .get ()) {
396+ dispatchAuthCommand (credentialsRef .getAndSet (null ));
397+ }
398+ } finally {
399+ reAuthSafety .unlock ();
400+ }
401+ }
402+
403+ private void dispatchAuthCommand (RedisCredentials credentials ) {
404+ if (credentials == null ) {
405+ return ;
406+ }
407+
408+ RedisFuture <String > auth ;
409+ if (credentials .getUsername () != null ) {
410+ auth = async ().auth (credentials .getUsername (), CharBuffer .wrap (credentials .getPassword ()));
411+ } else {
412+ auth = async ().auth (CharBuffer .wrap (credentials .getPassword ()));
413+ }
414+ auth .thenRun (() -> {
415+ publishReauthEvent ();
416+ logger .info ("Re-authentication succeeded for endpoint {}." , getEpid ());
417+ }).exceptionally (throwable -> {
418+ publishReauthFailedEvent (throwable );
419+ logger .error ("Re-authentication failed for endpoint {}." , getEpid (), throwable );
420+ return null ;
421+ });
422+ }
423+
424+ private void publishReauthEvent () {
425+ getResources ().eventBus ().publish (new ReauthenticateEvent (getEpid ()));
426+ }
427+
428+ private void publishReauthFailedEvent (Throwable throwable ) {
429+ getResources ().eventBus ().publish (new ReauthenticateFailedEvent (getEpid (), throwable ));
430+ }
431+
432+ private String getEpid () {
433+ if (getChannelWriter () instanceof Endpoint ) {
434+ return ((Endpoint ) getChannelWriter ()).getId ();
435+ }
436+ return "" ;
437+ }
438+
328439}
0 commit comments