# Happy Trees Discord Bot # A Discord bot to access a local Stable Diffusion model import os import sys import logging import asyncio import time import shlex import re import random import discord class HappyTreesBot(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.logger = logging.getLogger('discord.HappyTreesBot') self.commissions = asyncio.Queue() self.paintings = [] self.outpath = os.path.abspath(os.path.expanduser('~/happytreesbot/outputs')) self.sdpath = os.path.abspath(os.path.expanduser('~/stable-diffusion')) async def on_ready(self): self.logger.info(f'Logged in to Discord as {self.user}!') print(f'Invite me to your server at https://discord.com/oauth2/authorize?client_id={self.application_id}&scope=bot&permissions=35904') async def on_message(self, message): if message.author == self.user: return if message.content.startswith('!happytree'): self.logger.info(f'Command message from {message.author} ({message.author.id}): {message.content}') await self.take_commission(message) elif message.content.startswith(f'<@{self.application_id}>'): self.logger.info(f'Mention message from {message.author} ({message.author.id}): {message.content}') await self.take_commission(message) else: self.logger.debug(f'Non-command message from {message.author} ({message.author.id}): {message.content}') async def take_commission(self, message): t = time.perf_counter() parsed = shlex.split(message.content)[1:] prompt = "" currOpt = "" seed = random.randint(0, 1000000) samples = 4 steps = 50 for token in parsed: if not token.startswith('--'): if currOpt and token.isdigit(): token = int(token) if currOpt == "seed": if token <= 1000000: seed = token else: await message.reply(f'Invalid seed value. Seed must be <= 1000000.') return elif currOpt == "n_samples": if token <= 4: samples = token else: await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.') return elif currOpt == "ddim_steps": if token <= 80: steps = token else: await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.') return currOpt = "" else: if not prompt: prompt = token else: prompt = " ".join([prompt, token]) else: opt = token[2:] if "=" in opt: optName, _, val = opt.partition('=') try: val = int(val) except: await message.reply(f'Invalid option value. Please use a positive integer.') return if optName == "seed": if val <= 1000000: seed = val else: await message.reply(f'Invalid seed value. Seed must be <= 1000000.') return elif optName == "n_samples": if val <= 4: samples = val else: await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.') return elif optName == "ddim_steps": if val <= 80: steps = val else: await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.') return else: await message.reply(f'Invalid option name: {optName}. Valid options are "--ddim_steps", "--n_samples", and "--seed".') return else: if opt in ["ddim_steps", "n_samples", "seed"]: currOpt = opt else: await message.reply(f'Invalid option name: {opt}. Valid options are "--ddim_steps", "--n_samples", and "--seed".') return self.logger.info(f'Queueing request for {message.author}. Prompt: "{prompt}"; Samples: {samples}; Seed: {seed}; Steps: {steps}.') self.commissions.put_nowait((t, message, prompt, samples, seed, steps)) position=self.commissions.qsize() waittime=position * 6 await message.reply(f'{message.author} you are number {position} in the Happy Trees processing queue. Estimated wait time is {waittime} minutes.') async def painting(self): while True: self.logger.debug(f'Bob Ross is ready to paint!') t, message, prompt, samples, seed, steps = await self.commissions.get() self.logger.info(f'Bob Ross is painting "{prompt}" for {message.author}.') # This section copies the filename generation code from # optimized_txt2img.py from optimizedSD os.makedirs(self.outpath, exist_ok=True) sample_path = os.path.join(self.outpath, "_".join(re.split(":| ", prompt)))[:150] os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) outfile = os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.png") self.logger.debug(f'Output file will be: {outfile}') start = time.perf_counter() self.logger.info(f'About to start the subprocess...') proc = await asyncio.create_subprocess_exec('/usr/bin/python3','optimizedSD/optimized_txt2img.py', '--H', '448', '--W', '448', '--precision', 'full', '--outdir', self.outpath, '--n_iter', '1', '--n_samples', '1', '--ddim_steps', str(steps), '--seed', str(seed), '--prompt', str(prompt), stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.sdpath) self.logger.info(f'Started SD subprocess, PID: {proc.pid}') procout, procerr = await proc.communicate() procoutput = procout.decode().strip() self.logger.info(f'Process output: {procoutput}') complete = time.perf_counter() self.logger.info(f'Painting of "{prompt}" complete for {message.author}. Wait time: {start-t:0.5f}; Paint time: {complete-start:0.5f}; Total time: {complete-t:0.5f}.') try: await message.reply(file=discord.File(outfile, description=f'Prompt: "{prompt}"; Starting Seed: {seed}; Steps: {steps}')) except: await message.reply(f'An error occurred. Output file not found.') raise self.logger.debug(f'Reply sent') self.commissions.task_done() async def bobross(self): while True: self.logger.debug(f'Bob Ross is preparing a new canvas.') self.paintings.append(asyncio.create_task(self.painting())) self.logger.debug(f'Bob Ross is waiting on a commission.') complaints = await asyncio.gather(*self.paintings, return_exceptions=True) if complaints: self.logger.warning(f'Bob had some complaints: {complaints}') self.logger.info(f'All paintings complete.') self.paintings = [] async def main(): intents = discord.Intents.default() intents.message_content = True bot = HappyTreesBot(intents=intents) try: token = os.getenv("HAPPYTREES_TOKEN") if not token: logger.debug(f'Getting token from file.') with open(".token") as f: token = f.read().replace('\n','') await asyncio.gather(bot.start(str(token)), bot.bobross()) finally: for painting in bot.paintings: logger.info(f'Cancelling in-progress paintings.') painting.cancel() await asyncio.gather(*bot.paintings, return_exceptions=True) logger.info(f'All paintings have cancelled.') logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.FileHandler(filename='happytrees.log', encoding='utf-8', mode='w') formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}', '%Y-%m-%d %H:%M:%S', style='{') handler.setFormatter(formatter) logger.addHandler(handler) try: asyncio.run(main()) except KeyboardInterrupt: logger.info(f'Keyboard interrupt received, Happy Trees Bot shutting down.') sys.exit(0)