-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7e6f0a3
commit 59c7695
Showing
3 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
# WF 2024-01-22 | ||
for package in wikibot3rd tests | ||
do | ||
isort $package/*.py | ||
black $package/*.py | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Created on 2024-01-22 | ||
@author: wf | ||
""" | ||
import os | ||
|
||
import yaml | ||
|
||
from tests.basetest import BaseTest | ||
from wikibot3rd.sso import SSO | ||
|
||
|
||
class TestSSO(BaseTest): | ||
""" | ||
test single sign on | ||
""" | ||
|
||
def get_credentials(self): | ||
credentials_file = os.path.expanduser("~/.mediawiki-japi/cr_credentials.yaml") | ||
with open(credentials_file, "r") as file: | ||
credentials = yaml.safe_load(file) | ||
username = credentials["username"] | ||
password = credentials["password"] | ||
return username, password | ||
|
||
def test_mw_sso(self): | ||
""" | ||
test mediawiki single sign on | ||
""" | ||
if not self.inPublicCI(): | ||
debug = self.debug | ||
debug = True | ||
db_username, db_password = self.get_credentials() | ||
sso = SSO( | ||
"cr.bitplan.com", | ||
"crwiki", | ||
db_username=db_username, | ||
db_password=db_password, | ||
debug=debug, | ||
) | ||
port_avail = sso.check_port() | ||
if not port_avail: | ||
print(f"SQL Port {sso.sql_port} not accessible") | ||
print("You might want to try opening an SSL tunnel to the port with") | ||
print(f"ssh -L {sso.sql_port}:{sso.host}:{sso.sql_port} {sso.host}") | ||
wiki_user = self.getWikiUser("cr") | ||
is_valid = sso.check_credentials( | ||
username=wiki_user.user, password=wiki_user.get_password() | ||
) | ||
self.assertTrue(is_valid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
""" | ||
Created on 2024-01-22 | ||
@author: wf | ||
with ChatGPT-4 prompting | ||
""" | ||
import base64 | ||
import hashlib | ||
import socket | ||
import traceback | ||
from typing import Optional | ||
|
||
import mysql.connector | ||
from mysql.connector import pooling | ||
|
||
|
||
class SSO: | ||
""" | ||
A class to implement MediaWiki single sign-on support. | ||
This class provides functionality to connect to a MediaWiki database, | ||
verify user credentials, and handle database connections with pooling. | ||
Attributes: | ||
host (str): The host of the MediaWiki database. | ||
database (str): The name of the MediaWiki database. | ||
sql_port (int): The SQL port for the database connection. | ||
db_username (Optional[str]): The database username. | ||
db_password (Optional[str]): The database password. | ||
with_pool (bool): Flag to determine if connection pooling is used. | ||
timeout (float): The timeout for checking SQL port availability. | ||
debug (Optional[bool]): Flag to enable debug mode. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
host: str, | ||
database: str, | ||
sql_port: int = 3306, | ||
db_username: Optional[str] = None, | ||
db_password: Optional[str] = None, | ||
with_pool: bool = True, | ||
timeout: float = 3, | ||
debug: Optional[bool] = False, | ||
): | ||
""" | ||
Constructs all the necessary attributes for the SSO object. | ||
Args: | ||
host (str): The host of the MediaWiki database. | ||
database (str): The name of the MediaWiki database. | ||
sql_port (int): The SQL port for the database connection. | ||
db_username (Optional[str]): The database username. | ||
db_password (Optional[str]): The database password. | ||
with_pool (bool): Flag to determine if connection pooling is used. | ||
timeout (float): The timeout for checking SQL port availability. | ||
debug (Optional[bool]): Flag to enable debug mode. | ||
""" | ||
self.host = host | ||
self.database = database | ||
self.sql_port = sql_port | ||
self.timeout = timeout | ||
self.db_username = db_username | ||
self.db_password = db_password | ||
self.debug = debug | ||
self.pool = self.get_pool() if with_pool else None | ||
|
||
def get_pool(self) -> pooling.MySQLConnectionPool: | ||
""" | ||
Creates a connection pool for the database. | ||
Returns: | ||
MySQLConnectionPool: A pool of database connections. | ||
""" | ||
pool_config = { | ||
"pool_name": "mypool", | ||
"pool_size": 2, | ||
"host": self.host, | ||
"user": self.db_username, | ||
"password": self.db_password, | ||
"database": self.database, | ||
"raise_on_warnings": True, | ||
} | ||
return pooling.MySQLConnectionPool(**pool_config) | ||
|
||
def check_port(self) -> bool: | ||
""" | ||
Checks if the specified SQL port is accessible on the configured host. | ||
Returns: | ||
bool: True if the port is accessible, False otherwise. | ||
""" | ||
try: | ||
with socket.create_connection( | ||
(self.host, self.sql_port), timeout=self.timeout | ||
): | ||
return True | ||
except socket.error as ex: | ||
if self.debug: | ||
print(f"Connection to {self.host} port {self.sql_port} failed: {ex}") | ||
traceback.print_exc() | ||
return False | ||
|
||
def verify_password(self, password: str, hash_value: str) -> bool: | ||
""" | ||
Verifies a password against a stored hash value. | ||
Args: | ||
password (str): The password to verify. | ||
hash_value (str): The stored hash value to compare against. | ||
Returns: | ||
bool: True if the password matches the hash value, False otherwise. | ||
""" | ||
parts = hash_value.split(":") | ||
if len(parts) != 7: | ||
raise ValueError("Invalid hash format") | ||
|
||
( | ||
_, | ||
pbkdf2_indicator, | ||
hash_algorithm, | ||
iterations, | ||
_, | ||
salt, | ||
hashed_password, | ||
) = parts | ||
|
||
if pbkdf2_indicator != "pbkdf2": | ||
raise ValueError("verify_password expects pbkdf2 hashes") | ||
|
||
iterations = int(iterations) | ||
|
||
def fix_base64_padding(string: str) -> str: | ||
return string + "=" * (-len(string) % 4) | ||
|
||
salt = fix_base64_padding(salt) | ||
hashed_password = fix_base64_padding(hashed_password) | ||
|
||
salt = base64.b64decode(salt) | ||
hashed_password = base64.b64decode(hashed_password) | ||
|
||
if hash_algorithm not in hashlib.algorithms_available: | ||
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}") | ||
|
||
new_hash = hashlib.pbkdf2_hmac( | ||
hash_algorithm, password.encode("utf-8"), salt, iterations | ||
) | ||
return new_hash == hashed_password | ||
|
||
def check_credentials(self, username: str, password: str) -> bool: | ||
""" | ||
Checks the validity of MediaWiki username and password. | ||
Args: | ||
username (str): The MediaWiki username. | ||
password (str): The password to verify. | ||
Returns: | ||
bool: True if the credentials are valid, False otherwise. | ||
""" | ||
is_valid = False | ||
try: | ||
connection = ( | ||
self.pool.get_connection() | ||
if self.pool | ||
else mysql.connector.connect( | ||
host=self.host, | ||
user=self.db_username, | ||
password=self.db_password, | ||
database=self.database, | ||
) | ||
) | ||
mw_username = username[0].upper() + username[1:] | ||
cursor = connection.cursor(dictionary=True) | ||
cursor.execute( | ||
"SELECT user_password FROM `user` WHERE user_name = %s", (mw_username,) | ||
) | ||
result = cursor.fetchone() | ||
|
||
if result: | ||
stored_hash = result["user_password"] | ||
is_valid = self.verify_password(password, stored_hash.decode("utf-8")) | ||
elif self.debug: | ||
print( | ||
f"Username {mw_username} not found in {self.database} on {self.host}" | ||
) | ||
|
||
cursor.close() | ||
except Exception as ex: | ||
if self.debug: | ||
print(f"Database error: {ex}") | ||
traceback.print_exc() | ||
finally: | ||
if connection and connection.is_connected(): | ||
connection.close() | ||
return is_valid |