157 lines
5.7 KiB
Python
157 lines
5.7 KiB
Python
import os
|
|
from typing import Optional
|
|
from urllib.parse import quote_plus
|
|
|
|
from dotenv import load_dotenv
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
|
|
# Import Base để có thể tạo tables
|
|
from infra.db.base import Base
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def _normalize_asyncpg_url(url: str) -> str:
|
|
if url.startswith("postgresql+asyncpg://"):
|
|
return url
|
|
if url.startswith("postgresql://"):
|
|
return url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
|
if url.startswith("postgres://"):
|
|
return url.replace("postgres://", "postgresql+asyncpg://", 1)
|
|
return url
|
|
|
|
|
|
class PostgresConnection:
|
|
_instance: Optional["PostgresConnection"] = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(PostgresConnection, cls).__new__(cls)
|
|
cls._instance._initialize()
|
|
return cls._instance
|
|
|
|
def _initialize(self) -> None:
|
|
# Ưu tiên POSTGRES_URL, fallback từ các biến rời
|
|
raw_url = os.getenv("POSTGRES_URL")
|
|
if not raw_url:
|
|
host = os.getenv("POSTGRES_HOST", "localhost")
|
|
port = os.getenv("POSTGRES_PORT", "5432")
|
|
user = os.getenv("POSTGRES_USER", "postgres")
|
|
password = os.getenv("POSTGRES_PASSWORD", "")
|
|
database = os.getenv("POSTGRES_DB", "postgres")
|
|
|
|
# Encode password để handle ký tự đặc biệt như @, /, etc.
|
|
if password:
|
|
auth = f"{user}:{quote_plus(password)}@"
|
|
else:
|
|
auth = f"{user}@"
|
|
raw_url = f"postgresql+asyncpg://{auth}{host}:{port}/{database}"
|
|
|
|
url = _normalize_asyncpg_url(raw_url)
|
|
print(f"🔌 Connecting to database at: {url.split('@')[-1] if '@' in url else url}")
|
|
|
|
self.engine: AsyncEngine = create_async_engine(
|
|
url,
|
|
echo=os.getenv("POSTGRES_ECHO", "false").lower() == "true",
|
|
pool_size=int(os.getenv("POSTGRES_POOL_SIZE", "5")),
|
|
max_overflow=int(os.getenv("POSTGRES_MAX_OVERFLOW", "10")),
|
|
pool_pre_ping=True,
|
|
)
|
|
self.session_maker: async_sessionmaker[AsyncSession] = async_sessionmaker(
|
|
self.engine, expire_on_commit=False
|
|
)
|
|
|
|
async def wait_for_connection(self, timeout: int = 60, retry_interval: int = 2):
|
|
"""Wait for database connection to be ready"""
|
|
import asyncio
|
|
from sqlalchemy import text
|
|
import time
|
|
|
|
start_time = time.time()
|
|
while True:
|
|
try:
|
|
async with self.engine.begin() as conn:
|
|
await conn.execute(text("SELECT 1"))
|
|
print("✅ Database connection established!")
|
|
return True
|
|
except Exception as e:
|
|
if time.time() - start_time > timeout:
|
|
print(f"❌ Failed to connect to database after {timeout}s: {e}")
|
|
raise e
|
|
|
|
print(f"⚠️ Database not ready, retrying in {retry_interval}s...")
|
|
await asyncio.sleep(retry_interval)
|
|
|
|
async def create_tables(self):
|
|
"""Tạo tất cả tables từ Base metadata"""
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async def verify_and_migrate_schema(self):
|
|
"""Manually verify and migrate schema for multi-server support"""
|
|
print("🔄 Verifying database schema...")
|
|
from sqlalchemy import text
|
|
|
|
queries = [
|
|
# 1. Add columns first
|
|
"ALTER TABLE bot_configs ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
|
|
"ALTER TABLE home_debt ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
|
|
"ALTER TABLE score ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
|
|
|
|
# 2. Fix Primary Key for bot_configs
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'bot_configs_pkey') THEN
|
|
ALTER TABLE bot_configs DROP CONSTRAINT bot_configs_pkey;
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
# Re-adding PK might fail if there are duplicates, but usually safe if coming from single-tenant
|
|
"ALTER TABLE bot_configs ADD PRIMARY KEY (guild_id, key);",
|
|
|
|
# 3. Add Unique Constraints
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'uq_home_debt_guild_user') THEN
|
|
ALTER TABLE home_debt ADD CONSTRAINT uq_home_debt_guild_user UNIQUE (guild_id, user_id);
|
|
END IF;
|
|
END $$;
|
|
""",
|
|
"""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'uq_score_guild_user') THEN
|
|
ALTER TABLE score ADD CONSTRAINT uq_score_guild_user UNIQUE (guild_id, user_id);
|
|
END IF;
|
|
END $$;
|
|
"""
|
|
]
|
|
|
|
async with self.engine.begin() as conn:
|
|
for q in queries:
|
|
try:
|
|
await conn.execute(text(q))
|
|
except Exception as e:
|
|
# Ignore "multiple primary keys" errors if we ran this partially or if constraints conflict in weird ways
|
|
# But print simple warning
|
|
pass
|
|
print("✅ Schema verification/migration completed.")
|
|
|
|
def get_engine(self) -> AsyncEngine:
|
|
return self.engine
|
|
|
|
def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]:
|
|
return self.session_maker
|
|
|
|
|
|
# Singleton export cho dùng chung
|
|
postgres = PostgresConnection()
|