233 lines
6.5 KiB
Python
233 lines
6.5 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import click
|
|
|
|
if sys.version_info >= (3, 11):
|
|
import tomllib
|
|
else:
|
|
import tomli as tomllib
|
|
|
|
from beanie.migrations import template
|
|
from beanie.migrations.database import DBHandler
|
|
from beanie.migrations.models import RunningDirections, RunningMode
|
|
from beanie.migrations.runner import MigrationNode
|
|
|
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
|
|
|
|
|
class MigrationSettings:
|
|
def __init__(self, **kwargs: Any):
|
|
self.direction = (
|
|
kwargs.get("direction")
|
|
or self.get_env_value("direction")
|
|
or self.get_from_toml("direction")
|
|
or RunningDirections.FORWARD
|
|
)
|
|
|
|
self.distance = int(
|
|
kwargs.get("distance")
|
|
or self.get_env_value("distance")
|
|
or self.get_from_toml("distance")
|
|
or 0
|
|
)
|
|
self.connection_uri = str(
|
|
kwargs.get("connection_uri")
|
|
or self.get_env_value("connection_uri")
|
|
or self.get_from_toml("connection_uri")
|
|
)
|
|
self.database_name = str(
|
|
kwargs.get("database_name")
|
|
or self.get_env_value("database_name")
|
|
or self.get_from_toml("database_name")
|
|
)
|
|
self.path = Path(
|
|
kwargs.get("path")
|
|
or self.get_env_value("path")
|
|
or self.get_from_toml("path")
|
|
)
|
|
self.allow_index_dropping = bool(
|
|
kwargs.get("allow_index_dropping")
|
|
or self.get_env_value("allow_index_dropping")
|
|
or self.get_from_toml("allow_index_dropping")
|
|
or False
|
|
)
|
|
self.use_transaction = bool(kwargs.get("use_transaction"))
|
|
|
|
@staticmethod
|
|
def get_env_value(field_name) -> Any:
|
|
if field_name == "connection_uri":
|
|
value = (
|
|
os.environ.get("BEANIE_URI")
|
|
or os.environ.get("BEANIE_CONNECTION_URI")
|
|
or os.environ.get("BEANIE_CONNECTION_STRING")
|
|
or os.environ.get("BEANIE_MONGODB_DSN")
|
|
or os.environ.get("BEANIE_MONGODB_URI")
|
|
or os.environ.get("beanie_uri")
|
|
or os.environ.get("beanie_connection_uri")
|
|
or os.environ.get("beanie_connection_string")
|
|
or os.environ.get("beanie_mongodb_dsn")
|
|
or os.environ.get("beanie_mongodb_uri")
|
|
)
|
|
elif field_name == "database_name":
|
|
value = (
|
|
os.environ.get("BEANIE_DB")
|
|
or os.environ.get("BEANIE_DB_NAME")
|
|
or os.environ.get("BEANIE_DATABASE_NAME")
|
|
or os.environ.get("beanie_db")
|
|
or os.environ.get("beanie_db_name")
|
|
or os.environ.get("beanie_database_name")
|
|
)
|
|
else:
|
|
value = os.environ.get(
|
|
f"BEANIE_{field_name.upper()}"
|
|
) or os.environ.get(f"beanie_{field_name.lower()}")
|
|
return value
|
|
|
|
@staticmethod
|
|
def get_from_toml(field_name) -> Any:
|
|
path = Path("pyproject.toml")
|
|
if path.is_file():
|
|
with path.open("rb") as f:
|
|
toml_data = tomllib.load(f)
|
|
val = (
|
|
toml_data.get("tool", {})
|
|
.get("beanie", {})
|
|
.get("migrations", {})
|
|
)
|
|
else:
|
|
val = {}
|
|
return val.get(field_name)
|
|
|
|
|
|
@click.group()
|
|
def migrations():
|
|
pass
|
|
|
|
|
|
async def run_migrate(settings: MigrationSettings):
|
|
DBHandler.set_db(settings.connection_uri, settings.database_name)
|
|
root = await MigrationNode.build(settings.path)
|
|
mode = RunningMode(
|
|
direction=settings.direction, distance=settings.distance
|
|
)
|
|
await root.run(
|
|
mode=mode,
|
|
allow_index_dropping=settings.allow_index_dropping,
|
|
use_transaction=settings.use_transaction,
|
|
)
|
|
|
|
# Cleanup
|
|
client = DBHandler.get_cli()
|
|
if client:
|
|
await client.close()
|
|
|
|
|
|
@migrations.command()
|
|
@click.option(
|
|
"--forward",
|
|
"direction",
|
|
required=False,
|
|
flag_value="FORWARD",
|
|
help="Roll the migrations forward. This is default",
|
|
)
|
|
@click.option(
|
|
"--backward",
|
|
"direction",
|
|
required=False,
|
|
flag_value="BACKWARD",
|
|
help="Roll the migrations backward",
|
|
)
|
|
@click.option(
|
|
"-d",
|
|
"--distance",
|
|
required=False,
|
|
help="How many migrations should be done since the current? "
|
|
"0 - all the migrations. Default is 0",
|
|
)
|
|
@click.option(
|
|
"-uri",
|
|
"--connection-uri",
|
|
required=False,
|
|
type=str,
|
|
help="MongoDB connection URI",
|
|
)
|
|
@click.option(
|
|
"-db", "--database_name", required=False, type=str, help="DataBase name"
|
|
)
|
|
@click.option(
|
|
"-p",
|
|
"--path",
|
|
required=False,
|
|
type=str,
|
|
help="Path to the migrations directory",
|
|
)
|
|
@click.option(
|
|
"--allow-index-dropping/--forbid-index-dropping",
|
|
required=False,
|
|
default=False,
|
|
help="if allow-index-dropping is set, Beanie will drop indexes from your collection",
|
|
)
|
|
@click.option(
|
|
"--use-transaction/--no-use-transaction",
|
|
required=False,
|
|
default=True,
|
|
help="Enable or disable the use of transactions during migration. "
|
|
"When enabled (--use-transaction), Beanie uses transactions for migration, "
|
|
"which necessitates a replica set. When disabled (--no-use-transaction), "
|
|
"migrations occur without transactions.",
|
|
)
|
|
def migrate(
|
|
direction,
|
|
distance,
|
|
connection_uri,
|
|
database_name,
|
|
path,
|
|
allow_index_dropping,
|
|
use_transaction,
|
|
):
|
|
settings_kwargs = {}
|
|
if direction:
|
|
settings_kwargs["direction"] = direction
|
|
if distance:
|
|
settings_kwargs["distance"] = distance
|
|
if connection_uri:
|
|
settings_kwargs["connection_uri"] = connection_uri
|
|
if database_name:
|
|
settings_kwargs["database_name"] = database_name
|
|
if path:
|
|
settings_kwargs["path"] = path
|
|
if allow_index_dropping:
|
|
settings_kwargs["allow_index_dropping"] = allow_index_dropping
|
|
settings_kwargs["use_transaction"] = use_transaction
|
|
settings = MigrationSettings(**settings_kwargs)
|
|
|
|
asyncio.run(run_migrate(settings))
|
|
|
|
|
|
@migrations.command()
|
|
@click.option("-n", "--name", required=True, type=str, help="Migration name")
|
|
@click.option(
|
|
"-p",
|
|
"--path",
|
|
required=True,
|
|
type=str,
|
|
help="Path to the migrations directory",
|
|
)
|
|
def new_migration(name, path):
|
|
path = Path(path)
|
|
ts_string = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
file_name = f"{ts_string}_{name}.py"
|
|
|
|
shutil.copy(template.__file__, path / file_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
migrations()
|