From 69a160c6f2b6ebd0ed6b437f244a281db400c115 Mon Sep 17 00:00:00 2001 From: ReenigneArcher <42013603+ReenigneArcher@users.noreply.github.com> Date: Sat, 9 Nov 2024 16:49:41 -0500 Subject: [PATCH] feat(github): add github integration --- requirements.txt | 1 + src/__main__.py | 26 +++--- src/crypto.py | 69 ++++++++++++++++ src/database.py | 22 +++++ src/discord/bot.py | 7 +- src/discord/cogs/github_commands.py | 123 ++++++++++++++++++++++++++++ src/globals.py | 2 + src/keep_alive.py | 20 ----- src/reddit/bot.py | 11 +-- src/webapp.py | 86 +++++++++++++++++++ 10 files changed, 321 insertions(+), 46 deletions(-) create mode 100644 src/crypto.py create mode 100644 src/database.py create mode 100644 src/discord/cogs/github_commands.py create mode 100644 src/globals.py delete mode 100644 src/keep_alive.py create mode 100644 src/webapp.py diff --git a/requirements.txt b/requirements.txt index 2f48774..7581d55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ praw==7.8.1 py-cord==2.6.1 python-dotenv==1.0.1 requests==2.32.3 +requests-oauthlib==2.0.0 diff --git a/src/__main__.py b/src/__main__.py index 7968744..3b05c76 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,5 +1,4 @@ # standard imports -import os import time # development imports @@ -8,33 +7,28 @@ # local imports if True: # hack for flake8 + from src import globals from src.discord import bot as d_bot - from src import keep_alive + from src import webapp from src.reddit import bot as r_bot def main(): - # to run in replit - try: - os.environ['REPL_SLUG'] - except KeyError: - pass # not running in replit - else: - keep_alive.keep_alive() # Start the web server + webapp.start() # Start the web server - discord_bot = d_bot.Bot() - discord_bot.start_threaded() # Start the discord bot + globals.DISCORD_BOT = d_bot.Bot() + globals.DISCORD_BOT.start_threaded() # Start the discord bot - reddit_bot = r_bot.Bot() - reddit_bot.start_threaded() # Start the reddit bot + globals.REDDIT_BOT = r_bot.Bot() + globals.REDDIT_BOT.start_threaded() # Start the reddit bot try: - while discord_bot.bot_thread.is_alive() or reddit_bot.bot_thread.is_alive(): + while globals.DISCORD_BOT.bot_thread.is_alive() or globals.REDDIT_BOT.bot_thread.is_alive(): time.sleep(0.5) except KeyboardInterrupt: print("Keyboard Interrupt Detected") - discord_bot.stop() - reddit_bot.stop() + globals.DISCORD_BOT.stop() + globals.REDDIT_BOT.stop() if __name__ == '__main__': # pragma: no cover diff --git a/src/crypto.py b/src/crypto.py new file mode 100644 index 0000000..766bb92 --- /dev/null +++ b/src/crypto.py @@ -0,0 +1,69 @@ +# standard imports +import os + +# lib imports +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption +from datetime import datetime, timedelta, UTC + +# local imports +from src import common + +CERT_FILE = os.path.join(common.data_dir, "cert.pem") +KEY_FILE = os.path.join(common.data_dir, "key.pem") + + +def check_expiration(cert_path: str) -> int: + with open(cert_path, "rb") as cert_file: + cert_data = cert_file.read() + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + expiry_date = cert.not_valid_after_utc + return (expiry_date - datetime.now(UTC)).days + + +def generate_certificate(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, + ) + subject = issuer = x509.Name([ + x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"), + ]) + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.now(UTC) + ).not_valid_after( + datetime.now(UTC) + timedelta(days=365) + ).sign(private_key, hashes.SHA256()) + + with open(KEY_FILE, "wb") as f: + f.write(private_key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=NoEncryption(), + )) + + with open(CERT_FILE, "wb") as f: + f.write(cert.public_bytes(Encoding.PEM)) + + +def initialize_certificate() -> tuple[str, str]: + print("Initializing SSL certificate") + if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE): + cert_expires_in = check_expiration(CERT_FILE) + print(f"Certificate expires in {cert_expires_in} days.") + if cert_expires_in >= 90: + return CERT_FILE, KEY_FILE + print("Generating new certificate") + generate_certificate() + return CERT_FILE, KEY_FILE diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..b27fdd4 --- /dev/null +++ b/src/database.py @@ -0,0 +1,22 @@ +# standard imports +import shelve +import threading + + +class Database: + def __init__(self, db_path): + self.db_path = db_path + self.lock = threading.Lock() + + def __enter__(self): + self.lock.acquire() + self.db = shelve.open(self.db_path, writeback=True) + return self.db + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync() + self.db.close() + self.lock.release() + + def sync(self): + self.db.sync() diff --git a/src/discord/bot.py b/src/discord/bot.py index a9baf6c..64a1622 100644 --- a/src/discord/bot.py +++ b/src/discord/bot.py @@ -7,7 +7,8 @@ import discord # local imports -from src.common import bot_name, get_avatar_bytes, org_name +from src.common import bot_name, data_dir, get_avatar_bytes, org_name +from src.database import Database from src.discord.tasks import daily_task from src.discord.views import DonateCommandView @@ -30,6 +31,7 @@ def __init__(self, *args, **kwargs): self.bot_thread = threading.Thread(target=lambda: None) self.token = os.environ['DISCORD_BOT_TOKEN'] + self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database')) self.load_extension( name='src.discord.cogs', @@ -37,6 +39,9 @@ def __init__(self, *args, **kwargs): store=False, ) + with self.db as db: + db['oauth_states'] = {} # clear any oauth states from previous sessions + async def on_ready(self): """ Bot on ready event. diff --git a/src/discord/cogs/github_commands.py b/src/discord/cogs/github_commands.py new file mode 100644 index 0000000..2054ab4 --- /dev/null +++ b/src/discord/cogs/github_commands.py @@ -0,0 +1,123 @@ +# standard imports +import os + +# lib imports +import discord +import requests +from requests_oauthlib import OAuth2Session + + +class GitHubCommandsCog(discord.Cog): + def __init__(self, bot): + self.bot = bot + self.token = os.getenv("GITHUB_TOKEN") + self.org_name = os.getenv("GITHUB_ORG_NAME", "LizardByte") + self.graphql_url = "https://api.github.com/graphql" + self.headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + @discord.slash_command( + name="get_sponsors", + description="Get list of GitHub sponsors", + default_member_permissions=discord.Permissions(manage_guild=True), + ) + async def get_sponsors( + self, + ctx: discord.ApplicationContext, + ): + """ + Get list of GitHub sponsors. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + edges { + node { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + name + monthlyPriceInDollars + } + } + } + } + } + } + """ % self.org_name + + response = requests.post(self.graphql_url, json={'query': query}, headers=self.headers) + data = response.json() + + if 'errors' in data: + print(data['errors']) + await ctx.respond("An error occurred while fetching sponsors.", ephemeral=True) + return + + message = "List of GitHub sponsors" + for edge in data['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + tier = edge['node'].get('tier', {}) + tier_info = f" - Tier: {tier.get('name', 'N/A')} (${tier.get('monthlyPriceInDollars', 'N/A')}/month)" + message += f"\n* [{sponsor['login']}]({sponsor['url']}){tier_info}" + + embed = discord.Embed(title="GitHub Sponsors", color=0x00ff00, description=message) + + await ctx.respond(embed=embed, ephemeral=True) + + @discord.slash_command( + name="link_github", + description="Validate GitHub sponsor status" + ) + async def link_github(self, ctx: discord.ApplicationContext): + """ + Link Discord account with GitHub account, by validating Discord user's "GitHub" connected account status. + + User to login to Discord via OAuth2, and check if their connected GitHub account is a sponsor of the project. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + discord_oauth = OAuth2Session( + os.environ['DISCORD_CLIENT_ID'], + redirect_uri=os.environ['DISCORD_REDIRECT_URI'], + scope=[ + "identify", + "connections", + ], + ) + authorization_url, state = discord_oauth.authorization_url("https://discord.com/oauth2/authorize") + + with self.bot.db as db: + db['oauth_states'] = db.get('oauth_states', {}) + db['oauth_states'][str(ctx.author.id)] = state + db.sync() + + # Store the state in the user's session or database + await ctx.respond(f"Please authorize the application by clicking [here]({authorization_url}).", ephemeral=True) + + +def setup(bot: discord.Bot): + bot.add_cog(GitHubCommandsCog(bot=bot)) diff --git a/src/globals.py b/src/globals.py new file mode 100644 index 0000000..f185cab --- /dev/null +++ b/src/globals.py @@ -0,0 +1,2 @@ +DISCORD_BOT = None +REDDIT_BOT = None diff --git a/src/keep_alive.py b/src/keep_alive.py deleted file mode 100644 index 74ab1c9..0000000 --- a/src/keep_alive.py +++ /dev/null @@ -1,20 +0,0 @@ -from flask import Flask -from threading import Thread -import os - -app = Flask('') - - -@app.route('/') -def main(): - return f"{os.environ['REPL_SLUG']} is live!" - - -def run(): - app.run(host="0.0.0.0", port=8080) - - -def keep_alive(): - server = Thread(name="Flask", target=run) - server.setDaemon(daemonic=True) - server.start() diff --git a/src/reddit/bot.py b/src/reddit/bot.py index 7520b9e..620d0d6 100644 --- a/src/reddit/bot.py +++ b/src/reddit/bot.py @@ -1,7 +1,6 @@ # standard imports from datetime import datetime import os -import requests import shelve import sys import threading @@ -10,6 +9,7 @@ # lib imports import praw from praw import models +import requests # local imports from src import common @@ -31,14 +31,7 @@ def __init__(self, **kwargs): self.user_agent = kwargs.get('user_agent', f'{common.bot_name} {self.version}') self.avatar = kwargs.get('avatar', common.get_bot_avatar(gravatar=os.environ['GRAVATAR_EMAIL'])) self.subreddit_name = kwargs.get('subreddit', os.getenv('PRAW_SUBREDDIT', 'LizardByte')) - - if not kwargs.get('redirect_uri', None): - try: # for running in replit - self.redirect_uri = f'https://{os.environ["REPL_SLUG"]}.{os.environ["REPL_OWNER"].lower()}.repl.co' - except KeyError: - self.redirect_uri = os.getenv('REDIRECT_URI', 'http://localhost:8080') - else: - self.redirect_uri = kwargs['redirect_uri'] + self.redirect_uri = kwargs.get('redirect_uri', os.getenv('REDIRECT_URI', 'http://localhost:8080')) # directories self.data_dir = common.data_dir diff --git a/src/webapp.py b/src/webapp.py new file mode 100644 index 0000000..ebc9cfe --- /dev/null +++ b/src/webapp.py @@ -0,0 +1,86 @@ +# standard imports +import os +from threading import Thread + +# lib imports +from flask import Flask, request +from requests_oauthlib import OAuth2Session + +# local imports +from src import crypto +from src import globals + + +DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") +DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") +DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI") + +app = Flask('LizardByte-bot') + + +@app.route('/') +def main(): + return "LizardByte-bot is live!" + + +@app.route("/discord/callback") +def discord_callback(): + # get all active states from the global state manager + with globals.DISCORD_BOT.db as db: + active_states = db['oauth_states'] + + discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) + token = discord_oauth.fetch_token("https://discord.com/api/oauth2/token", + client_secret=DISCORD_CLIENT_SECRET, + authorization_response=request.url) + + # Fetch the user's Discord profile + response = discord_oauth.get("https://discord.com/api/users/@me") + discord_user = response.json() + + # if the user is not in the active states, return an error + if discord_user['id'] not in active_states: + return "Invalid state" + + # remove the user from the active states + del active_states[discord_user['id']] + + # Fetch the user's connected accounts + connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") + connections = connections_response.json() + + with globals.DISCORD_BOT.db as db: + db['discord_users'] = db.get('discord_users', {}) + db['discord_users'][discord_user['id']] = { + 'discord_username': discord_user['username'], + 'discord_global_name': discord_user['global_name'], + 'github_id': None, + 'github_username': None, + 'token': token, # TODO: should we store the token at all? + } + + for connection in connections: + if connection['type'] == 'github': + db['discord_users'][discord_user['id']]['github_id'] = connection['id'] + db['discord_users'][discord_user['id']]['github_username'] = connection['name'] + + return "Discord account linked successfully!" + + +def run(): + cert_file, key_file = crypto.initialize_certificate() + + app.run( + host="0.0.0.0", + port=8080, + ssl_context=(cert_file, key_file) + ) + + +def start(): + server = Thread( + name="Flask", + daemon=True, + target=run, + ) + server.start()