git » repo » main » tree

[main] / services / spaces / src / WsHandler.cs

using System.Buffers;
using System.Collections.Concurrent;
using System.Net.WebSockets;
using System.Text.Json;
using CommunityToolkit.HighPerformance.Buffers;

namespace spaces;

internal static class WsHandler
{
    public static void CloseConnections() => Task.WhenAll(Connections.Select(async conn =>
    {
        try { await conn.Key.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None); }
        catch { /* ignored */ }
    })).Wait(TimeSpan.FromSeconds(3));

    public static async Task MessageLoopAsync(WebSocket ws, Guid userId, CancellationToken cancel)
    {
        var conn = new Connection(ws, userId);

        var state = await Storage.TryLoadStateAsync(userId, cancel);
        if(state?.User == null)
        {
            state = new State(null, conn.GenUser());
            await conn.SendProfileGeneratedAsync(state.User, cancel);
        }

        conn.State = state;
        Connections[ws] = conn;

        await conn.JoinAsync(state, true, cancel);

        while(!cancel.IsCancellationRequested && ws.State == WebSocketState.Open)
        {
            bool continuation = false;
            ValueWebSocketReceiveResult rcv;
            var buffer = MemoryPool<byte>.Shared.Rent(MaxInputMessageSize);
            try
            {
                rcv = await ws.ReceiveAsync(buffer.Memory, cancel);
                if(rcv.MessageType == WebSocketMessageType.Close)
                    break;
                if(!conn.ProfileLimit.TryIncrement(Connection.TotalLimitRpm, out var crossed))
                {
                    if(crossed) await conn.TrySendErrorAsync("Limit exceeded, wait 1 min", cancel);
                    continue;
                }
                continuation = true;
            }
            catch { break; }
            finally { if(!continuation) buffer.Dispose(); }

            _ = Task.Run(async () =>
            {
                try
                {
                    var cmd = JsonSerializer.Deserialize<Command>(buffer.Memory.Span.Slice(0, rcv.Count), JsonHelper.SerializerOptions);
                    if(cmd == default)
                        return;
                    if(conn.TryIncrement(cmd.Type, out var crossed) == false)
                    {
                        if(crossed) await conn.TrySendErrorAsync("Limit exceeded, wait 1 min", cancel);
                        return;
                    }
                    await conn.ExecuteCommandAsync(cmd, cancel);
                }
                catch { await conn.TrySendErrorAsync("Failed to process command", cancel); }
                finally { buffer.Dispose(); }
            }, cancel);
        }

        Connections.TryRemove(ws, out _);

        try { await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancel); }
        catch { /* ignored */ }
    }

    private static Task ExecuteCommandAsync(this Connection conn, Command cmd, CancellationToken cancel) => cmd.Type switch
    {
        MsgType.Close => conn.CloseSpace(cancel),
        MsgType.Join => conn.JoinAsync(cmd.Data, cancel),
        MsgType.Room => conn.JoinRoomAsync(cmd.Data, cancel),
        MsgType.Msg => conn.SendMessage(cmd.Data, cancel),
        MsgType.Generate => conn.GenProfile(cancel),
        _ => throw new Exception()
    };

    private static async Task CloseSpace(this Connection conn, CancellationToken cancel)
    {
        var state = conn.State;

        var space = state?.Context?.Space ?? 0UL;
        if(space == 0UL)
            return;

        await Storage.CloseSpace(space.ToBase58());

        await GetSpaceConnections(space).BroadcastAsync(state!, MsgType.Close, $"Closed space '{space.ToBase58()}', new members can no longer join", cancel);
    }

    private static User GenUser(this Connection conn)
    {
        var humanoid = Names.Humanoids[conn.Random.Value.Next(Names.Humanoids.Length)];
        var avatar = conn.Random.Value.CreateAvatar(out var color);
        return new($"{color} {humanoid}", avatar);
    }

    private static async Task GenProfile(this Connection conn, CancellationToken cancel)
    {
        var state = conn.State;
        if(state?.Context != null)
            return;

        var user = conn.GenUser();
        conn.State = new State(null, user);

        await conn.SendProfileGeneratedAsync(user, cancel);
    }

    private static async Task SendProfileGeneratedAsync(this Connection conn, User user, CancellationToken cancel)
    {
        await conn.TrySendMessageAsync(new Message
        {
            Type = MsgType.Generate,
            Text = "Generated new anonymous profile",
            Author = user.Name,
            Avatar = user.Avatar
        }, cancel);
    }

    private static async Task SendMessage(this Connection conn, string? text, CancellationToken cancel)
    {
        if(string.IsNullOrEmpty(text))
            return;

        var state = conn.State;
        if(state == null)
            return;

        var (ctx, user) = state;
        if(ctx == null)
            return;

        var msg = new Message
        {
            Type = MsgType.Msg,
            Author = user.Name,
            Avatar = user.Avatar,
            Text = text
        };

        await Storage.SaveMessageAsync(ctx.Space.ToBase58(), ctx.Room, msg, cancel);
        await GetSpaceConnections(ctx).BroadcastAsync(msg, cancel);
    }

    private static async Task JoinAsync(this Connection conn, string? value, CancellationToken cancel)
    {
        value = value?.Trim().NullIfEmpty();

        ulong space;
        if(value == null)
            Storage.CreateSpace((space = conn.RndSpace()).ToBase58());
        else if(!ContextHelper.TryParseSpace(value, out space) || !Storage.IsSpaceExists(value) || !Storage.HasAccess(value, conn.UserId))
        {
            await conn.TrySendErrorAsync("Space not exists or invalid or closed", cancel);
            return;
        }

        var user = await Storage.FindUserAsync(conn.UserId, space.ToBase58(), cancel);
        if(user == null)
            await Storage.AddUserToSpaceAsync(space.ToBase58(), conn.UserId, user = conn.State?.User ?? throw new InvalidOperationException(), cancel);

        var state = new State(new Context(space, null), user);
        await conn.JoinAsync(state, false, cancel);
    }

    private static async Task JoinAsync(this Connection conn, State state, bool init, CancellationToken cancel)
    {
        var ctx = state.Context;
        if(ctx == null)
            return;

        var old = conn.State?.Context;
        if(old != ctx)
            await Storage.SaveContextAsync(conn.UserId, ctx, cancel);

        conn.State = state;

        var space = ctx.Space.ToBase58();
        if(ctx.Room != null)
            Storage.CreateRoom(space, ctx.Room);

        await foreach(var msg in Storage.TryReadMessages(space, ctx.Room, cancel))
            await conn.TrySendMessageAsync(msg, cancel);

        await GetSpaceConnections(ctx.Space).Where(c => c == conn && init || old?.Space != ctx.Space).BroadcastAsync(state, MsgType.Join, $"Joined space '{space}'", cancel);
        await GetSpaceConnections(ctx.Space).Where(c => c == conn && init && ctx.Room != null || old?.Room != ctx.Room).BroadcastAsync(state, MsgType.Room, string.IsNullOrEmpty(ctx.Room) ? "Returned back to space root" : $"Entered room '{ctx.Room}'", cancel);
    }

    private static async Task JoinRoomAsync(this Connection conn, string? room, CancellationToken cancel)
    {
        var state = conn.State;
        if(state?.Context == null)
            return;

        room = room?.Trim().ToLower().NullIfEmpty();
        if(!ContextHelper.IsRoomValid(room))
        {
            await conn.TrySendErrorAsync("Invalid room, only ascii letters allowed", cancel);
            return;
        }

        state = state with { Context = state.Context with { Room = room } };
        await conn.JoinAsync(state, false, cancel);
    }

    private static Task BroadcastAsync(this IEnumerable<Connection> connections, State state, MsgType type, string text, CancellationToken cancel) => connections.BroadcastAsync(new Message
    {
        Context = state.Context?.ToStringValue(),
        Type = type,
        Text = text,
        Author = state.User.Name,
        Avatar = state.User.Avatar
    }, cancel);

    private static IEnumerable<Connection> GetSpaceConnections(Context? ctx)
        => ctx == null ? Enumerable.Empty<Connection>() : Connections.Select(pair => pair.Value).Where(c => c.State?.Context == ctx);
    private static IEnumerable<Connection> GetSpaceConnections(ulong? space)
        => space > 0UL ? Connections.Select(pair => pair.Value).Where(c => c.State?.Context?.Space == space) : Enumerable.Empty<Connection>();

    private static Task TrySendErrorAsync(this Connection conn, string text, CancellationToken cancel) => conn.TrySendMessageAsync(new Message
    {
        Type = MsgType.Error,
        Author = AvatarGen.SystemName,
        Avatar = AvatarGen.SystemAvatar,
        Text = text
    }, cancel);

    private static Task TrySendMessageAsync(this Connection conn, Message msg, CancellationToken cancel)
        => BroadcastAsync(EnumerableHelper.Yield(conn), msg, cancel);

    private static async Task BroadcastAsync(this IEnumerable<Connection> connections, Message msg, CancellationToken cancel)
    {
        using var buffer = MemoryPool<byte>.Shared.Rent(MaxOutputMessageSize);
        var serialized = SerializeMessage(msg, buffer.Memory);
        await Task.WhenAll(connections.Select(conn => TrySendMessageAsync(conn, serialized, cancel)));
    }

    private static ReadOnlyMemory<byte> SerializeMessage(Message msg, Memory<byte> memory)
    {
        var writer = new MemoryBufferWriter<byte>(memory);
        using var jsonWriter = new Utf8JsonWriter(writer);
        JsonSerializer.Serialize(jsonWriter, msg, JsonHelper.SerializerOptions);
        return writer.WrittenMemory;
    }

    private static async Task TrySendMessageAsync(Connection conn, ReadOnlyMemory<byte> memory, CancellationToken cancel)
    {
        await conn.WsSendSync.WaitAsync(cancel);
        try { await conn.Ws.SendAsync(memory, WebSocketMessageType.Text, true, cancel); }
        catch { /* ignored */ }
        finally { conn.WsSendSync.Release(); }
    }

    private static ulong RndSpace(this Connection conn)
        => unchecked((ulong)conn.Random.Value.NextInt64());

    private static readonly ConcurrentDictionary<WebSocket, Connection> Connections = new();
    private const int MaxOutputMessageSize = 4096;
    private const int MaxInputMessageSize = 256;
}

internal class Connection
{
    public Connection(WebSocket ws, Guid userId)
    {
        Ws = ws;
        UserId = userId;
    }

    public bool TryIncrement(MsgType type, out bool crossed)
    {
        crossed = false;
        var (limit, max) = type switch
        {
            MsgType.Msg => (MsgLimit, MsgLimitRpm),
            MsgType.Join or MsgType.Room or MsgType.Close => (CtxLimit, CtxLimitRpm),
            _ => (null, int.MaxValue)
        };
        return limit == null || limit.TryIncrement(max, out crossed);
    }

    public readonly WebSocket Ws;
    public readonly Guid UserId;

    public volatile State? State;

    public readonly SemaphoreSlim WsSendSync = new(1, 1);
    public readonly Lazy<Random> Random = new(() => new Random(Guid.NewGuid().GetHashCode()));

    public readonly RequestLimit CtxLimit = new();
    public readonly RequestLimit MsgLimit = new();
    public readonly RequestLimit ProfileLimit = new();

    public const int CtxLimitRpm = 6;
    public const int MsgLimitRpm = 60;
    public const int TotalLimitRpm = 600;
}