from fastapi import APIRouter, Request, Response, status, Depends, HTTPException
from pydantic import EmailStr
from datetime import timedelta
from app import oauth2
from sqlalchemy.orm import Session
from .. import schemas, models, utils
from ..database import get_db
from ..oauth2 import AuthJWT
from ..config import settings
router = APIRouter()
ACCESS_TOKEN_EXPIRES_IN = settings.ACCESS_TOKEN_EXPIRES_IN
REFRESH_TOKEN_EXPIRES_IN = settings.REFRESH_TOKEN_EXPIRES_IN
# 用户注册控制器
@router.post('/register', status_code=status.HTTP_201_CREATED, response_model=schemas.UserResponse)
async def create_user(payload: schemas.CreateUserSchema, db: Session = Depends(get_db)):
# 判断注册邮箱是否已存在
user = db.query(models.User).filter(models.User.email == EmailStr(payload.email.lower())).first()
if user:
raise HTTPException(status_code=status.HTTP_409_CONFLICT,detail='Account already exist')
# 判断两次密码是否一致
if payload.password != payload.passwordConfirm:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,detail='Passwords do not match')
# 哈希用户密码
payload.password = utils.hash_password(payload.password)
# 删除不需要字段
del payload.passwordConfirm
payload.role = 'user'
payload.verified = True
payload.email = EmailStr(payload.email.lower())
new_user = models.User(**payload.dict())
db.add(new_user)
db.commit()
db.refresh(new_user)
return new_user
# 用户登录控制器
@router.post('/login')
def login(payload: schemas.LoginUserSchema, response: Response, db: Session = Depends(get_db), Authorize: AuthJWT = Depends()):
# 查看用户是否存在
user = db.query(models.User).filter(
models.User.email == EmailStr(payload.email.lower())).first()
if not user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail='Incorrect Email or Password')
# 检查用户是否验证了他的电子邮件
if not user.verified:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail='Please verify your email address')
# 检查密码是否有效
if not utils.verify_password(payload.password, user.password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail='Incorrect Email or Password')
# 创建访问令牌
access_token = Authorize.create_access_token(
subject=str(user.id), expires_time=timedelta(minutes=ACCESS_TOKEN_EXPIRES_IN))
# 创建刷新令牌
refresh_token = Authorize.create_refresh_token(
subject=str(user.id), expires_time=timedelta(minutes=REFRESH_TOKEN_EXPIRES_IN))
# 在cookie中存储刷新和访问令牌
response.set_cookie('access_token', access_token, ACCESS_TOKEN_EXPIRES_IN * 60,
ACCESS_TOKEN_EXPIRES_IN * 60, '/', None, False, True, 'lax')
response.set_cookie('refresh_token', refresh_token,
REFRESH_TOKEN_EXPIRES_IN * 60, REFRESH_TOKEN_EXPIRES_IN * 60, '/', None, False, True, 'lax')
response.set_cookie('logged_in', 'True', ACCESS_TOKEN_EXPIRES_IN * 60,
ACCESS_TOKEN_EXPIRES_IN * 60, '/', None, False, False, 'lax')
# 返回信息
return {'status': 'success', 'access_token': access_token}
# 刷新访问令牌控制器
@router.get('/refresh')
def refresh_token(response: Response, request: Request, Authorize: AuthJWT = Depends(), db: Session = Depends(get_db)):
try:
print(Authorize._refresh_cookie_key)
Authorize.jwt_refresh_token_required()
user_id = Authorize.get_jwt_subject()
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not refresh access token')
user = db.query(models.User).filter(models.User.id == user_id).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail='The user belonging to this token no logger exist')
access_token = Authorize.create_access_token(
subject=str(user.id), expires_time=timedelta(minutes=ACCESS_TOKEN_EXPIRES_IN))
except Exception as e:
error = e.__class__.__name__
if error == 'MissingTokenError':
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail='Please provide refresh token')
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=error)
response.set_cookie('access_token', access_token, ACCESS_TOKEN_EXPIRES_IN * 60,
ACCESS_TOKEN_EXPIRES_IN * 60, '/', None, False, True, 'lax')
response.set_cookie('logged_in', 'True', ACCESS_TOKEN_EXPIRES_IN * 60,
ACCESS_TOKEN_EXPIRES_IN * 60, '/', None, False, False, 'lax')
return {'access_token': access_token}
# 注销用户控制器, require_user 路由保护
@router.get('/logout', status_code=status.HTTP_200_OK)
def logout(response: Response, Authorize: AuthJWT = Depends(), user_id: str = Depends(oauth2.require_user)):
Authorize.unset_jwt_cookies()
response.set_cookie('logged_in', '', -1)
return {'status': 'success'}