Skip to content

Commit

Permalink
Add backend support
Browse files Browse the repository at this point in the history
  • Loading branch information
RealStillkill committed May 30, 2024
1 parent 28da4b9 commit 1b2af2e
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 6 deletions.
22 changes: 22 additions & 0 deletions NovelAIBot/Extensions/WebsocketExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Net.WebSockets;
using System.Text;

namespace NovelAIBot.Extensions
{
public static class WebSocketExtensions
{
public static async Task SendTextMessageAsync(this WebSocket client, string text, Encoding encoding)
{
byte[] buffer = encoding.GetBytes(text);
await client.SendAsync(buffer, WebSocketMessageType.Text, true, CancellationToken.None);
}

public static async Task<string> ReceiveTextMessageAsync(this WebSocket client, int bufferSize, Encoding encoding)
{
byte[] buffer = new byte[bufferSize];
var result = await client.ReceiveAsync(buffer, CancellationToken.None);
Array.Resize(ref buffer, result.Count);
return encoding.GetString(buffer, 0, buffer.Length) ?? string.Empty;
}
}
}
29 changes: 29 additions & 0 deletions NovelAIBot/Models/BackendQueueStatus.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace NovelAIBot.Models
{
public class BackendQueueStatus
{
public Guid Id { get; set; }
public int QueuePosition { get; set; }

public NaiQueueState State { get; set; }

public BackendQueueStatus()
{
}

public BackendQueueStatus(Guid? id, int queuePosition, NaiQueueState state = NaiQueueState.Enqueued)
{
Id = id ?? Guid.Empty;
QueuePosition = queuePosition;
State = state;
}
}

public enum NaiQueueState { Enqueued, Processing, CompletedSuccess, CompletedError }
}
42 changes: 42 additions & 0 deletions NovelAIBot/Models/BackendRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Discord.Interactions;
using System.Text.Json.Serialization;


namespace NovelAIBot.Models
{
internal class BackendRequest : INaiRequest
{
public Guid Id { get; set; }
public string Prompt { get; private set; }
public string NegativePrompt { get; private set; }
public string AuthKey { get; private set; }

[JsonIgnore]
public SocketInteractionContext Context { get; set; }

public int Height { get; set; }

public int Width { get; set; }

public BackendRequest(string prompt, string negativePrompt, string authKey, int height, int width)
{
Id = Guid.Empty;
Prompt = prompt;
NegativePrompt = negativePrompt;
AuthKey = authKey;
Height = height;
Width = width;
}

public BackendRequest(string prompt, string negativePrompt, string authKey, SocketInteractionContext context, int height, int width)
{
Id = Guid.Empty;
Prompt = prompt;
NegativePrompt = negativePrompt;
AuthKey = authKey;
Context = context;
Height = height;
Width = width;
}
}
}
20 changes: 17 additions & 3 deletions NovelAIBot/Modules/PromptModule.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Discord;
using Discord.Interactions;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using NovelAIBot.Enums;
using NovelAIBot.Models;
Expand All @@ -15,13 +16,17 @@ namespace NovelAIBot.Modules
{
internal class PromptModule : InteractionModuleBase<SocketInteractionContext>
{
private string AuthKey { get => _configuration.GetRequiredSection("GenerationApi")["ApiKey"]; }

private readonly ILogger<PromptModule> _logger;
private readonly QueueService _queueService;
private readonly IConfiguration _configuration;

public PromptModule(ILogger<PromptModule> logger, QueueService queueService)
public PromptModule(ILogger<PromptModule> logger, QueueService queueService, IConfiguration configuration)
{
_logger = logger;
_queueService = queueService;
_configuration = configuration;
}

[SlashCommand("prompt", "Generates an image based on a prompt")]
Expand Down Expand Up @@ -67,8 +72,17 @@ public async Task Prompt(
height = 1216;
break;
}
NaiRequest request = new NaiRequest(prompt, negativePrompt, height, width, Context);
await _queueService.AddPromptToQueueAsync(request);

if (_configuration.GetRequiredSection("GenerationApi")["Mode"] == "Contained")
{
NaiRequest request = new NaiRequest(prompt, negativePrompt, height, width, Context);
await _queueService.AddPromptToQueueAsync(request);
}
else
{
BackendRequest request = new BackendRequest(prompt, negativePrompt, AuthKey, Context, height, width);
await _queueService.AddPromptToQueueAsync(request);
}
}

[ComponentInteraction("delete-image")]
Expand Down
1 change: 1 addition & 0 deletions NovelAIBot/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static async Task Main(string[] args)
}));
builder.Services.AddSingleton<QueueService>();
builder.Services.AddKeyedScoped<IGenerationService, NaiService>("Contained");
builder.Services.AddKeyedScoped<IGenerationService, BackendService>("Backend");


builder.Services.AddHostedService<DiscordService>();
Expand Down
68 changes: 68 additions & 0 deletions NovelAIBot/Services/BackendService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using NovelAIBot.Extensions;
using NovelAIBot.Models;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;

namespace NovelAIBot.Services
{
internal class BackendService : IGenerationService
{
public event EventHandler<BackendQueueStatus> BackendQueueStatusChanged;

private string WebSocketBaseAddress { get { return _configuration.GetSection("GenerationApi")["BackendUrl"]; } }
private string ImageBaseAddress { get { return _configuration.GetSection("GenerationApi")["BackendImageUrl"]; } }



private readonly IConfiguration _configuration;
private readonly ILogger<BackendService> _logger;
private readonly HttpClient _httpClient;

public BackendService(IConfiguration configuration, ILogger<BackendService> logger)
{
_configuration = configuration;
_logger = logger;
_httpClient = new HttpClient();
_httpClient.BaseAddress = new Uri(ImageBaseAddress);
}


public async Task<byte[]> GetImageBytesAsync(INaiRequest request)
{
using ClientWebSocket client = new ClientWebSocket();
await client.ConnectAsync(new Uri(WebSocketBaseAddress), CancellationToken.None);
string json = JsonSerializer.Serialize(request as BackendRequest);

await client.SendTextMessageAsync(json, Encoding.UTF8);
json = await client.ReceiveTextMessageAsync(1024 * 20, Encoding.UTF8);
BackendQueueStatus status = JsonSerializer.Deserialize<BackendQueueStatus>(json);

do
{
json = await client.ReceiveTextMessageAsync(1024 * 20, Encoding.UTF8);
BackendQueueStatus newStatus = JsonSerializer.Deserialize<BackendQueueStatus>(json);
if (newStatus.QueuePosition != status.QueuePosition || newStatus.State != status.State)
BackendQueueStatusChanged?.Invoke(this, newStatus);
status = newStatus;
} while (status.State == NaiQueueState.Enqueued || status.State == NaiQueueState.Processing);

if (status.State == NaiQueueState.CompletedError)
{
await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Request complete.", CancellationToken.None);
throw new Exception("Image generation completed in an error state");
}

await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Request complete.", CancellationToken.None);

return await _httpClient.GetByteArrayAsync($"/api/nai/getimage/{status.Id}");
}
}
}
53 changes: 50 additions & 3 deletions NovelAIBot/Services/QueueService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public async Task AddPromptToQueueAsync(INaiRequest request)
else
{
Queue.Enqueue(request);
await request.Context.Interaction.FollowupAsync($"Prompt job queued. {Queue.Count} prompts ahead. ~10 seconds per prompt\n{request.Prompt}");
await request.Context.Interaction.FollowupAsync($"Prompt job queued. {Queue.Count} prompts ahead.\n**Prompt:**{request.Prompt}");
}
}

Expand All @@ -82,14 +82,61 @@ private async Task StartJob(INaiRequest request)
this.IsBusy = true;
if (request is NaiRequest)
await StartContainedJob(request);
// if (request is BackendRequest)
// await StartBackendJob
if (request is BackendRequest)
await StartBackendJob(request);


this.IsBusy = false;
this.JobCompleted?.Invoke(this, request);
}

private async Task StartBackendJob(INaiRequest request)
{
try
{
using IServiceScope scope = _serviceProvider.CreateScope();
IGenerationService naiService = scope.ServiceProvider.GetRequiredKeyedService<IGenerationService>("Backend");

byte[] image = await naiService.GetImageBytesAsync(request);

FileAttachment attachment;
using (MemoryStream ms = new MemoryStream(image))
{
attachment = new FileAttachment(ms, "image.png");
EmbedBuilder embedBuilder = new EmbedBuilder()
.WithTitle("Text2Image Generation")
.WithAuthor(request.Context.User)
.WithCurrentTimestamp()
.WithImageUrl("attachment://image.png")
.WithFooter("nai-diffusion-v3")
.AddField("Prompt", request.Prompt);

if (!string.IsNullOrEmpty(request.NegativePrompt))
embedBuilder.AddField("Negative Prompt", request.NegativePrompt ?? "<No negative prompt>");

embedBuilder.AddField("Size", $"{request.Width}x{request.Height}");

await request.Context.Interaction.ModifyOriginalResponseAsync(x =>
{
x.Content = "";
x.Attachments = new List<FileAttachment> { attachment };
x.Embed = embedBuilder.Build();
x.Components = GetMessageButtons();
});
}
scope.Dispose();
}
catch (Exception ex)
{
_logger.LogError(ex, "An error occurred while generating an image");
await request.Context.Interaction.ModifyOriginalResponseAsync(x =>
{
x.Content = $"An error has occurred while processing your request:\n{ex.Message}";
});
}
}


private async Task StartContainedJob(INaiRequest request)
{
try
Expand Down

0 comments on commit 1b2af2e

Please sign in to comment.