Skip to content

Commit

Permalink
Adjusting "wait for disconnect" logic to accommodate for gracefully c…
Browse files Browse the repository at this point in the history
…losing the connection.
  • Loading branch information
tpeczek committed Dec 25, 2023
1 parent d98ba76 commit 51da5f0
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Security.Claims;
using System.Collections.Generic;
using Microsoft.Extensions.Options;
using BenchmarkDotNet.Attributes;
Expand Down Expand Up @@ -33,12 +32,12 @@ public class ServerSentEventsServiceBenchmarks
#region Constructor
public ServerSentEventsServiceBenchmarks()
{
_serverSentEventsClient = new ServerSentEventsClient(Guid.NewGuid(), new ClaimsPrincipal(), new NoOpHttpResponse(), false);
_serverSentEventsClient = new ServerSentEventsClient(Guid.NewGuid(), new NoOpHttpContext(), false);

_serverSentEventsService = new ServerSentEventsService(Options.Create<ServerSentEventsServiceOptions<ServerSentEventsService>>(null));
for (int i = 0; i < MULTIPLE_CLIENTS_COUNT; i++)
{
_serverSentEventsService.AddClient(new ServerSentEventsClient(Guid.NewGuid(), new ClaimsPrincipal(), new NoOpHttpResponse(), false));
_serverSentEventsService.AddClient(new ServerSentEventsClient(Guid.NewGuid(), new NoOpHttpContext(), false));
}
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System;
using System.Threading;
using System.Security.Claims;
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;

namespace Benchmark.AspNetCore.ServerSentEvents.Infrastructure
{
internal class NoOpHttpContext : HttpContext
{
public override IFeatureCollection Features => throw new NotImplementedException();

public override HttpRequest Request => throw new NotImplementedException();

public override HttpResponse Response { get; } = new NoOpHttpResponse();

public override ConnectionInfo Connection => throw new NotImplementedException();

public override WebSocketManager WebSockets => throw new NotImplementedException();

public override ClaimsPrincipal User { get; set; } = new ClaimsPrincipal();

public override IDictionary<object, object> Items { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override IServiceProvider RequestServices { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override CancellationToken RequestAborted { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override string TraceIdentifier { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override ISession Session { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public override void Abort()
{
throw new NotImplementedException();
}
}
}
23 changes: 0 additions & 23 deletions Lib.AspNetCore.ServerSentEvents/Internals/TaskHelper.cs

This file was deleted.

27 changes: 22 additions & 5 deletions Lib.AspNetCore.ServerSentEvents/ServerSentEventsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public sealed class ServerSentEventsClient : IServerSentEventsClient
#region Fields
private readonly HttpResponse _response;
private readonly bool _clientDisconnectServicesAvailable;
private readonly TaskCompletionSource<bool> _disconnectTaskCompletionSource = new TaskCompletionSource<bool>();
private readonly ConcurrentDictionary<string, object> _properties = new ConcurrentDictionary<string, object>();
#endregion

Expand All @@ -40,13 +41,20 @@ public sealed class ServerSentEventsClient : IServerSentEventsClient
#endregion

#region Constructor
internal ServerSentEventsClient(Guid id, ClaimsPrincipal user, HttpResponse response, bool clientDisconnectServicesAvailable)
internal ServerSentEventsClient(Guid id, HttpContext context, bool clientDisconnectServicesAvailable)
{
if (context is null)
{
throw new ArgumentNullException(nameof(context));
}

Id = id;
User = user ?? throw new ArgumentNullException(nameof(user));
User = context.User;

_response = response ?? throw new ArgumentNullException(nameof(response));
_response = context.Response;
context.RequestAborted.Register(RequestAbortedCallback, _disconnectTaskCompletionSource);
_clientDisconnectServicesAvailable = clientDisconnectServicesAvailable;

IsConnected = true;
}
#endregion
Expand Down Expand Up @@ -103,12 +111,11 @@ public async Task DisconnectAsync()
{
IsConnected = false;

await _response.Body.FlushAsync();

#if NET461
_response.HttpContext.Abort();
#else
await _response.CompleteAsync();
_disconnectTaskCompletionSource.TrySetResult(true);
#endif
}
}
Expand Down Expand Up @@ -185,6 +192,16 @@ internal Task ChangeReconnectIntervalAsync(uint reconnectInterval, CancellationT
return SendAsync(ServerSentEventsHelper.GetReconnectIntervalBytes(reconnectInterval), cancellationToken);
}

internal Task WaitForDisconnectAsync()
{
return _disconnectTaskCompletionSource.Task;
}

private static void RequestAbortedCallback(object taskCompletionSource)
{
((TaskCompletionSource<bool>)taskCompletionSource).TrySetResult(true);
}

private void CheckIsConnected()
{
if (!IsConnected)
Expand Down
4 changes: 2 additions & 2 deletions Lib.AspNetCore.ServerSentEvents/ServerSentEventsMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public async Task Invoke(HttpContext context, IPolicyEvaluator policyEvaluator)

await context.Response.AcceptAsync(_serverSentEventsOptions.OnPrepareAccept);

ServerSentEventsClient client = new ServerSentEventsClient(clientId, context.User, context.Response, _clientDisconnectServicesAvailable);
ServerSentEventsClient client = new ServerSentEventsClient(clientId, context, _clientDisconnectServicesAvailable);

if (_serverSentEventsService.ReconnectInterval.HasValue)
{
Expand All @@ -105,7 +105,7 @@ public async Task Invoke(HttpContext context, IPolicyEvaluator policyEvaluator)

await ConnectClientAsync(context.Request, client);

await context.RequestAborted.WaitAsync();
await client.WaitForDisconnectAsync();

await DisconnectClientAsync(context.Request, client);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class ServerSentEventsClientTests
private static ServerSentEventsClient PrepareServerSentEventsClient(HttpContext context = null, bool clientDisconnectServicesAvailable = false)
{
context = context ?? new DefaultHttpContext();
return new ServerSentEventsClient(Guid.NewGuid(), new ClaimsPrincipal(), context.Response, clientDisconnectServicesAvailable);
return new ServerSentEventsClient(Guid.NewGuid(), context, clientDisconnectServicesAvailable);
}
#endregion

Expand Down Expand Up @@ -168,8 +168,13 @@ public async Task Disconnect_ClientDisconnectServicesAvailable_Disconnects()
// ARRANGE
HttpContext context = new DefaultHttpContext();

#if NET461
Mock<IHttpRequestLifetimeFeature> httpRequestLifetimeFeatureMock = new Mock<IHttpRequestLifetimeFeature>();
context.Features.Set(httpRequestLifetimeFeatureMock.Object);
#else
Mock<IHttpResponseBodyFeature> httpResponseBodyFeatureMock = new Mock<IHttpResponseBodyFeature>();
context.Features.Set(httpResponseBodyFeatureMock.Object);
#endif

var client = PrepareServerSentEventsClient(context: context, clientDisconnectServicesAvailable: true);

Expand All @@ -178,7 +183,13 @@ public async Task Disconnect_ClientDisconnectServicesAvailable_Disconnects()

// ASSERT
Assert.False(client.IsConnected);
Assert.True(client.DisconnectAsync().IsCompleted);

#if NET461
httpRequestLifetimeFeatureMock.Verify(o => o.Abort(), Times.Once);
#else
httpResponseBodyFeatureMock.Verify(o => o.CompleteAsync(), Times.Once);
#endif
}
#endregion
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class ServerSentEventsServiceTests
private static async Task<ServerSentEventsClient> PrepareAndAddServerSentEventsClientAsync(ServerSentEventsService serverSentEventsService)
{
HttpContext context = new DefaultHttpContext();
ServerSentEventsClient serverSentEventsClient = new ServerSentEventsClient(Guid.NewGuid(), new ClaimsPrincipal(), context.Response, false);
ServerSentEventsClient serverSentEventsClient = new ServerSentEventsClient(Guid.NewGuid(), context, false);

await serverSentEventsService.OnConnectAsync(context.Request, serverSentEventsClient);

Expand Down

0 comments on commit 51da5f0

Please sign in to comment.