导入需要的library

import boto3
import sagemaker
from sagemaker import get_execution_role
import json
import tarfile
import os

训练方法

def train_my_xgboost(train, code_files, script, hyperparameters={}, role=None, prefix=None, bucket=None):
    
    # 创建tar.gz文件
    def create_tar_file(source_files, target=None):
        if target:
            filename = target
        else:
            _, filename = tempfile.mkstemp()

        with tarfile.open(filename, mode="w:gz") as t:
            for sf in source_files:
                # Add all files from the directory into the root of the directory structure of the tar
                t.add(sf, arcname=os.path.basename(sf))
        return filename
    # 超参数encode成json
    def json_encode_hyperparameters(hyperparameters):
        return {str(k): json.dumps(v) for (k, v) in hyperparameters.items()}
    
    
    sagemaker_session = sagemaker.session.Session()
    
    # 取得默认的bucket
    if not bucket:
        print('Using default bucket ', end='')
        bucket = sagemaker_session.default_bucket()
        print(bucket)
    
    if not code_files[0].startswith('s3://'):
        print('Uploading code to S3:', end='')
        # 把代码文件打爆
        create_tar_file(code_files, "sourcedir.tar.gz")
        # 上传代码文件
        sources = sagemaker_session.upload_data("sourcedir.tar.gz", bucket, prefix + "/code")
        print(sources)
    else:
        sources = code_files
    
    # 把代码的s3位置放进超参数
    hyperparameters['sagemaker_submit_directory']= sources
    
    # encode超参数
    hyperparameters = json_encode_hyperparameters(
        hyperparameters
    )
    
    if not role:
        print('Getting default Role', end='')
        role = get_execution_role()
        print(role)
    
    # 放入如下内容
    # docker ecr链接
    # role
    # 同时训练的数量
    # 机器类型
    # training jobs 前缀
    # 超参数
    est = sagemaker.estimator.Estimator(
        # docker镜像
        '337058716437.dkr.ecr.ca-central-1.amazonaws.com/xgboost_001',
        role,
        train_instance_count=1,
        # train_instance_type='ml.m5.xlarge',
        train_instance_type="local",
        base_job_name=prefix,
        hyperparameters=hyperparameters,
    )
    
    # 这个可以做映射的文件,假如有666,那么文件会被挂载到/opt/ml/input/data/666/
    est.fit({"train": 's3://xxx/', "666": 'yyy/'})

调用

# 训练集地址
train = 's3://sagemaker-ca-central-1-337058716437/script-mode-container-2/train/'
# 本地code文件
code_files = ["source_dir/train.py", "source_dir/utils.py", "source_dir/run.sh"]
# 执行文件
script = 'train.py'

prefix = 'xxx'
role = 'arn:aws:iam::337058716437:role/SageMaker-Execution'

hyperparameters = {
                     "sagemaker_program": "train.py",
                     "hp1": {'xgboost':'123',
                             'test':'ttt'
                            },
                     "hp2": 300,
                     "hp3": 0.001,
                   }
train_my_xgboost(train, code_files, script, hyperparameters=hyperparameters,
                 role=role,
                 prefix=prefix
                )
最后修改:2021 年 08 月 10 日 11 : 30 AM
如果觉得我的文章对你有用,请随意赞赏