fabriquedoc/backend/authentication.py
François Pelletier 4a6bfc951f Gros refactoring
2024-12-31 17:00:07 -05:00

102 lines
3.4 KiB
Python

from datetime import timedelta, datetime, timezone
from typing import Annotated
import jwt
from fastapi import Depends, HTTPException
from passlib.exc import InvalidTokenError
from starlette import status
from config import SECRET_KEY, get_fake_users_db, pwd_context, oauth2_scheme, ALGORITHM
from models import UserInDB, User
import base64
import hashlib
import hmac
def verify_password(plain_password, hashed_password):
try:
# Decode the stored hash
decoded = base64.b64decode(hashed_password)
salt, stored_hash = decoded[:16], decoded[16:]
# Hash the plain password with the same salt
key = hashlib.sha256(salt).digest()
new_hash = hmac.new(key, plain_password.encode('utf-8'), hashlib.sha256).digest()
# Compare the new hash with the stored hash
return hmac.compare_digest(new_hash, stored_hash)
except (TypeError, ValueError, base64.binascii.Error):
# This will catch any decoding errors or incorrect hash format
return False
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
from fastapi import HTTPException, status
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
print(f"User not found: {username}") # Debug print
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username",
headers={"WWW-Authenticate": "Bearer"},
)
print(f"Stored hash: {user.hashed_password}") # Debug print
print(f"Provided password: {password}") # Debug print
if not verify_password(password, user.hashed_password):
print("Password verification failed") # Debug print
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect password",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = User(username=username)
except InvalidTokenError:
raise credentials_exception
user = get_user(get_fake_users_db(), username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user