A Stable Diffusion Discord bot.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

190 lines
9.1 KiB

# Happy Trees Discord Bot
# A Discord bot to access a local Stable Diffusion model
import os
import sys
import logging
import asyncio
import time
import shlex
import re
import random
import discord
class HappyTreesBot(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = logging.getLogger('discord.HappyTreesBot')
self.commissions = asyncio.Queue()
self.paintings = []
self.outpath = os.path.abspath(os.path.expanduser('~/happytreesbot/outputs'))
self.sdpath = os.path.abspath(os.path.expanduser('~/stable-diffusion'))
async def on_ready(self):
self.logger.info(f'Logged in to Discord as {self.user}!')
print(f'Invite me to your server at https://discord.com/oauth2/authorize?client_id={self.application_id}&scope=bot&permissions=35904')
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!happytree'):
self.logger.info(f'Command message from {message.author} ({message.author.id}): {message.content}')
await self.take_commission(message)
elif message.content.startswith(f'<@{self.application_id}>'):
self.logger.info(f'Mention message from {message.author} ({message.author.id}): {message.content}')
await self.take_commission(message)
else:
self.logger.debug(f'Non-command message from {message.author} ({message.author.id}): {message.content}')
async def take_commission(self, message):
t = time.perf_counter()
parsed = shlex.split(message.content)[1:]
prompt = ""
currOpt = ""
seed = random.randint(0, 1000000)
samples = 4
steps = 50
for token in parsed:
if not token.startswith('--'):
if currOpt and token.isdigit():
token = int(token)
if currOpt == "seed":
if token <= 1000000:
seed = token
else:
await message.reply(f'Invalid seed value. Seed must be <= 1000000.')
return
elif currOpt == "n_samples":
if token <= 4:
samples = token
else:
await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.')
return
elif currOpt == "ddim_steps":
if token <= 80:
steps = token
else:
await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.')
return
currOpt = ""
else:
if not prompt:
prompt = token
else:
prompt = " ".join([prompt, token])
else:
opt = token[2:]
if "=" in opt:
optName, _, val = opt.partition('=')
try:
val = int(val)
except:
await message.reply(f'Invalid option value. Please use a positive integer.')
return
if optName == "seed":
if val <= 1000000:
seed = val
else:
await message.reply(f'Invalid seed value. Seed must be <= 1000000.')
return
elif optName == "n_samples":
if val <= 4:
samples = val
else:
await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.')
return
elif optName == "ddim_steps":
if val <= 80:
steps = val
else:
await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.')
return
else:
await message.reply(f'Invalid option name: {optName}. Valid options are "--ddim_steps", "--n_samples", and "--seed".')
return
else:
if opt in ["ddim_steps", "n_samples", "seed"]:
currOpt = opt
else:
await message.reply(f'Invalid option name: {opt}. Valid options are "--ddim_steps", "--n_samples", and "--seed".')
return
self.logger.info(f'Queueing request for {message.author}. Prompt: "{prompt}"; Samples: {samples}; Seed: {seed}; Steps: {steps}.')
self.commissions.put_nowait((t, message, prompt, samples, seed, steps))
position=self.commissions.qsize()
waittime=position * 6
await message.reply(f'{message.author} you are number {position} in the Happy Trees processing queue. Estimated wait time is {waittime} minutes.')
async def painting(self):
while True:
self.logger.debug(f'Bob Ross is ready to paint!')
t, message, prompt, samples, seed, steps = await self.commissions.get()
self.logger.info(f'Bob Ross is painting "{prompt}" for {message.author}.')
# This section copies the filename generation code from
# optimized_txt2img.py from optimizedSD
os.makedirs(self.outpath, exist_ok=True)
sample_path = os.path.join(self.outpath, "_".join(re.split(":| ", prompt)))[:150]
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
outfile = os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.png")
self.logger.debug(f'Output file will be: {outfile}')
start = time.perf_counter()
self.logger.info(f'About to start the subprocess...')
proc = await asyncio.create_subprocess_exec('/usr/bin/python3','optimizedSD/optimized_txt2img.py', '--H', '448', '--W', '448', '--precision', 'full', '--outdir', self.outpath, '--n_iter', '1', '--n_samples', '1', '--ddim_steps', str(steps), '--seed', str(seed), '--prompt', str(prompt), stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.sdpath)
self.logger.info(f'Started SD subprocess, PID: {proc.pid}')
procout, procerr = await proc.communicate()
procoutput = procout.decode().strip()
self.logger.info(f'Process output: {procoutput}')
complete = time.perf_counter()
self.logger.info(f'Painting of "{prompt}" complete for {message.author}. Wait time: {start-t:0.5f}; Paint time: {complete-start:0.5f}; Total time: {complete-t:0.5f}.')
try:
await message.reply(file=discord.File(outfile, description=f'Prompt: "{prompt}"; Starting Seed: {seed}; Steps: {steps}'))
except:
await message.reply(f'An error occurred. Output file not found.')
raise
self.logger.debug(f'Reply sent')
self.commissions.task_done()
async def bobross(self):
while True:
self.logger.debug(f'Bob Ross is preparing a new canvas.')
self.paintings.append(asyncio.create_task(self.painting()))
self.logger.debug(f'Bob Ross is waiting on a commission.')
complaints = await asyncio.gather(*self.paintings, return_exceptions=True)
if complaints:
self.logger.warning(f'Bob had some complaints: {complaints}')
self.logger.info(f'All paintings complete.')
self.paintings = []
async def main():
intents = discord.Intents.default()
intents.message_content = True
bot = HappyTreesBot(intents=intents)
try:
token = os.getenv("HAPPYTREES_TOKEN")
if not token:
logger.debug(f'Getting token from file.')
with open(".token") as f:
token = f.read().replace('\n','')
await asyncio.gather(bot.start(str(token)),
bot.bobross())
finally:
for painting in bot.paintings:
logger.info(f'Cancelling in-progress paintings.')
painting.cancel()
await asyncio.gather(*bot.paintings, return_exceptions=True)
logger.info(f'All paintings have cancelled.')
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.FileHandler(filename='happytrees.log',
encoding='utf-8', mode='w')
formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}',
'%Y-%m-%d %H:%M:%S', style='{')
handler.setFormatter(formatter)
logger.addHandler(handler)
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info(f'Keyboard interrupt received, Happy Trees Bot shutting down.')
sys.exit(0)