本文以Python为例,您在创建函数时
运行环境选择
Python3.9,示例代码如下。(您可以根据实际应用情况修改)
# -*- coding: utf-8 -*-
import json
import logging
import os
try:
import pymysql
except:
os.system('pip install pymysql -t ./')
import pymysql
from aliyunsdkcore.acs_exception.exceptions import ServerException
from aliyunsdkcore.auth.credentials import StsTokenCredential
from aliyunsdkcore.client import AcsClient
from aliyunsdkkms.request.v20160120.GetRandomPasswordRequest import GetRandomPasswordRequest
from aliyunsdkkms.request.v20160120.GetSecretValueRequest import GetSecretValueRequest
from aliyunsdkkms.request.v20160120.PutSecretValueRequest import PutSecretValueRequest
from aliyunsdkkms.request.v20160120.UpdateSecretVersionStageRequest import UpdateSecretVersionStageRequest
from aliyunsdkrds.request.v20140815.DescribeDBInstancesRequest import DescribeDBInstancesRequest
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def handler(event, context):
evt = json.loads(event)
secret_name = evt['SecretName']
region_id = evt['RegionId']
step = evt['Step']
version_id = evt.get('VersionId')
if not version_id:
version_id = context.requestId
credentials = StsTokenCredential(context.credentials.accessKeyId, context.credentials.accessKeySecret,
context.credentials.securityToken)
client = AcsClient(region_id=region_id, credential=credentials)
endpoint = "kms-vpc." + region_id + ".aliyuncs.com"
client.add_endpoint(region_id, 'kms', endpoint)
resp = get_secret_value(client, secret_name)
if "Generic" != resp['SecretType']:
logger.error("Secret %s is not enabled for rotation" % secret_name)
raise ValueError("Secret %s is not enabled for rotation" % secret_name)
if step == "new":
new_phase(client, secret_name, version_id)
elif step == "set":
set_phase(client, secret_name, version_id)
elif step == "test":
test_phase(client, secret_name, version_id)
elif step == "end":
end_phase(client, secret_name, version_id)
else:
logger.error("handler: Invalid step parameter %s for secret %s" % (step, secret_name))
raise ValueError("Invalid step parameter %s for secret %s" % (step, secret_name))
return {"VersionId": version_id}
def new_phase(client, secret_name, version_id):
current_dict = get_secret_dict(client, secret_name, "ACSCurrent")
try:
get_secret_dict(client, secret_name, "ACSPending", version_id)
logger.info("new: Successfully retrieved secret for %s." % secret_name)
except ServerException as e:
if e.error_code != 'Forbidden.ResourceNotFound':
raise
current_dict['AccountName'] = get_alt_account_name(current_dict['AccountName'])
exclude_characters = os.environ['EXCLUDE_CHARACTERS'] if 'EXCLUDE_CHARACTERS' in os.environ else '/@"\'\\'
passwd = get_random_password(client, exclude_characters)
current_dict['AccountPassword'] = passwd['RandomPassword']
put_secret_value(client, secret_name, version_id, json.dumps(current_dict),
json.dumps(['ACSPending']))
logger.info(
"new: Successfully put secret for secret_name %s and version %s." % (secret_name, version_id))
def set_phase(client, secret_name, version_id):
current_dict = get_secret_dict(client, secret_name, "ACSCurrent")
pending_dict = get_secret_dict(client, secret_name, "ACSPending", version_id)
conn = get_connection(pending_dict)
if conn:
conn.close()
logger.info(
"set: ACSPending secret is already set as password in MySQL DB for secret secret_name %s." % secret_name)
return
if get_alt_account_name(current_dict['AccountName']) != pending_dict['AccountName']:
logger.error("set: Attempting to modify user %s other than current user or rotation %s" % (
pending_dict['AccountName'], current_dict['AccountName']))
raise ValueError("Attempting to modify user %s other than current user or rotation %s" % (
pending_dict['AccountName'], current_dict['AccountName']))
if current_dict['Endpoint'] != pending_dict['Endpoint']:
logger.error("set: Attempting to modify user for Endpoint %s other than current Endpoint %s" % (
pending_dict['Endpoint'], current_dict['Endpoint']))
raise ValueError("Attempting to modify user for Endpoint %s other than current Endpoint %s" % (
pending_dict['Endpoint'], current_dict['Endpoint']))
conn = get_connection(current_dict)
if not conn:
logger.error("set: Unable to access the given database using current credentials for secret %s" % secret_name)
raise ValueError("Unable to access the given database using current credentials for secret %s" % secret_name)
conn.close()
master_secret = current_dict['MasterSecret']
master_dict = get_secret_dict(client, master_secret, "ACSCurrent")
if current_dict['Endpoint'] != master_dict['Endpoint'] and not is_rds_replica_database(current_dict, master_dict):
logger.error("set: Current database Endpoint %s is not the same Endpoint as/rds replica of master %s" % (
current_dict['Endpoint'], master_dict['Endpoint']))
raise ValueError("Current database Endpoint %s is not the same Endpoint as/rds replica of master %s" % (
current_dict['Endpoint'], master_dict['Endpoint']))
conn = get_connection(master_dict)
if not conn:
logger.error(
"set: Unable to access the given database using credentials in master secret secret %s" % master_secret)
raise ValueError("Unable to access the given database using credentials in master secret secret %s" % master_secret)
try:
with conn.cursor() as cur:
cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['AccountName'])
if cur.rowcount == 0:
cur.execute("CREATE USER %s IDENTIFIED BY %s",
(pending_dict['AccountName'], pending_dict['AccountPassword']))
cur.execute("SHOW GRANTS FOR %s", current_dict['AccountName'])
for row in cur.fetchall():
if 'XA_RECOVER_ADMIN' in row[0]:
continue
grant = row[0].split(' TO ')
new_grant_escaped = grant[0].replace('%', '%%') # % is a special cha30racter in Python format strings.
cur.execute(new_grant_escaped + " TO %s ", (pending_dict['AccountName'],))
cur.execute("SELECT VERSION()")
ver = cur.fetchone()[0]
escaped_encryption_statement = get_escaped_encryption_statement(ver)
cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s",
current_dict['AccountName'])
tls_options = cur.fetchone()
ssl_type = tls_options[0]
if not ssl_type:
cur.execute(escaped_encryption_statement + " NONE", pending_dict['AccountName'])
elif ssl_type == "ANY":
cur.execute(escaped_encryption_statement + " SSL", pending_dict['AccountName'])
elif ssl_type == "X509":
cur.execute(escaped_encryption_statement + " X509", pending_dict['AccountName'])
else:
cur.execute(escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s",
(pending_dict['AccountName'], tls_options[1], tls_options[2], tls_options[3]))
password_option = get_password_option(ver)
cur.execute("SET PASSWORD FOR %s = " + password_option,
(pending_dict['AccountName'], pending_dict['AccountPassword']))
conn.commit()
logger.info("set: Successfully changed password for %s in MySQL DB for secret secret_name %s." % (
pending_dict['AccountName'], secret_name))
finally:
conn.close()
def test_phase(client, secret_name, version_id):
conn = get_connection(get_secret_dict(client, secret_name, "ACSPending", version_id))
if conn:
try:
with conn.cursor() as cur:
cur.execute("SELECT NOW()")
conn.commit()
finally:
conn.close()
logger.info("test: Successfully accessed into MySQL DB with ACSPending secret in %s." % secret_name)
return
else:
logger.error(
"test: Unable to access the given database with pending secret of secret secret_name %s" % secret_name)
raise ValueError("Unable to access the given database with pending secret of secret secret_name %s" % secret_name)
def end_phase(client, secret_name, version_id):
update_secret_version_stage(client, secret_name, 'ACSCurrent', move_to_version=version_id)
update_secret_version_stage(client, secret_name, 'ACSPending', remove_from_version=version_id)
logger.info(
"end: Successfully update ACSCurrent stage to version %s for secret %s." % (version_id, secret_name))
def get_connection(secret_dict):
port = int(secret_dict['Port']) if 'Port' in secret_dict else 3306
dbname = secret_dict['DBName'] if 'DBName' in secret_dict else None
use_ssl, fall_back = get_ssl_config(secret_dict)
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
if conn or not fall_back:
return conn
else:
return connect_and_authenticate(secret_dict, port, dbname, False)
def get_ssl_config(secret_dict):
if 'SSL' not in secret_dict:
return True, True
if isinstance(secret_dict['SSL'], bool):
return secret_dict['SSL'], False
if isinstance(secret_dict['SSL'], str):
ssl = secret_dict['SSL'].lower()
if ssl == "true":
return True, False
elif ssl == "false":
return False, False
else:
return True, True
return True, True
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
ssl = {'ca': '/opt/python/certs/cert.pem'} if use_ssl else None
try:
conn = pymysql.connect(host=secret_dict['Endpoint'], user=secret_dict['AccountName'],
password=secret_dict['AccountPassword'],
port=port, database=dbname, connect_timeout=5, ssl=ssl)
logger.info("Successfully established %s connection as user '%s' with Endpoint: '%s'" % (
"SSL/TLS" if use_ssl else "non SSL/TLS", secret_dict['AccountName'], secret_dict['Endpoint']))
return conn
except pymysql.OperationalError as e:
if 'certificate verify failed: IP address mismatch' in e.args[1]:
logger.error(
"Hostname verification failed when estlablishing SSL/TLS Handshake with Endpoint: %s" % secret_dict[
'Endpoint'])
return None
def get_secret_dict(client, secret_name, stage, version_id=None):
required_fields = ['Endpoint', 'AccountName', 'AccountPassword']
if version_id:
secret = get_secret_value(client, secret_name, version_id, stage)
else:
secret = get_secret_value(client, secret_name, stage=stage)
plaintext = secret['SecretData']
secret_dict = json.loads(plaintext)
for field in required_fields:
if field not in secret_dict:
raise KeyError("%s key is missing from secret JSON" % field)
return secret_dict
def get_alt_account_name(current_account_name):
rotation_suffix = "_rt"
if current_account_name.endswith(rotation_suffix):
return current_account_name[:(len(rotation_suffix) * -1)]
else:
new_account_name = current_account_name + rotation_suffix
if len(new_account_name) > 16:
raise ValueError(
"Unable to rotate user, account_name length with _rotation appended would exceed 16 characters")
return new_account_name
def get_password_option(version):
if version.startswith("8"):
return "%s"
else:
return "PASSWORD(%s)"
def get_escaped_encryption_statement(version):
if version.startswith("5.6"):
return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE"
else:
return "ALTER USER %s@'%%' REQUIRE"
def is_rds_replica_database(client, replica_dict, master_dict):
replica_instance_id = replica_dict['Endpoint'].split(".")[0].replace('io', '')
master_instance_id = master_dict['Endpoint'].split(".")[0].replace('io', '')
try:
describe_response = describe_db_instances(client, replica_instance_id)
except Exception as err:
logger.warning("Encountered error while verifying rds replica status: %s" % err)
return False
items = describe_response['Items']
instances = items.get("DBInstance")
if not instances:
logger.info("Cannot verify replica status - no RDS instance found with identifier: %s" % replica_instance_id)
return False
current_instance = instances[0]
return master_instance_id == current_instance.get('DBInstanceId')
def get_secret_value(client, secret_name, version_id=None, stage=None):
request = GetSecretValueRequest()
request.set_accept_format('json')
request.set_SecretName(secret_name)
if version_id:
request.set_VersionId(version_id)
if stage:
request.set_VersionStage(stage)
response = client.do_action_with_exception(request)
return json.loads(response)
def put_secret_value(client, secret_name, version_id, secret_data, version_stages=None):
request = PutSecretValueRequest()
request.set_accept_format('json')
request.set_SecretName(secret_name)
request.set_VersionId(version_id)
if version_stages:
request.set_VersionStages(version_stages)
request.set_SecretData(secret_data)
response = client.do_action_with_exception(request)
return json.loads(response)
def get_random_password(client, exclude_characters=None):
request = GetRandomPasswordRequest()
request.set_accept_format('json')
if exclude_characters:
request.set_ExcludeCharacters(exclude_characters)
response = client.do_action_with_exception(request)
return json.loads(response)
def update_secret_version_stage(client, secret_name, version_stage, remove_from_version=None, move_to_version=None):
request = UpdateSecretVersionStageRequest()
request.set_accept_format('json')
request.set_VersionStage(version_stage)
request.set_SecretName(secret_name)
if remove_from_version:
request.set_RemoveFromVersion(remove_from_version)
if move_to_version:
request.set_MoveToVersion(move_to_version)
response = client.do_action_with_exception(request)
return json.loads(response)
def describe_db_instances(client, db_instance_id):
request = DescribeDBInstancesRequest()
request.set_accept_format('json')
request.set_DBInstanceId(db_instance_id)
response = client.do_action_with_exception(request)
return json.loads(response)