From f1d2dd5004633b5758bf44f7de48f6553fd9862e Mon Sep 17 00:00:00 2001 From: Alan Edwardes Date: Tue, 21 May 2024 21:35:15 +0100 Subject: [PATCH] Adds DALLE2 back --- .../Responders/DalleResponder.cs | 52 ++++++++++++++----- .../Responders/FeatureFlagResponder.cs | 15 ++++++ src/Runner.Discord/Responders/GptResponder.cs | 2 +- .../Responders/IFeatureFlags.cs | 3 ++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/Runner.Discord/Responders/DalleResponder.cs b/src/Runner.Discord/Responders/DalleResponder.cs index 5236b7d..831d9e4 100644 --- a/src/Runner.Discord/Responders/DalleResponder.cs +++ b/src/Runner.Discord/Responders/DalleResponder.cs @@ -2,6 +2,7 @@ using Estranged.Automation.Runner.Discord.Events; using OpenAI_API; using OpenAI_API.Images; +using OpenAI_API.Models; using System; using System.Linq; using System.Net.Http; @@ -30,36 +31,59 @@ public async Task ProcessMessage(IMessage message, CancellationToken token) return; } - const string trigger = "dalle"; - if (!message.Content.StartsWith(trigger, StringComparison.InvariantCultureIgnoreCase)) - { - return; - } - if (_featureFlags.ShouldResetDalleAttempts()) { // Refresh the bucket since time moved on _featureFlags.ResetDalleAttempts(); } - int dalleLimit = 10; + if (_featureFlags.ShouldResetDalleHqAttempts()) + { + // Refresh the bucket since time moved on + _featureFlags.ResetDalleHqAttempts(); + } - if (_featureFlags.DalleAttempts.Count >= dalleLimit) + const string dalle2Trigger = "dalle "; + if (message.Content.StartsWith(dalle2Trigger, StringComparison.InvariantCultureIgnoreCase)) { - await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions()); + const int dalle2Limit = 10; + if (_featureFlags.DalleAttempts.Count >= dalle2Limit) + { + await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions()); + return; + } + + _featureFlags.DalleAttempts.Count++; + await RequestImage(message, _featureFlags.DalleAttempts.Count, dalle2Limit, dalle2Trigger.Length, Model.DALLE2, ImageSize._256, token); return; } - var prompt = message.Content[trigger.Length..].Trim(); + const string dalle3Trigger = "dalle3 "; + if (message.Content.StartsWith(dalle3Trigger, StringComparison.InvariantCultureIgnoreCase)) + { + const int dalle3Limit = 1; + if (_featureFlags.DalleHqAttempts.Count >= dalle3Limit) + { + await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions()); + return; + } + + _featureFlags.DalleHqAttempts.Count++; + await RequestImage(message, _featureFlags.DalleHqAttempts.Count, dalle3Limit, dalle3Trigger.Length, Model.DALLE3, ImageSize._1024, token); + return; + } + } - _featureFlags.DalleAttempts.Count++; + private async Task RequestImage(IMessage message, int dalleAttempts, int dalleLimit, int initialMessagePrefixLength, Model model, ImageSize size, CancellationToken token) + { + var prompt = message.Content[initialMessagePrefixLength..].Trim(); using (message.Channel.EnterTypingState()) { var response = await _openAi.ImageGenerations.CreateImageAsync(new ImageGenerationRequest { - Model = "dall-e-3", - Size = ImageSize._1024, + Size = size, + Model = model, NumOfImages = 1, ResponseFormat = ImageResponseFormat.Url, Prompt = prompt @@ -70,7 +94,7 @@ public async Task ProcessMessage(IMessage message, CancellationToken token) using var httpClient = _httpClientFactory.CreateClient(DiscordHttpClientConstants.RESPONDER_CLIENT); using var image = await httpClient.GetStreamAsync(result.Url); - await message.Channel.SendFileAsync(image, $"{Guid.NewGuid()}.png", $"{_featureFlags.DalleAttempts.Count}/{dalleLimit}", messageReference: new MessageReference(message.Id), options: token.ToRequestOptions()); + await message.Channel.SendFileAsync(image, $"{Guid.NewGuid()}.png", $"{dalleAttempts}/{dalleLimit}", messageReference: new MessageReference(message.Id), options: token.ToRequestOptions()); } } } diff --git a/src/Runner.Discord/Responders/FeatureFlagResponder.cs b/src/Runner.Discord/Responders/FeatureFlagResponder.cs index 979d537..dd97f78 100644 --- a/src/Runner.Discord/Responders/FeatureFlagResponder.cs +++ b/src/Runner.Discord/Responders/FeatureFlagResponder.cs @@ -18,6 +18,7 @@ public sealed class AttemptsBucket public FeatureFlagResponder() { + ResetDalleHqAttempts(); ResetDalleAttempts(); ResetGptAttempts(); } @@ -46,8 +47,21 @@ private DateTime CurrentGptBucket public void ResetGptAttempts() => GptAttempts = new AttemptsBucket(CurrentGptBucket); public AttemptsBucket GptAttempts { get; private set; } + private DateTime CurrentDalleHqBucket + { + get + { + var now = DateTime.UtcNow; + return new DateTime(now.Year, now.Month, now.Day, 0, 0, 0, DateTimeKind.Utc); + } + } + public bool ShouldResetDalleHqAttempts() => DalleHqAttempts.Bucket != CurrentDalleHqBucket; + public void ResetDalleHqAttempts() => DalleHqAttempts = new AttemptsBucket(CurrentDalleHqBucket); + public AttemptsBucket DalleHqAttempts { get; private set; } + public bool IsAiEnabled { get; private set; } = true; + public Task ProcessMessage(IMessage message, CancellationToken token) { if (message.Author.Id != 269883106792701952) @@ -65,6 +79,7 @@ public Task ProcessMessage(IMessage message, CancellationToken token) if (message.Content == "ff dalle reset") { ResetDalleAttempts(); + ResetDalleHqAttempts(); message.Channel.SendMessageAsync($"Reset dalle attempts"); return Task.CompletedTask; } diff --git a/src/Runner.Discord/Responders/GptResponder.cs b/src/Runner.Discord/Responders/GptResponder.cs index 9ebc3a7..0d4de7d 100644 --- a/src/Runner.Discord/Responders/GptResponder.cs +++ b/src/Runner.Discord/Responders/GptResponder.cs @@ -75,7 +75,7 @@ public async Task ProcessMessage(IMessage originalMessage, CancellationToken tok if (_featureFlags.GptAttempts.Count >= 100) { - await initialMessage.Channel.SendMessageAsync("wait until the next hour", options: token.ToRequestOptions()); + // Ensure only 100 attempts per hour return; } diff --git a/src/Runner.Discord/Responders/IFeatureFlags.cs b/src/Runner.Discord/Responders/IFeatureFlags.cs index 29acbf3..6c255c8 100644 --- a/src/Runner.Discord/Responders/IFeatureFlags.cs +++ b/src/Runner.Discord/Responders/IFeatureFlags.cs @@ -2,12 +2,15 @@ { internal interface IFeatureFlags { + FeatureFlagResponder.AttemptsBucket DalleHqAttempts { get; } FeatureFlagResponder.AttemptsBucket DalleAttempts { get; } FeatureFlagResponder.AttemptsBucket GptAttempts { get; } bool IsAiEnabled { get; } + void ResetDalleHqAttempts(); void ResetDalleAttempts(); void ResetGptAttempts(); + bool ShouldResetDalleHqAttempts(); bool ShouldResetDalleAttempts(); bool ShouldResetGptAttempts(); }