Skip to content

Commit

Permalink
feat: support SSE (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvarendorff2 authored May 7, 2024
1 parent eebee43 commit 1ce74f0
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/Fluss.HotChocolate/AddExtensionMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ public async ValueTask InvokeAsync(IRequestContext context)
return;
}

if (true != context.Services.GetRequiredService<IHttpContextAccessor>().HttpContext?.WebSockets
.IsWebSocketRequest)
var httpContext = context.Services.GetRequiredService<IHttpContextAccessor>().HttpContext;
var isWebsocket = true == httpContext?.WebSockets.IsWebSocketRequest;
var isSse = httpContext?.Request.Headers.Accept.ToString() == "text/event-stream";
if (!isWebsocket && !isSse)
{
return;
}
Expand Down Expand Up @@ -76,16 +78,18 @@ private async IAsyncEnumerable<IQueryResult> LiveResults(IReadOnlyDictionary<str
yield break;
}

contextData.TryGetValue(nameof(HttpContext), out var httpContext);
var isWebsocket = (httpContext as HttpContext)?.WebSockets.IsWebSocketRequest ?? false;
var foundSocketSession = contextData.TryGetValue(nameof(ISocketSession), out var contextSocketSession); // as ISocketSession
var foundOperationId = contextData.TryGetValue("HotChocolate.Execution.Transport.OperationSessionId", out var operationId); // as string

if (!foundSocketSession || !foundOperationId)
if (isWebsocket && (!foundSocketSession || !foundOperationId))
{
_logger.LogWarning("Trying to add live results but {SocketSession} or {OperationId} is not present in context!", nameof(contextSocketSession), nameof(operationId));
yield break;
}

if (contextSocketSession is not ISocketSession socketSession)
if (isWebsocket && contextSocketSession is not ISocketSession)
{
_logger.LogWarning("{ContextSocketSession} key present in context but not an {ISocketSession}!", contextSocketSession?.GetType().FullName, nameof(ISocketSession));
yield break;
Expand All @@ -108,7 +112,7 @@ private async IAsyncEnumerable<IQueryResult> LiveResults(IReadOnlyDictionary<str
unitOfWork.ReadModels
);

if (socketSession.Operations.All(operationSession => operationSession.Id != operationId?.ToString()))
if (isWebsocket && contextSocketSession is ISocketSession socketSession && socketSession.Operations.All(operationSession => operationSession.Id != operationId?.ToString()))
{
break;
}
Expand Down

0 comments on commit 1ce74f0

Please sign in to comment.