You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
9.1 KiB
190 lines
9.1 KiB
# Happy Trees Discord Bot
|
|
# A Discord bot to access a local Stable Diffusion model
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import asyncio
|
|
import time
|
|
import shlex
|
|
import re
|
|
import random
|
|
import discord
|
|
|
|
class HappyTreesBot(discord.Client):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.logger = logging.getLogger('discord.HappyTreesBot')
|
|
self.commissions = asyncio.Queue()
|
|
self.paintings = []
|
|
self.outpath = os.path.abspath(os.path.expanduser('~/happytreesbot/outputs'))
|
|
self.sdpath = os.path.abspath(os.path.expanduser('~/stable-diffusion'))
|
|
|
|
async def on_ready(self):
|
|
self.logger.info(f'Logged in to Discord as {self.user}!')
|
|
print(f'Invite me to your server at https://discord.com/oauth2/authorize?client_id={self.application_id}&scope=bot&permissions=35904')
|
|
|
|
async def on_message(self, message):
|
|
if message.author == self.user:
|
|
return
|
|
if message.content.startswith('!happytree'):
|
|
self.logger.info(f'Command message from {message.author} ({message.author.id}): {message.content}')
|
|
await self.take_commission(message)
|
|
elif message.content.startswith(f'<@{self.application_id}>'):
|
|
self.logger.info(f'Mention message from {message.author} ({message.author.id}): {message.content}')
|
|
await self.take_commission(message)
|
|
else:
|
|
self.logger.debug(f'Non-command message from {message.author} ({message.author.id}): {message.content}')
|
|
|
|
async def take_commission(self, message):
|
|
t = time.perf_counter()
|
|
parsed = shlex.split(message.content)[1:]
|
|
prompt = ""
|
|
currOpt = ""
|
|
seed = random.randint(0, 1000000)
|
|
samples = 4
|
|
steps = 50
|
|
for token in parsed:
|
|
if not token.startswith('--'):
|
|
if currOpt and token.isdigit():
|
|
token = int(token)
|
|
if currOpt == "seed":
|
|
if token <= 1000000:
|
|
seed = token
|
|
else:
|
|
await message.reply(f'Invalid seed value. Seed must be <= 1000000.')
|
|
return
|
|
elif currOpt == "n_samples":
|
|
if token <= 4:
|
|
samples = token
|
|
else:
|
|
await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.')
|
|
return
|
|
elif currOpt == "ddim_steps":
|
|
if token <= 80:
|
|
steps = token
|
|
else:
|
|
await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.')
|
|
return
|
|
currOpt = ""
|
|
else:
|
|
if not prompt:
|
|
prompt = token
|
|
else:
|
|
prompt = " ".join([prompt, token])
|
|
|
|
else:
|
|
opt = token[2:]
|
|
if "=" in opt:
|
|
optName, _, val = opt.partition('=')
|
|
try:
|
|
val = int(val)
|
|
except:
|
|
await message.reply(f'Invalid option value. Please use a positive integer.')
|
|
return
|
|
if optName == "seed":
|
|
if val <= 1000000:
|
|
seed = val
|
|
else:
|
|
await message.reply(f'Invalid seed value. Seed must be <= 1000000.')
|
|
return
|
|
elif optName == "n_samples":
|
|
if val <= 4:
|
|
samples = val
|
|
else:
|
|
await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.')
|
|
return
|
|
elif optName == "ddim_steps":
|
|
if val <= 80:
|
|
steps = val
|
|
else:
|
|
await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.')
|
|
return
|
|
else:
|
|
await message.reply(f'Invalid option name: {optName}. Valid options are "--ddim_steps", "--n_samples", and "--seed".')
|
|
return
|
|
else:
|
|
if opt in ["ddim_steps", "n_samples", "seed"]:
|
|
currOpt = opt
|
|
else:
|
|
await message.reply(f'Invalid option name: {opt}. Valid options are "--ddim_steps", "--n_samples", and "--seed".')
|
|
return
|
|
self.logger.info(f'Queueing request for {message.author}. Prompt: "{prompt}"; Samples: {samples}; Seed: {seed}; Steps: {steps}.')
|
|
self.commissions.put_nowait((t, message, prompt, samples, seed, steps))
|
|
position=self.commissions.qsize()
|
|
waittime=position * 6
|
|
await message.reply(f'{message.author} you are number {position} in the Happy Trees processing queue. Estimated wait time is {waittime} minutes.')
|
|
|
|
async def painting(self):
|
|
while True:
|
|
self.logger.debug(f'Bob Ross is ready to paint!')
|
|
t, message, prompt, samples, seed, steps = await self.commissions.get()
|
|
self.logger.info(f'Bob Ross is painting "{prompt}" for {message.author}.')
|
|
# This section copies the filename generation code from
|
|
# optimized_txt2img.py from optimizedSD
|
|
os.makedirs(self.outpath, exist_ok=True)
|
|
sample_path = os.path.join(self.outpath, "_".join(re.split(":| ", prompt)))[:150]
|
|
os.makedirs(sample_path, exist_ok=True)
|
|
base_count = len(os.listdir(sample_path))
|
|
outfile = os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.png")
|
|
self.logger.debug(f'Output file will be: {outfile}')
|
|
start = time.perf_counter()
|
|
self.logger.info(f'About to start the subprocess...')
|
|
proc = await asyncio.create_subprocess_exec('/usr/bin/python3','optimizedSD/optimized_txt2img.py', '--H', '448', '--W', '448', '--precision', 'full', '--outdir', self.outpath, '--n_iter', '1', '--n_samples', '1', '--ddim_steps', str(steps), '--seed', str(seed), '--prompt', str(prompt), stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.sdpath)
|
|
self.logger.info(f'Started SD subprocess, PID: {proc.pid}')
|
|
procout, procerr = await proc.communicate()
|
|
procoutput = procout.decode().strip()
|
|
self.logger.info(f'Process output: {procoutput}')
|
|
complete = time.perf_counter()
|
|
self.logger.info(f'Painting of "{prompt}" complete for {message.author}. Wait time: {start-t:0.5f}; Paint time: {complete-start:0.5f}; Total time: {complete-t:0.5f}.')
|
|
try:
|
|
await message.reply(file=discord.File(outfile, description=f'Prompt: "{prompt}"; Starting Seed: {seed}; Steps: {steps}'))
|
|
except:
|
|
await message.reply(f'An error occurred. Output file not found.')
|
|
raise
|
|
self.logger.debug(f'Reply sent')
|
|
self.commissions.task_done()
|
|
|
|
async def bobross(self):
|
|
while True:
|
|
self.logger.debug(f'Bob Ross is preparing a new canvas.')
|
|
self.paintings.append(asyncio.create_task(self.painting()))
|
|
self.logger.debug(f'Bob Ross is waiting on a commission.')
|
|
complaints = await asyncio.gather(*self.paintings, return_exceptions=True)
|
|
if complaints:
|
|
self.logger.warning(f'Bob had some complaints: {complaints}')
|
|
self.logger.info(f'All paintings complete.')
|
|
self.paintings = []
|
|
|
|
async def main():
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
bot = HappyTreesBot(intents=intents)
|
|
try:
|
|
token = os.getenv("HAPPYTREES_TOKEN")
|
|
if not token:
|
|
logger.debug(f'Getting token from file.')
|
|
with open(".token") as f:
|
|
token = f.read().replace('\n','')
|
|
await asyncio.gather(bot.start(str(token)),
|
|
bot.bobross())
|
|
finally:
|
|
for painting in bot.paintings:
|
|
logger.info(f'Cancelling in-progress paintings.')
|
|
painting.cancel()
|
|
await asyncio.gather(*bot.paintings, return_exceptions=True)
|
|
logger.info(f'All paintings have cancelled.')
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
handler = logging.FileHandler(filename='happytrees.log',
|
|
encoding='utf-8', mode='w')
|
|
formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}',
|
|
'%Y-%m-%d %H:%M:%S', style='{')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info(f'Keyboard interrupt received, Happy Trees Bot shutting down.')
|
|
sys.exit(0)
|
|
|