# Happy Trees Discord Bot # A Discord bot to access a local Stable Diffusion model import os import sys import logging import asyncio import time import re import random import hashlib import tempfile import discord from discord.ext import commands import sqlite3 from contextlib import closing class EnlargeButton(discord.ui.Button): def __init__(self, relfile): super().__init__(style=discord.ButtonStyle.secondary, label='Embiggen', custom_id=relfile) self.logger = logging.getLogger('discord.EnlargeButton') async def callback(self, interaction): assert self.view is not None view = self.view t = time.perf_counter() self.logger.info(f'Queueing request for {interaction.user} to embiggen {self.custom_id}.') interaction.client.commissions.put_nowait((t, 'embiggen', interaction.message, self.custom_id, None, None, None, None)) queuelen = interaction.client.commissions.qsize() remain = interaction.client.currtasktime - time.time() if remain < 0: remain = 0 waittime = (remain / 60) + (queuelen * 7) - 6.5 await interaction.response.send_message(f'{interaction.user} you are number {queuelen} in the Happy Trees processing queue. Estimated wait time is {waittime:0.2f} minutes.') class EnlargeView(discord.ui.View): def __init__(self, outfile=None, outpath=None): super().__init__(timeout=None) if not outfile: return self.logger = logging.getLogger('discord.EnlargeView') relfile = os.path.relpath(outfile, outpath) self.logger.info(f'Creating embiggen button for {relfile}.') self.add_item(EnlargeButton(relfile)) class GridButton(discord.ui.Button): def __init__(self, sample, samplefile, row): super().__init__(style=discord.ButtonStyle.secondary, label=f'Image {sample + 1}', custom_id=samplefile, row=row) self.logger = logging.getLogger('discord.GridButton') async def callback(self, interaction): assert self.view is not None self.logger.info(f'Received a request to return single file {self.custom_id}') samplefile = os.path.join(interaction.client.outpath, self.custom_id) try: await interaction.response.send_message(file=discord.File(samplefile, description=f'Sample {self.custom_id}.'), view=EnlargeView(samplefile, interaction.client.outpath)) except: await interaction.response.send_message(f'An error occurred. Sample {self.custom_id} not found.') class GridSelect(discord.ui.View): def __init__(self, samples=None, n_rows=None, outfile=None, outpath=None): super().__init__(timeout=None) if not samples: return self.logger = logging.getLogger('discord.GridSelect') relfile = os.path.relpath(outfile, outpath) relpath, gridfile = os.path.split(relfile) _, seed, base = os.path.splitext(gridfile)[0].split('_') per_row = samples // n_rows if (samples % n_rows) != 0: per_row += 1 self.logger.info(f'Creating buttons for GridView for {outfile}.') for sample in range(samples): sampleseed = int(seed)+sample samplebase = int(base)+sample samplefile = os.path.join(relpath, f'seed_{sampleseed}_{samplebase:05}.png') row=sample//per_row self.add_item(GridButton(sample, samplefile, row)) class HappyTreesBot(commands.Bot): def __init__(self): intents = discord.Intents.default() intents.message_content = True super().__init__(command_prefix=commands.when_mentioned_or('!happytree'), intents=intents) self.logger = logging.getLogger('discord.HappyTreesBot') self.commissions = asyncio.Queue() self.paintings = [] self.currtasktime = 0 self.outpath = os.path.abspath(os.path.expanduser('~/happytreesbot/outputs')) self.sdpath = os.path.abspath(os.path.expanduser('~/stable-diffusion')) self.ganpath = os.path.abspath(os.path.expanduser('~/real-ESRGAN')) with closing(sqlite3.connect("users.db")) as conn: with closing(conn.cursor()) as curs: curs.execute("CREATE TABLE IF NOT EXISTS users(id integer PRIMARY KEY, first_use integer, last_use integer)") async def setup_hook(self) -> None: self.add_view(GridSelect()) self.add_view(EnlargeView()) async def on_ready(self): self.logger.info(f'Logged in to Discord as {self.user}!') print(f'Invite me to your server at https://discord.com/oauth2/authorize?client_id={self.application_id}&scope=bot&permissions=35904') async def on_message(self, message): if message.author == self.user: return if message.content.startswith('!happytree'): self.logger.info(f'Command message from {message.author} ({message.author.id}): {message.content}') await self.take_commission(message) elif message.content.startswith(f'<@{self.application_id}>'): self.logger.info(f'Mention message from {message.author} ({message.author.id}): {message.content}') await self.take_commission(message) elif isinstance(message.channel, discord.channel.DMChannel): self.logger.info(f'Direct message from {message.author} ({message.author.id}): {message.content}') await self.take_commission(message) else: self.logger.debug(f'Non-command message from {message.author} ({message.author.id}): {message.content}') async def usage(self, message): await message.reply(f'Welcome to the Happy Trees Bot. I take your requests and draw them using an old GTX 980, the Stable Diffusion image generation algorithm and the real-ESRGAN upscaling algorithm. You can commission an image by sending me a _DM_, sending a message that starts by _@mentioning_ me in a channel that I\'m in, or prefixing your message with _!happytree_ in a channel that I\'m in.\n\nI have two main modes of operation:\n **txt2img:** This is my default mode and here I\'ll interpret most text as part of a prompt and generate an image based on that.\n **img2img:** You can use this by attaching an image to your message and I\'ll try to use your image as a guide for my output.\n\nThere are also certain special options you can pass me to affect how I work. These are:\n *--seed [0-1000000]* will use a specific number instead of a random seed for the generation process.\n *--n_samples [1-4]* will determine how many sample images I make for you. **Default is 4.**\n *--ddim_steps [0-80]* will cause me to spend more or less compute time on the image. **Default is 50.**\n *--strength [0.00-1.00]* will set how much liberty I should take in deviating from your input image, with 1 being to basically ignore the input image. **Default is 0.75.**\n\nFor more details and examples, please check out the README file in the Happy Trees Bot Git repository: .') return async def license(self, message): await message.author.send(f'Hi there! This seems to be your first time using the Happy Trees Bot, so I just want to be sure that you know how to use the bot and are aware of the license restrictions that the the bot operates under from the Stable Diffusion library.\n\n**NOTE:** All requests run one at a time because I only have one GPU. Please be respectful of others who might want to request some images and don\'t queue up too much art all at once. *Thank you*\n\n You can interact with the bot via DM, or via a message that starts with an @mention of the bot in a channel that the bot is in, or via a message that starts with !happytree in a channel that the bot is in. For full details on how to use the bot, use any of those methods with the command "help".\n\nAll use of the Happy Trees Bot is subject to the Stable Diffusion license: , including _Attachment A: Use Restrictions_. If you **do not agree** to the terms of that license, please **discontinue use** of the Happy Trees Bot. You should read the license for full details, but my short summart here is: don\'t be a jerk, don\'t be a creep, and don\'t break the law. Hopefully you can manage that.') return async def take_commission(self, message): t = time.perf_counter() parsed = message.content.split() tags = 0 for tok in parsed: if tok.startswith('!happytree') or tok.startswith(f'<@{self.application_id}>'): tags += 1 else: break parsed = parsed[tags:] with closing(sqlite3.connect("users.db")) as conn: with closing(conn.cursor()) as curs: existing_user = curs.execute('SELECT * FROM users where id=?', (message.author.id,)).fetchall() curs.execute('INSERT INTO users (id, first_use, last_use) VALUES (?, strftime(\'%s\', \'now\'), strftime(\'%s\', \'now\')) ON CONFLICT(id) DO UPDATE SET last_use=strftime(\'%s\', \'now\');', (message.author.id,)) conn.commit() self.logger.debug(f'First message token: {parsed[0]}.') if parsed[0].lower() in ["help", "--help", "-help", "-h", "--h"]: await self.usage(message) return elif parsed[0].lower() in ["license", "--license", "-license", "-l", "--l"]: await self.license(message) return if not existing_user: await self.license(message) prompt = "" currOpt = "" customSeed = False seed = random.randint(0, 1000000) samples = 4 steps = 50 strength = 0.75 for token in parsed: if not token.startswith('--'): if currOpt and token.isdigit(): token = int(token) if currOpt == "seed": if (token >= 1) and (token <= 1000000): seed = int(token//1) customSeed = True else: await message.reply(f'Invalid seed value. Seed must be <= 1000000.') return elif currOpt == "n_samples": if (token >= 1) and (token <= 4): samples = int(token//1) else: await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.') return elif currOpt == "ddim_steps": if (token >= 1) and (token <= 80): steps = int(token//1) else: await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.') return elif currOpt == "strength": if (token > 0) and (token < 1): strength = token else: await message.reply(f'Invalid strength value. Strength must be a decimal between 0 and 1.') return currOpt = "" else: if not prompt: prompt = token else: prompt = " ".join([prompt, token]) else: opt = token[2:] if "=" in opt: optName, _, val = opt.partition('=') if optName == "seed": if (val >= 1) and (val <= 1000000): seed = int(val//1) customSeed = True else: await message.reply(f'Invalid seed value. Seed must be <= 1000000.') return elif optName == "n_samples": if (val >= 1) and (val <= 4): samples = int(val//1) else: await message.reply(f'Invalid n_samples value. This GPU is old and cannot do more than 4 at a time.') return elif optName == "ddim_steps": if (val >= 1) and (val <= 80): steps = int(val//1) else: await message.reply(f'Invalid ddim_steps value. Limited to 80 for time considerations.') return elif optName == "strength": if (val > 0) and (val < 1): strength = val else: await message.reply(f'Invalid strength value. Strength must be a decimal between 0 and 1.') else: await message.reply(f'Invalid option name: {optName}. Valid options are "--ddim_steps", "--n_samples", "--seed", and "--strength".') return else: if opt in ["ddim_steps", "n_samples", "seed", "strength"]: currOpt = opt else: await message.reply(f'Invalid option name: {opt}. Valid options are "--ddim_steps", "--n_samples", "--seed", and "--strength".') return self.logger.info(f'Queueing request for {message.author}. Prompt: "{prompt}"; Samples: {samples}; Seed: {seed}; Steps: {steps}; Strength: {strength}.') self.commissions.put_nowait((t, 'paint', message, prompt, samples, seed, steps, strength)) position=self.commissions.qsize() remain = self.currtasktime - time.time() if remain < 0: remain = 0 waittime = ((remain + 30 + (2 * steps * samples)) / 60) + ((position - 1) * 7) await message.reply(f'{message.author} you are number {position} in the Happy Trees processing queue. Estimated wait time is {waittime:0.2f} minutes.') async def painting(self): while True: self.logger.debug(f'Bob Ross is ready to paint!') t, command, message, prompt, samples, seed, steps, strength = await self.commissions.get() if command == "embiggen": self.currtasktime = time.time() + 30 os.makedirs(self.outpath, exist_ok=True) with tempfile.NamedTemporaryFile(suffix=".png", dir=self.outpath) as fp: start = time.perf_counter() proc = await asyncio.create_subprocess_exec(os.path.join(self.ganpath, 'realesrgan-ncnn-vulkan'), '-i', os.path.join(self.outpath, prompt), '-o', fp.name, '-n', 'realesrgan-x4plus', stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.ganpath) self.logger.info(f'Started real-ESRGAN subprocess, PID: {proc.pid}') procout, procerr = await proc.communicate() procoutput = procout.decode().strip() self.logger.info(f'Process output: {procoutput}') complete = time.perf_counter() self.logger.info(f'Embiggening of "{prompt}" complete. Output at {fp.name}. Wait time: {start-t:0.5f}; Paint time: {complete-start:0.5f}; Total time: {complete-t:0.5f}.') try: await message.reply(file=discord.File(fp.name, description=f'Embiggened version of {prompt}')) except: await message.reply(f'An error occurred. Output file not found.') raise self.logger.debug(f'Reply sent') self.commissions.task_done() elif command == "paint": self.logger.info(f'Bob Ross is painting "{prompt}" for {message.author}.') self.currtasktime = time.time() + 30 + (2 * steps * samples) # This section copies the filename generation code from # optimized_txt2img.py from optimizedSD os.makedirs(self.outpath, exist_ok=True) promptpath = "_".join(re.split(":| ", prompt))[:150] sample_path = os.path.join(self.outpath, promptpath) os.makedirs(sample_path, exist_ok=True) hashdirpath = os.path.join(self.outpath, 'hashed') os.makedirs(hashdirpath, exist_ok=True) hash_path = os.path.join(hashdirpath, hashlib.sha256(promptpath.encode()).hexdigest()) if not os.path.exists(hash_path): os.symlink(sample_path, hash_path, target_is_directory=True) base_count = len(os.listdir(sample_path)) attachment="" if message.attachments: if message.attachments[0].content_type.startswith('image/'): base_count += 1 attachment = os.path.join(sample_path, "init_" + str(seed) + "_" + f"{base_count:05}_{message.attachments[0].filename}") await message.attachments[0].save(attachment) if samples > 1: outfile = os.path.join(hash_path, "grid_" + str(seed) + "_" + f"{base_count:04}.png") else: outfile = os.path.join(hash_path, "seed_" + str(seed) + "_" + f"{base_count:05}.png") self.logger.debug(f'Output file will be: {outfile}') num_rows = int((samples**0.5)//1) start = time.perf_counter() self.logger.info(f'About to start the subprocess...') if attachment: proc = await asyncio.create_subprocess_exec('/usr/bin/python3','optimizedSD/optimized_img2img.py', '--H', '448', '--W', '448', '--precision', 'full', '--outdir', self.outpath, '--init-img', str(attachment), '--strength', str(strength), '--n_iter', '1', '--n_samples', str(samples), '--n_rows', str(num_rows), '--ddim_steps', str(steps), '--seed', str(seed), '--prompt', str(prompt), stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.sdpath) else: proc = await asyncio.create_subprocess_exec('/usr/bin/python3','optimizedSD/optimized_txt2img.py', '--H', '448', '--W', '448', '--precision', 'full', '--outdir', self.outpath, '--n_iter', '1', '--n_samples', str(samples), '--n_rows', str(num_rows), '--ddim_steps', str(steps), '--seed', str(seed), '--prompt', str(prompt), stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, cwd=self.sdpath) self.logger.info(f'Started SD subprocess, PID: {proc.pid}') procout, procerr = await proc.communicate() procoutput = procout.decode().strip() self.logger.info(f'Process output: {procoutput}') complete = time.perf_counter() self.logger.info(f'Painting of "{prompt}" complete for {message.author}. Output at {outfile}. Wait time: {start-t:0.5f}; Paint time: {complete-start:0.5f}; Total time: {complete-t:0.5f}.') try: if samples > 1: await message.reply(file=discord.File(outfile, description=f'Prompt: "{prompt}"; Starting Seed: {seed}; Steps: {steps}'), view=GridSelect(samples, num_rows, outfile, self.outpath)) else: await message.reply(file=discord.File(outfile, description=f'Prompt: "{prompt}"; Starting Seed: {seed}; Steps: {steps}'), view=EnlargeView(outfile, self.outpath)) except: await message.reply(f'An error occurred. Output file not found.') raise self.logger.debug(f'Reply sent') self.commissions.task_done() else: self.logger.warn(f'Unknown command in queue: {command}') self.commissions.task_done() async def bobross(self): while True: self.logger.debug(f'Bob Ross is preparing a new canvas.') self.paintings.append(asyncio.create_task(self.painting())) self.logger.debug(f'Bob Ross is waiting on a commission.') complaints = await asyncio.gather(*self.paintings, return_exceptions=True) if complaints: self.logger.warning(f'Bob had some complaints: {complaints}') self.logger.info(f'All paintings complete.') self.paintings = [] async def main(): bot = HappyTreesBot() try: token = os.getenv("HAPPYTREES_TOKEN") if not token: logger.debug(f'Getting token from file.') with open(".token") as f: token = f.read().replace('\n','') await asyncio.gather(bot.start(str(token)), bot.bobross()) finally: for painting in bot.paintings: logger.info(f'Cancelling in-progress paintings.') painting.cancel() await asyncio.gather(*bot.paintings, return_exceptions=True) logger.info(f'All paintings have cancelled.') logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.FileHandler(filename='happytrees.log', encoding='utf-8', mode='w') formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}', '%Y-%m-%d %H:%M:%S', style='{') handler.setFormatter(formatter) logger.addHandler(handler) try: asyncio.run(main()) except KeyboardInterrupt: logger.info(f'Keyboard interrupt received, Happy Trees Bot shutting down.') sys.exit(0)