Files
virtus-bot/infra/db/postgres.py
T

157 lines
5.7 KiB
Python

import os
from typing import Optional
from urllib.parse import quote_plus
from dotenv import load_dotenv
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
# Import Base để có thể tạo tables
from infra.db.base import Base
load_dotenv()
def _normalize_asyncpg_url(url: str) -> str:
if url.startswith("postgresql+asyncpg://"):
return url
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://", 1)
if url.startswith("postgres://"):
return url.replace("postgres://", "postgresql+asyncpg://", 1)
return url
class PostgresConnection:
_instance: Optional["PostgresConnection"] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PostgresConnection, cls).__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self) -> None:
# Ưu tiên POSTGRES_URL, fallback từ các biến rời
raw_url = os.getenv("POSTGRES_URL")
if not raw_url:
host = os.getenv("POSTGRES_HOST", "localhost")
port = os.getenv("POSTGRES_PORT", "5432")
user = os.getenv("POSTGRES_USER", "postgres")
password = os.getenv("POSTGRES_PASSWORD", "")
database = os.getenv("POSTGRES_DB", "postgres")
# Encode password để handle ký tự đặc biệt như @, /, etc.
if password:
auth = f"{user}:{quote_plus(password)}@"
else:
auth = f"{user}@"
raw_url = f"postgresql+asyncpg://{auth}{host}:{port}/{database}"
url = _normalize_asyncpg_url(raw_url)
print(f"🔌 Connecting to database at: {url.split('@')[-1] if '@' in url else url}")
self.engine: AsyncEngine = create_async_engine(
url,
echo=os.getenv("POSTGRES_ECHO", "false").lower() == "true",
pool_size=int(os.getenv("POSTGRES_POOL_SIZE", "5")),
max_overflow=int(os.getenv("POSTGRES_MAX_OVERFLOW", "10")),
pool_pre_ping=True,
)
self.session_maker: async_sessionmaker[AsyncSession] = async_sessionmaker(
self.engine, expire_on_commit=False
)
async def wait_for_connection(self, timeout: int = 60, retry_interval: int = 2):
"""Wait for database connection to be ready"""
import asyncio
from sqlalchemy import text
import time
start_time = time.time()
while True:
try:
async with self.engine.begin() as conn:
await conn.execute(text("SELECT 1"))
print("✅ Database connection established!")
return True
except Exception as e:
if time.time() - start_time > timeout:
print(f"❌ Failed to connect to database after {timeout}s: {e}")
raise e
print(f"⚠️ Database not ready, retrying in {retry_interval}s...")
await asyncio.sleep(retry_interval)
async def create_tables(self):
"""Tạo tất cả tables từ Base metadata"""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def verify_and_migrate_schema(self):
"""Manually verify and migrate schema for multi-server support"""
print("🔄 Verifying database schema...")
from sqlalchemy import text
queries = [
# 1. Add columns first
"ALTER TABLE bot_configs ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
"ALTER TABLE home_debt ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
"ALTER TABLE score ADD COLUMN IF NOT EXISTS guild_id BIGINT DEFAULT 0;",
# 2. Fix Primary Key for bot_configs
"""
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'bot_configs_pkey') THEN
ALTER TABLE bot_configs DROP CONSTRAINT bot_configs_pkey;
END IF;
END $$;
""",
# Re-adding PK might fail if there are duplicates, but usually safe if coming from single-tenant
"ALTER TABLE bot_configs ADD PRIMARY KEY (guild_id, key);",
# 3. Add Unique Constraints
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'uq_home_debt_guild_user') THEN
ALTER TABLE home_debt ADD CONSTRAINT uq_home_debt_guild_user UNIQUE (guild_id, user_id);
END IF;
END $$;
""",
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'uq_score_guild_user') THEN
ALTER TABLE score ADD CONSTRAINT uq_score_guild_user UNIQUE (guild_id, user_id);
END IF;
END $$;
"""
]
async with self.engine.begin() as conn:
for q in queries:
try:
await conn.execute(text(q))
except Exception as e:
# Ignore "multiple primary keys" errors if we ran this partially or if constraints conflict in weird ways
# But print simple warning
pass
print("✅ Schema verification/migration completed.")
def get_engine(self) -> AsyncEngine:
return self.engine
def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]:
return self.session_maker
# Singleton export cho dùng chung
postgres = PostgresConnection()