Refactor code structure for improved readability and maintainability
This commit is contained in:
+107
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db_session
|
||||
|
||||
SESSION_COOKIE_KEYS = [
|
||||
"next-auth.session-token",
|
||||
"__Secure-next-auth.session-token",
|
||||
"authjs.session-token",
|
||||
"__Secure-authjs.session-token",
|
||||
]
|
||||
|
||||
|
||||
def _jwt_secret() -> str:
|
||||
return settings.mobile_jwt_secret or settings.nextauth_secret
|
||||
|
||||
|
||||
def create_access_token(user_id: str) -> str:
|
||||
now = dt.datetime.now(dt.timezone.utc)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int((now + dt.timedelta(days=7)).timestamp()),
|
||||
}
|
||||
secret = _jwt_secret()
|
||||
if not secret:
|
||||
raise RuntimeError("Missing MOBILE_JWT_SECRET or NEXTAUTH_SECRET")
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
|
||||
async def _get_user_by_id(db: AsyncSession, user_id: str) -> dict[str, Any] | None:
|
||||
result = await db.execute(
|
||||
text(
|
||||
'SELECT id, email, name, image, role FROM "User" WHERE id = :user_id LIMIT 1'
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
row = result.mappings().first()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
async def _get_user_from_session_cookie(db: AsyncSession, request: Request) -> dict[str, Any] | None:
|
||||
token = None
|
||||
for key in SESSION_COOKIE_KEYS:
|
||||
value = request.cookies.get(key)
|
||||
if value:
|
||||
token = value
|
||||
break
|
||||
|
||||
if not token:
|
||||
return None
|
||||
|
||||
result = await db.execute(
|
||||
text(
|
||||
'SELECT u.id, u.email, u.name, u.image, u.role '
|
||||
'FROM "Session" s '
|
||||
'JOIN "User" u ON u.id = s."userId" '
|
||||
'WHERE s."sessionToken" = :token AND s.expires > NOW() '
|
||||
'LIMIT 1'
|
||||
),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.mappings().first()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
async def resolve_current_user(db: AsyncSession, request: Request) -> dict[str, Any] | None:
|
||||
auth = request.headers.get("authorization", "")
|
||||
if auth.lower().startswith("bearer "):
|
||||
token = auth.split(" ", 1)[1].strip()
|
||||
secret = _jwt_secret()
|
||||
if not secret:
|
||||
return None
|
||||
try:
|
||||
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
||||
subject = payload.get("sub")
|
||||
if isinstance(subject, str) and subject:
|
||||
return await _get_user_by_id(db, subject)
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
return await _get_user_from_session_cookie(db, request)
|
||||
|
||||
|
||||
async def require_current_user(db: AsyncSession, request: Request) -> dict[str, Any]:
|
||||
user = await resolve_current_user(db, request)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return user
|
||||
|
||||
|
||||
async def require_mod_user(request: Request, db: AsyncSession = Depends(get_db_session)) -> dict[str, Any]:
|
||||
user = await resolve_current_user(db, request)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
if user.get("role") not in ("MOD", "ADMIN"):
|
||||
raise HTTPException(status_code=403, detail="Forbidden: MOD or ADMIN role required")
|
||||
return user
|
||||
@@ -0,0 +1,45 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
app_name: str = "reader-api"
|
||||
app_env: str = "development"
|
||||
|
||||
database_url: str
|
||||
mongodb_uri: str
|
||||
|
||||
google_client_id: str = ""
|
||||
nextauth_secret: str = ""
|
||||
mobile_jwt_secret: str = ""
|
||||
|
||||
cors_origins: str = "*"
|
||||
r2_account_id: str = ""
|
||||
r2_access_key_id: str = ""
|
||||
r2_secret_access_key: str = ""
|
||||
r2_bucket_name: str = ""
|
||||
r2_public_base_url: str = ""
|
||||
|
||||
deepseek_key: str = ""
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
openrouter_key: str = ""
|
||||
openrouter_paused: str = "true"
|
||||
|
||||
@property
|
||||
def google_client_id_list(self) -> list[str]:
|
||||
raw = (self.google_client_id or "").strip()
|
||||
if not raw:
|
||||
return []
|
||||
return [item.strip() for item in raw.split(",") if item.strip()]
|
||||
|
||||
|
||||
@property
|
||||
def cors_origin_list(self) -> list[str]:
|
||||
raw = (self.cors_origins or "*").strip()
|
||||
if raw == "*":
|
||||
return ["*"]
|
||||
return [item.strip() for item in raw.split(",") if item.strip()]
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,45 @@
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _normalize_database_url(url: str) -> str:
|
||||
# Strip Prisma-only query params (e.g. ?schema=public) that asyncpg doesn't accept
|
||||
from urllib.parse import urlparse, urlencode, parse_qsl, urlunparse
|
||||
|
||||
_PRISMA_ONLY_PARAMS = {"schema", "connection_limit", "pool_timeout", "connect_timeout", "sslmode"}
|
||||
|
||||
if url.startswith("postgresql+asyncpg://"):
|
||||
scheme = "postgresql+asyncpg"
|
||||
elif url.startswith("postgresql://") or url.startswith("postgres://"):
|
||||
scheme = "postgresql+asyncpg"
|
||||
else:
|
||||
return url
|
||||
|
||||
parsed = urlparse(url)
|
||||
clean_params = [(k, v) for k, v in parse_qsl(parsed.query) if k not in _PRISMA_ONLY_PARAMS]
|
||||
clean_url = urlunparse((
|
||||
scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
urlencode(clean_params),
|
||||
parsed.fragment,
|
||||
))
|
||||
return clean_url
|
||||
|
||||
|
||||
engine = create_async_engine(_normalize_database_url(settings.database_url), pool_pre_ping=True)
|
||||
SessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
mongo_client = AsyncIOMotorClient(settings.mongodb_uri)
|
||||
mongo_db = mongo_client.get_default_database()
|
||||
|
||||
|
||||
async def get_db_session() -> AsyncSession:
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
+1202
File diff suppressed because it is too large
Load Diff
+2417
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user