docker_utils.py

import boto3
import base64
import docker

def build_docker(tag, path='.'):
    print('*' * 30)
    print('Start building')
    docker_client = docker.from_env()
    image, build_log = docker_client.images.build(
        path=path, tag=tag, rm=True)
    for line in build_log:
        if 'stream' in line:
            print(line['stream'],end='')
    return image

def push_to_ecr(image, ecr_repo_name):
    print('*' * 30)
    print('Start pushing')
    sess = boto3.Session()
    resp = sess.client('ecr').get_authorization_token()
    token = resp['authorizationData'][0]['authorizationToken']
    token = base64.b64decode(token).decode()
    username, password = token.split(':')
    auth_config = {'username': username, 'password': password}
    
    ecr_url = resp['authorizationData'][0]['proxyEndpoint']
    
    client = docker.from_env()
    
    try:
        ecr_client = boto3.client('ecr')
        response = ecr_client.create_repository(
            repositoryName=ecr_repo_name,
        )
        print('[Info]Repository {} created'.format(ecr_repo_name))
    except:
        print('[Info]Repository {} existed'.format(ecr_repo_name))
    
    ecr_repo_name = '{}/{}'.format(
        ecr_url.replace('https://', ''), ecr_repo_name)
    print(ecr_repo_name)
    
    image.tag(ecr_repo_name, tag='latest')
    
    push_log = client.images.push(ecr_repo_name, auth_config=auth_config)
    print(push_log.replace('"status":"', '').replace('{', '').replace('}', '').replace(']', '').replace('"', ''))
    return ecr_repo_name

def build_and_push(tag, dockerfile_path, ecr_repo_name):
    image = build_docker(tag, dockerfile_path)
    print('\n\n\n')
    ecr_repo_name = push_to_ecr(image, ecr_repo_name)
    return ecr_repo_name

调用格式:

  • tag: docker image tag <tag>的名字
  • dockerfile_path: Dockerfile的路径
  • ecr_repo_name AWS ECR Repo的名字,如果没有对应的repo,会自动创建
from docker_utils import build_and_push

image = build_and_push(tag='xgboost_001', dockerfile_path='/home/ec2-user/SageMaker/self_docker', ecr_repo_name='xgboost_001')

sage_train.py

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer

import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.predictor import csv_serializer
from sagemaker.tuner import IntegerParameter, ContinuousParameter, HyperparameterTuner

def refresh_training_settings(prefix, train_location, val_location, session, role, output_path):
    prefix = 'Unnamed-Training-Job'
    session = sagemaker.Session() if not session else session
    role = get_execution_role() if not role else role
    
    output_path='s3://{}/{}/output'.format(session.default_bucket(), prefix) if not output_path else output_path
    
    if not train_location.startswith('s3://'):
        
        new_train_location = session.upload_data(train_location, key_prefix=prefix)
        print('Uploading Training Set {} to {}'.format(train_location, new_train_location))
        train_location = new_train_location
        
    if val_location and not val_location.startswith('s3://'):
        new_val_location = session.upload_data(val_location, key_prefix=prefix)
        print('Uploading Validation Set {} to {}'.format(val_location, new_val_location))
        val_location = new_val_location

    
    return prefix, train_location, val_location, session, role, output_path

def get_xgb(prefix, xgb_params, 
            hyperparameter_ranges, 
            session, role, 
            instance_count, 
            instance_type, 
            output_path,
            objective_metric_name,
            objective_type,
            max_jobs,
            max_parallel_jobs
           ):
    
    if hyperparameter_ranges and not xgb_params:
        raise ValueError('Use hyperparameter_ranges must with xgb_params')
    
    if hyperparameter_ranges:
        mapper = {'int' : IntegerParameter,
             'float' : ContinuousParameter
        }
        
        for key in hyperparameter_ranges:
            if type(hyperparameter_ranges[key]) == list:
                hyperparameter_ranges[key] = mapper[hyperparameter_ranges[key][0]](*hyperparameter_ranges[key][1:])
    
    container = sagemaker.image_uris.retrieve('xgboost', session.boto_region_name, 'latest')
    
    xgb = sagemaker.estimator.Estimator(container,
                                        role, 
                                        instance_count=instance_count, 
                                        instance_type=instance_type, 
                                        output_path=output_path,
                                        sagemaker_session=session,
                                        base_job_name=prefix
                                       )
    if xgb_params:
        xgb.set_hyperparameters(**xgb_params)
    if xgb_params :
        if 'num_round' not in xgb_params:
            raise ValueError('xgb_params must include parameter : num_round')
    else:
        xgb.set_hyperparameters(max_depth=5,
                            eta=0.2,
                            gamma=4,
                            min_child_weight=6,
                            subsample=0.8,
                            objective='reg:linear',
                            early_stopping_rounds=10,
                            num_round=200)
    if hyperparameter_ranges:
        xgb = HyperparameterTuner(estimator = xgb,
                                               objective_metric_name = objective_metric_name, 
                                               objective_type = objective_type,
                                               max_jobs = max_jobs,
                                               max_parallel_jobs = max_parallel_jobs,
                                               hyperparameter_ranges = hyperparameter_ranges,
                                                   base_tuning_job_name = prefix
                                                  )
        print('Use hyperparameter')
    return xgb
def train_xgboost( 
                  train_location,
                  prefix=None,
                  xgb_params=None,
                  val_location=None,
                  hyperparameter_ranges=None,
                  session=None, 
                  role=None, 
                  instance_count=1, 
                  instance_type='ml.m4.xlarge', 
                  output_path=None,
                  objective_metric_name='validation:rmse',
                  objective_type='Minimize',
                  max_jobs = 20,
                  max_parallel_jobs=3,
                 ):
    
    prefix, train_location, val_location, session, role, output_path = refresh_training_settings(
        prefix, train_location, 
        val_location, session, 
        role, 
        output_path
    )
    
    s3_input_train = sagemaker.inputs.TrainingInput(s3_data=train_location, content_type='csv')
    
    if val_location:
        s3_input_validation = sagemaker.inputs.TrainingInput(s3_data=val_location, content_type='csv')
    else:
        print('[Warining] There is no Validation Set for model validation!')
    
    xgb = get_xgb(prefix, xgb_params, 
            hyperparameter_ranges, 
            session, role, 
            instance_count, 
            instance_type, 
            output_path,
            objective_metric_name,
            objective_type,
            max_jobs,
            max_parallel_jobs
           )
    
    print('*' * 10)
    print('[Info]Prefix :', prefix)
    print('[Info]Output path :', output_path)
    print('[Info]Training Set Location :', train_location)
    print('[Info]Validation Set Location :', val_location)
    print('[Info]Role :', role)
    print('[Info]Session :', session)
    
    print('[Info] Starting XGBoost training process')
    
    if val_location:
        xgb.fit({'train': s3_input_train, 'validation': s3_input_validation})
    else:
        xgb.fit({'train': s3_input_train, 'validation': s3_input_train})
    return xgb
最后修改:2021 年 11 月 02 日 10 : 37 PM
如果觉得我的文章对你有用,请随意赞赏