84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
|
|
from datetime import datetime, timedelta
|
||
|
|
from typing import Optional
|
||
|
|
from jose import JWTError, jwt
|
||
|
|
import bcrypt
|
||
|
|
from app.core.config import settings
|
||
|
|
from fastapi import Depends, HTTPException, status, Security
|
||
|
|
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from app.db.session import get_db
|
||
|
|
from app.models.user import User
|
||
|
|
from app.services.api_key import APIKeyService
|
||
|
|
|
||
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login/access-token")
|
||
|
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||
|
|
|
||
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
|
|
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||
|
|
|
||
|
|
def get_password_hash(password: str) -> str:
|
||
|
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||
|
|
|
||
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||
|
|
to_encode = data.copy()
|
||
|
|
if expires_delta:
|
||
|
|
expire = datetime.utcnow() + expires_delta
|
||
|
|
else:
|
||
|
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
|
|
to_encode.update({"exp": expire})
|
||
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||
|
|
return encoded_jwt
|
||
|
|
|
||
|
|
def get_current_user(
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
token: str = Depends(oauth2_scheme)
|
||
|
|
) -> User:
|
||
|
|
credentials_exception = HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Could not validate credentials",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
|
|
username: str = payload.get("sub")
|
||
|
|
if username is None:
|
||
|
|
raise credentials_exception
|
||
|
|
except JWTError:
|
||
|
|
raise credentials_exception
|
||
|
|
|
||
|
|
user = db.query(User).filter(User.username == username).first()
|
||
|
|
if user is None:
|
||
|
|
raise credentials_exception
|
||
|
|
if not user.is_active:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Inactive user",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
)
|
||
|
|
return user
|
||
|
|
|
||
|
|
def get_api_key_user(
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
api_key: str = Security(api_key_header),
|
||
|
|
) -> User:
|
||
|
|
if not api_key:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="API key header missing",
|
||
|
|
)
|
||
|
|
|
||
|
|
api_key_obj = APIKeyService.get_api_key_by_key(db=db, key=api_key)
|
||
|
|
if not api_key_obj:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Invalid API key",
|
||
|
|
)
|
||
|
|
|
||
|
|
if not api_key_obj.is_active:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Inactive API key",
|
||
|
|
)
|
||
|
|
|
||
|
|
APIKeyService.update_last_used(db=db, api_key=api_key_obj)
|
||
|
|
return api_key_obj.user
|