using System.Net.WebSockets;
using System.Text;
using System.Web;
using JNPF.Common.Net;
using JNPF.Common.Security;
using JNPF.DataEncryption;
using JNPF.Extras.WebSockets.Models;
using Microsoft.AspNetCore.Http;
using Microsoft.IdentityModel.JsonWebTokens;
namespace JNPF.WebSockets;
///
/// WebSocket 中间件.
///
public class WebSocketMiddleware
{
///
/// 请求委托.
///
private readonly RequestDelegate _next;
///
/// webSocket 处理程序.
///
private WebSocketHandler _webSocketHandler { get; set; }
///
/// 初始化一个类型的新实例.
///
///
///
public WebSocketMiddleware(
RequestDelegate next,
WebSocketHandler webSocketHandler)
{
_next = next;
_webSocketHandler = webSocketHandler;
}
///
/// 异步调用.
///
public async Task InvokeAsync(HttpContext context)
{
if (!context.WebSockets.IsWebSocketRequest)
{
await _next.Invoke(context);
return;
}
WebSocket? socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false);
var token = new JsonWebToken(context.Request.Path.ToString().TrimStart('/').Replace("Bearer%20", string.Empty).Replace("bearer%20", string.Empty));
var httpContext = (DefaultHttpContext)context;
httpContext.Request.Headers["Authorization"] = HttpUtility.UrlDecode(context.Request.Path.ToString().TrimStart('/'), Encoding.UTF8);
UserAgent userAgent = new UserAgent(httpContext);
if (!JWTEncryption.ValidateJwtBearerToken(httpContext, out token))
{
await _webSocketHandler.OnDisconnected(socket);
}
else
{
var connectionId = Guid.NewGuid().ToString("N");
var wsClient = new WebSocketClient
{
ConnectionId = connectionId,
WebSocket = socket,
LoginIpAddress = NetHelper.Ip,
LoginPlatForm = string.Format("{0}-{1}", userAgent.OS.ToString(), userAgent.RawValue)
};
await _webSocketHandler.OnConnected(connectionId, wsClient).ConfigureAwait(false);
await Receive(wsClient, async (result, serializedMessage) =>
{
switch (result.MessageType)
{
case WebSocketMessageType.Text:
await _webSocketHandler.ReceiveAsync(wsClient, result, serializedMessage).ConfigureAwait(false);
break;
case WebSocketMessageType.Close:
await _webSocketHandler.OnDisconnected(socket);
break;
case WebSocketMessageType.Binary:
break;
default:
throw new ArgumentOutOfRangeException();
}
});
}
}
///
/// 接收数据.
///
///
///
///
private async Task Receive(WebSocketClient client, Action handleMessage)
{
while (client.WebSocket.State == WebSocketState.Open)
{
ArraySegment buffer = new ArraySegment(new byte[1024 * 4]);
string message = string.Empty;
WebSocketReceiveResult result = null;
try
{
using (var ms = new MemoryStream())
{
do
{
result = await client.WebSocket.ReceiveAsync(buffer, CancellationToken.None);
ms.Write(buffer.Array, buffer.Offset, result.Count);
}
while (!result.EndOfMessage);
ms.Seek(0, SeekOrigin.Begin);
using (var reader = new StreamReader(ms, Encoding.UTF8))
{
message = await reader.ReadToEndAsync().ConfigureAwait(false);
}
}
handleMessage(result, message);
}
catch (WebSocketException e)
{
if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
client.WebSocket.Abort();
}
}
}
await _webSocketHandler.OnDisconnected(client.WebSocket);
}
}