Skip to content

Commit

Permalink
Adds DALLE2 back
Browse files Browse the repository at this point in the history
  • Loading branch information
alanedwardes committed May 21, 2024
1 parent f400e5d commit f1d2dd5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
52 changes: 38 additions & 14 deletions src/Runner.Discord/Responders/DalleResponder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Estranged.Automation.Runner.Discord.Events;
using OpenAI_API;
using OpenAI_API.Images;
using OpenAI_API.Models;
using System;
using System.Linq;
using System.Net.Http;
Expand Down Expand Up @@ -30,36 +31,59 @@ public async Task ProcessMessage(IMessage message, CancellationToken token)
return;
}

const string trigger = "dalle";
if (!message.Content.StartsWith(trigger, StringComparison.InvariantCultureIgnoreCase))
{
return;
}

if (_featureFlags.ShouldResetDalleAttempts())
{
// Refresh the bucket since time moved on
_featureFlags.ResetDalleAttempts();
}

int dalleLimit = 10;
if (_featureFlags.ShouldResetDalleHqAttempts())
{
// Refresh the bucket since time moved on
_featureFlags.ResetDalleHqAttempts();
}

if (_featureFlags.DalleAttempts.Count >= dalleLimit)
const string dalle2Trigger = "dalle ";
if (message.Content.StartsWith(dalle2Trigger, StringComparison.InvariantCultureIgnoreCase))
{
await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions());
const int dalle2Limit = 10;
if (_featureFlags.DalleAttempts.Count >= dalle2Limit)
{
await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions());
return;
}

_featureFlags.DalleAttempts.Count++;
await RequestImage(message, _featureFlags.DalleAttempts.Count, dalle2Limit, dalle2Trigger.Length, Model.DALLE2, ImageSize._256, token);
return;
}

var prompt = message.Content[trigger.Length..].Trim();
const string dalle3Trigger = "dalle3 ";
if (message.Content.StartsWith(dalle3Trigger, StringComparison.InvariantCultureIgnoreCase))
{
const int dalle3Limit = 1;
if (_featureFlags.DalleHqAttempts.Count >= dalle3Limit)
{
await message.Channel.SendMessageAsync("wait until the next day", options: token.ToRequestOptions());
return;
}

_featureFlags.DalleHqAttempts.Count++;
await RequestImage(message, _featureFlags.DalleHqAttempts.Count, dalle3Limit, dalle3Trigger.Length, Model.DALLE3, ImageSize._1024, token);
return;
}
}

_featureFlags.DalleAttempts.Count++;
private async Task RequestImage(IMessage message, int dalleAttempts, int dalleLimit, int initialMessagePrefixLength, Model model, ImageSize size, CancellationToken token)
{
var prompt = message.Content[initialMessagePrefixLength..].Trim();

using (message.Channel.EnterTypingState())
{
var response = await _openAi.ImageGenerations.CreateImageAsync(new ImageGenerationRequest
{
Model = "dall-e-3",
Size = ImageSize._1024,
Size = size,
Model = model,
NumOfImages = 1,
ResponseFormat = ImageResponseFormat.Url,
Prompt = prompt
Expand All @@ -70,7 +94,7 @@ public async Task ProcessMessage(IMessage message, CancellationToken token)
using var httpClient = _httpClientFactory.CreateClient(DiscordHttpClientConstants.RESPONDER_CLIENT);
using var image = await httpClient.GetStreamAsync(result.Url);

await message.Channel.SendFileAsync(image, $"{Guid.NewGuid()}.png", $"{_featureFlags.DalleAttempts.Count}/{dalleLimit}", messageReference: new MessageReference(message.Id), options: token.ToRequestOptions());
await message.Channel.SendFileAsync(image, $"{Guid.NewGuid()}.png", $"{dalleAttempts}/{dalleLimit}", messageReference: new MessageReference(message.Id), options: token.ToRequestOptions());
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/Runner.Discord/Responders/FeatureFlagResponder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public sealed class AttemptsBucket

public FeatureFlagResponder()
{
ResetDalleHqAttempts();
ResetDalleAttempts();
ResetGptAttempts();
}
Expand Down Expand Up @@ -46,8 +47,21 @@ private DateTime CurrentGptBucket
public void ResetGptAttempts() => GptAttempts = new AttemptsBucket(CurrentGptBucket);
public AttemptsBucket GptAttempts { get; private set; }

private DateTime CurrentDalleHqBucket
{
get
{
var now = DateTime.UtcNow;
return new DateTime(now.Year, now.Month, now.Day, 0, 0, 0, DateTimeKind.Utc);
}
}
public bool ShouldResetDalleHqAttempts() => DalleHqAttempts.Bucket != CurrentDalleHqBucket;
public void ResetDalleHqAttempts() => DalleHqAttempts = new AttemptsBucket(CurrentDalleHqBucket);
public AttemptsBucket DalleHqAttempts { get; private set; }

public bool IsAiEnabled { get; private set; } = true;


public Task ProcessMessage(IMessage message, CancellationToken token)
{
if (message.Author.Id != 269883106792701952)
Expand All @@ -65,6 +79,7 @@ public Task ProcessMessage(IMessage message, CancellationToken token)
if (message.Content == "ff dalle reset")
{
ResetDalleAttempts();
ResetDalleHqAttempts();
message.Channel.SendMessageAsync($"Reset dalle attempts");
return Task.CompletedTask;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Runner.Discord/Responders/GptResponder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public async Task ProcessMessage(IMessage originalMessage, CancellationToken tok

if (_featureFlags.GptAttempts.Count >= 100)
{
await initialMessage.Channel.SendMessageAsync("wait until the next hour", options: token.ToRequestOptions());
// Ensure only 100 attempts per hour
return;
}

Expand Down
3 changes: 3 additions & 0 deletions src/Runner.Discord/Responders/IFeatureFlags.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
{
internal interface IFeatureFlags
{
FeatureFlagResponder.AttemptsBucket DalleHqAttempts { get; }
FeatureFlagResponder.AttemptsBucket DalleAttempts { get; }
FeatureFlagResponder.AttemptsBucket GptAttempts { get; }
bool IsAiEnabled { get; }

void ResetDalleHqAttempts();
void ResetDalleAttempts();
void ResetGptAttempts();
bool ShouldResetDalleHqAttempts();
bool ShouldResetDalleAttempts();
bool ShouldResetGptAttempts();
}
Expand Down

0 comments on commit f1d2dd5

Please sign in to comment.