docker_utils.py
import boto3
import base64
import docker
def build_docker(tag, path='.'):
print('*' * 30)
print('Start building')
docker_client = docker.from_env()
image, build_log = docker_client.images.build(
path=path, tag=tag, rm=True)
for line in build_log:
if 'stream' in line:
print(line['stream'],end='')
return image
def push_to_ecr(image, ecr_repo_name):
print('*' * 30)
print('Start pushing')
sess = boto3.Session()
resp = sess.client('ecr').get_authorization_token()
token = resp['authorizationData'][0]['authorizationToken']
token = base64.b64decode(token).decode()
username, password = token.split(':')
auth_config = {'username': username, 'password': password}
ecr_url = resp['authorizationData'][0]['proxyEndpoint']
client = docker.from_env()
try:
ecr_client = boto3.client('ecr')
response = ecr_client.create_repository(
repositoryName=ecr_repo_name,
)
print('[Info]Repository {} created'.format(ecr_repo_name))
except:
print('[Info]Repository {} existed'.format(ecr_repo_name))
ecr_repo_name = '{}/{}'.format(
ecr_url.replace('https://', ''), ecr_repo_name)
print(ecr_repo_name)
image.tag(ecr_repo_name, tag='latest')
push_log = client.images.push(ecr_repo_name, auth_config=auth_config)
print(push_log.replace('"status":"', '').replace('{', '').replace('}', '').replace(']', '').replace('"', ''))
return ecr_repo_name
def build_and_push(tag, dockerfile_path, ecr_repo_name):
image = build_docker(tag, dockerfile_path)
print('\n\n\n')
ecr_repo_name = push_to_ecr(image, ecr_repo_name)
return ecr_repo_name
调用格式:
tag
:docker image tag <tag>
的名字dockerfile_path
:Dockerfile
的路径ecr_repo_name
AWS ECR Repo的名字,如果没有对应的repo,会自动创建
from docker_utils import build_and_push
image = build_and_push(tag='xgboost_001', dockerfile_path='/home/ec2-user/SageMaker/self_docker', ecr_repo_name='xgboost_001')
sage_train.py
import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer
import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer
from sagemaker.tuner import IntegerParameter, ContinuousParameter, HyperparameterTuner
def refresh_training_settings(prefix, train_location, val_location, session, role, output_path):
prefix = 'Unnamed-Training-Job'
session = sagemaker.Session() if not session else session
role = get_execution_role() if not role else role
output_path='s3://{}/{}/output'.format(session.default_bucket(), prefix) if not output_path else output_path
if not train_location.startswith('s3://'):
new_train_location = session.upload_data(train_location, key_prefix=prefix)
print('Uploading Training Set {} to {}'.format(train_location, new_train_location))
train_location = new_train_location
if val_location and not val_location.startswith('s3://'):
new_val_location = session.upload_data(val_location, key_prefix=prefix)
print('Uploading Validation Set {} to {}'.format(val_location, new_val_location))
val_location = new_val_location
return prefix, train_location, val_location, session, role, output_path
def get_xgb(prefix, xgb_params,
hyperparameter_ranges,
session, role,
instance_count,
instance_type,
output_path,
objective_metric_name,
objective_type,
max_jobs,
max_parallel_jobs
):
if hyperparameter_ranges and not xgb_params:
raise ValueError('Use hyperparameter_ranges must with xgb_params')
if hyperparameter_ranges:
mapper = {'int' : IntegerParameter,
'float' : ContinuousParameter
}
for key in hyperparameter_ranges:
if type(hyperparameter_ranges[key]) == list:
hyperparameter_ranges[key] = mapper[hyperparameter_ranges[key][0]](*hyperparameter_ranges[key][1:])
container = sagemaker.image_uris.retrieve('xgboost', session.boto_region_name, 'latest')
xgb = sagemaker.estimator.Estimator(container,
role,
instance_count=instance_count,
instance_type=instance_type,
output_path=output_path,
sagemaker_session=session,
base_job_name=prefix
)
if xgb_params:
xgb.set_hyperparameters(**xgb_params)
if xgb_params :
if 'num_round' not in xgb_params:
raise ValueError('xgb_params must include parameter : num_round')
else:
xgb.set_hyperparameters(max_depth=5,
eta=0.2,
gamma=4,
min_child_weight=6,
subsample=0.8,
objective='reg:linear',
early_stopping_rounds=10,
num_round=200)
if hyperparameter_ranges:
xgb = HyperparameterTuner(estimator = xgb,
objective_metric_name = objective_metric_name,
objective_type = objective_type,
max_jobs = max_jobs,
max_parallel_jobs = max_parallel_jobs,
hyperparameter_ranges = hyperparameter_ranges,
base_tuning_job_name = prefix
)
print('Use hyperparameter')
return xgb
def train_xgboost(
train_location,
prefix=None,
xgb_params=None,
val_location=None,
hyperparameter_ranges=None,
session=None,
role=None,
instance_count=1,
instance_type='ml.m4.xlarge',
output_path=None,
objective_metric_name='validation:rmse',
objective_type='Minimize',
max_jobs = 20,
max_parallel_jobs=3,
):
prefix, train_location, val_location, session, role, output_path = refresh_training_settings(
prefix, train_location,
val_location, session,
role,
output_path
)
s3_input_train = sagemaker.inputs.TrainingInput(s3_data=train_location, content_type='csv')
if val_location:
s3_input_validation = sagemaker.inputs.TrainingInput(s3_data=val_location, content_type='csv')
else:
print('[Warining] There is no Validation Set for model validation!')
xgb = get_xgb(prefix, xgb_params,
hyperparameter_ranges,
session, role,
instance_count,
instance_type,
output_path,
objective_metric_name,
objective_type,
max_jobs,
max_parallel_jobs
)
print('*' * 10)
print('[Info]Prefix :', prefix)
print('[Info]Output path :', output_path)
print('[Info]Training Set Location :', train_location)
print('[Info]Validation Set Location :', val_location)
print('[Info]Role :', role)
print('[Info]Session :', session)
print('[Info] Starting XGBoost training process')
if val_location:
xgb.fit({'train': s3_input_train, 'validation': s3_input_validation})
else:
xgb.fit({'train': s3_input_train, 'validation': s3_input_train})
return xgb