使用Jinja2批量生成Slurm作业脚本

目录

课题项目对比 HPC 上数据处理算法在不同节点数下的运行时间。 为了得到有说服力的结果,需要多次运行同一测试用例并统计运行时间。 不过,多次重复运行同一个脚本的结果往往不够准确,反复读取相同数据可能受 HPC 并行文件系统缓存机制影响而导致速度变快。 因此最好能在每次运行时使用不同的数据。

一种方案是使用单一作业脚本,在 Python 代码中随机选择数据目录,但该方案的作业脚本每次运行都会使用不同的输入数据,不利于结果重现。 因此,本文批量生成作业脚本,数据目录在脚本生成时随机定义并写入到作业脚本中。

本文介绍如何使用 Python 模板库 Jinja2 批量生成 Slurm 作业脚本。 可以看到仅使用 Jinja2 提供的最基础的模板功能就能实现作业脚本的自动批量生成。

创建模板

首先创建 Slurm 作业运行需要的 Shell 脚本模板。

使用 Jinja2 模板的两种语法:

  • 变量替换: {{ variable_name }},将替换为变量的值,主要用于将运行参数导入到模板中
  • 条件语句:{%- if %} ... {%- else %} ... {%- endif %},根据条件输出不同的代码段,主要用于为串行作业和并行作业配置不同的参数和脚本

定义 Slurm 参数:

#!/bin/bash
#SBATCH -J reki
#SBATCH -p {{ partition }}
{%- if is_parallel %}
#SBATCH -N {{ nodes }}
#SBATCH --ntasks-per-node={{ ntasks_per_node }}
{%- endif %}
#SBATCH -o {{ job_name }}.%j.out
#SBATCH -e {{ job_name }}.%j.err
#SBATCH --comment=Grapes_gfs_post
#SBATCH --no-requeue

其中模板参数 is_parallel 表示作业是否为并行作业。 串行作业仅需要设置队列 partition,并行作业还需要设置节点数 nodes 和每个节点的任务数 ntasks_per_node

运行 Python 脚本:

{%- if is_parallel %}
module load compiler/intel/composer_xe_2017.2.174
module load mpi/intelmpi/2017.2.174
# module load apps/eccodes/2.17.0/intel

export I_MPI_PMI_LIBRARY=/opt/gridview/slurm17/lib/libpmi.so

srun --mpi=pmi2 python -m {{ model_path }} {{ options }}
{%- else  %}
python -m {{ model_path }} {{ options }}
{%- endif %}

串行直接运行 Python 脚本,并行作业需要设置并行环境,并使用 srun 运行 MPI 程序。

渲染模板

渲染模板为 Slurm 作业脚本需要四步:

  • 加载模板
  • 设置模板参数
  • 使用参数渲染模板
  • 将结果写入到文件中

假设上述模板保存为 slurm_job.sh,在该目录下的另一个 Python 脚本中实现模板渲染。

从本地模板文件中载入模板对象 template

file_loader = FileSystemLoader(Path(__file__).parent)
env = Environment(loader=file_loader)

template = env.get_template("slurm_job.sh")

在字典对象 job_params 中设置模板参数:

job_params = dict(
    job_name=output_script_path.stem,
    is_parallel=True,
    partition=partition,
    nodes=nodes,
    ntasks_per_node=ntasks_per_node,
    model_path="reki_data_tool.postprocess.grid.gfs.ne",
    work_directory=work_directory.absolute(),
    options=f"""dask-v1 \\
        --start-time={start_time_label} \\
        --forecast-time={forecast_time_label}h \\
        --output-file-path={output_file_path} \\
        --engine=mpi"""
)

使用参数渲染模板,得到作业脚本的字节数组 task_script_content

task_script_content = template.render(**job_params)

将结果写入到脚本文件 output_script_path 中:

with open(output_script_path, "w") as f:
    f.write(task_script_content)

批量生成

将以上代码封装到命令行CLI函数中,如果没有给定 start_timeforecast_time,则会自动生成随机值。

app = typer.Typer()

@app.command("dask-v1")
def create_dask_v1_task(
        output_script_path: Path = typer.Option(Path(OUTPUT_DIRECTORY, "11-dask-v1", "dask_v1_case_1.sh")),
        work_directory: Path = typer.Option(Path(OUTPUT_DIRECTORY)),
        start_time: str = typer.Option(None),
        forecast_time: str = typer.Option(None),
        nodes: int = 1,
        ntasks_per_node: int = 32,
        partition: str = "normal"
):
    if start_time is None:
        start_time = get_random_start_time()
    else:
        start_time = pd.to_datetime(start_time, format="%Y%m%d%H")
    start_time_label = start_time.strftime("%Y%m%d%H")

    if forecast_time is None:
        forecast_time = get_random_forecast_time()
    else:
        forecast_time = pd.to_timedelta(forecast_time)
    forecast_time_label = f"{int(forecast_time / pd.Timedelta(hours=1)):03}"

    # ...skip...

批量调用该命令行接口就可以生成多个作业脚本。 下面脚本使用 typer 的 CliRunner 模拟执行命令行接口,分别生成节点数为 1/2/4/8 的测试用例,每种用例生成 20 个脚本。

from typer.testing import CliRunner

runner = CliRunner()

nodes_list = (1, 2, 4, 8)
count = 20
partition = "normal"

script_base_directory = Path(CASE_BASE_DIRECTORY, "script")

for node_count in nodes_list:
    for test_index in range(1, count+1):
        logger.info(f"create job script for NODE {node_count} TEST {test_index}...")
        script_path = Path(script_base_directory, f"node_{node_count:02}", f"test_{test_index:02}.cmd")
        script_path.parent.mkdir(parents=True, exist_ok=True)

        work_dir = Path(CASE_BASE_DIRECTORY, f"node_{node_count:02}", f"test_{test_index:02}")
        work_dir.mkdir(parents=True, exist_ok=True)
        result = runner.invoke(app, [
            "dask-v1",
            "--output-script-path", script_path.absolute(),
            "--work-directory", work_dir.absolute(),
            "--nodes", node_count,
            "--partition", partition
        ])

参考

Jinja2 官网

CMA-PI HPC 的 Slurm 简介:《NWPC高性能计算机环境介绍:作业管理

使用 Jinja2 定义批处理系统作业脚本的一篇示例文档:《MDBenchmark documentation: Defining host templates