diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index b8a9c00ff..c95849ab5 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -17,6 +17,7 @@ package io.rsocket.core; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; +import static io.rsocket.keepalive.KeepAliveSupport.ServerKeepAliveSupport; import io.netty.buffer.ByteBuf; import io.netty.util.collection.IntObjectMap; @@ -85,6 +86,38 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { @Nullable RequesterLeaseTracker requesterLeaseTracker, Sinks.Empty onThisSideClosedSink, Mono onAllClosed) { + this( + connection, + payloadDecoder, + streamIdSupplier, + mtu, + maxFrameLength, + maxInboundPayloadSize, + keepAliveTickPeriod, + keepAliveAckTimeout, + keepAliveHandler, + requestInterceptorFunction, + requesterLeaseTracker, + onThisSideClosedSink, + onAllClosed, + false); + } + + RSocketRequester( + DuplexConnection connection, + PayloadDecoder payloadDecoder, + StreamIdSupplier streamIdSupplier, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + int keepAliveTickPeriod, + int keepAliveAckTimeout, + @Nullable KeepAliveHandler keepAliveHandler, + Function requestInterceptorFunction, + @Nullable RequesterLeaseTracker requesterLeaseTracker, + Sinks.Empty onThisSideClosedSink, + Mono onAllClosed, + boolean serverSide) { super( mtu, maxFrameLength, @@ -105,7 +138,11 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { KeepAliveSupport keepAliveSupport = - new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); + serverSide + ? new ServerKeepAliveSupport( + this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout) + : new ClientKeepAliveSupport( + this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = keepAliveHandler.start( keepAliveSupport, diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index e969c39d2..62ca55d60 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -456,7 +456,8 @@ private Mono acceptSetup( requesterLeaseTracker, requesterOnAllClosedSink, Mono.whenDelayError( - responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono())); + responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono()), + true); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java index 4fd18d041..1887d9dcc 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -181,6 +181,19 @@ void onIntervalTick() { } } + public static final class ServerKeepAliveSupport extends KeepAliveSupport { + + public ServerKeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + super(allocator, keepAliveInterval, keepAliveTimeout); + } + + @Override + void onIntervalTick() { + tryTimeout(); + } + } + public static final class KeepAlive { private final Duration tickPeriod; private final Duration timeoutMillis; diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 5be59235c..a0ee13080 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -89,6 +89,30 @@ static RSocketState requester(int tickPeriod, int timeout) { return new RSocketState(rSocket, allocator, connection, empty); } + static RSocketState serverRequester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + Sinks.Empty empty = Sinks.empty(); + RSocketRequester rSocket = + new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.serverSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + tickPeriod, + timeout, + new DefaultKeepAliveHandler(), + r -> null, + null, + empty, + empty.asMono(), + true); + return new RSocketState(rSocket, allocator, connection, empty); + } + static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @@ -339,6 +363,105 @@ void resumableRSocketsNotDisposedOnMissingKeepAlives() throws InterruptedExcepti resumableRequesterState.allocator.assertHasNoLeaks(); } + @Test + void serverDoesNotSendProactiveKeepAlives() { + RSocketState serverState = serverRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection connection = serverState.connection(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL * 3)); + + // Server should not have sent any keepalive frames + Assertions.assertThat(connection.pollFrame()).isNull(); + + serverState.rSocket.dispose(); + FrameAssert.assertThat(connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + serverState.connection.dispose(); + + serverState.allocator.assertHasNoLeaks(); + } + + @Test + void serverRespondsToClientKeepAlives() { + RSocketState serverState = serverRequester(100_000, 100_000); + TestDuplexConnection connection = serverState.connection(); + + Duration duration = Duration.ofMillis(100); + Mono.delay(duration) + .subscribe( + l -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + serverState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(duration); + FrameAssert.assertThat(connection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .matches(this::keepAliveFrameWithoutRespondFlag); + + serverState.rSocket.dispose(); + FrameAssert.assertThat(serverState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + serverState.connection.dispose(); + + serverState.allocator.assertHasNoLeaks(); + } + + @Test + void serverDisposedOnMissingKeepAlives() { + RSocketState serverState = serverRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + RSocket rSocket = serverState.rSocket(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + rSocket + .onClose() + .as(StepVerifier::create) + .expectError(ConnectionErrorException.class) + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(serverState.connection.getSent()).allMatch(ByteBuf::release); + + serverState.allocator.assertHasNoLeaks(); + } + + @Test + void serverNotDisposedOnPresentKeepAlives() { + RSocketState serverState = serverRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + TestDuplexConnection connection = serverState.connection(); + + Disposable disposable = + Flux.interval(Duration.ofMillis(KEEP_ALIVE_INTERVAL)) + .subscribe( + n -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + serverState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + RSocket rSocket = serverState.rSocket(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + + disposable.dispose(); + + serverState.connection.dispose(); + serverState.rSocket.dispose(); + + Assertions.assertThat(serverState.connection.getSent()).allMatch(ByteBuf::release); + + serverState.allocator.assertHasNoLeaks(); + } + private boolean keepAliveFrame(ByteBuf frame) { return FrameHeaderCodec.frameType(frame) == FrameType.KEEPALIVE; }