Source code for myqueue.schedulers.slurm

from __future__ import annotations
import os
import subprocess
from math import ceil

from myqueue.task import Task
from myqueue.schedulers import Scheduler, SchedulerError


[docs] class SLURM(Scheduler):
[docs] def submit(self, task: Task, dry_run: bool = False, verbose: bool = False) -> int: nodelist = self.config.nodes nodes, nodename, nodedct = task.resources.select(nodelist) ntasks = task.resources.processes cpus_per_task = task.resources.cores // ntasks name = task.cmd.short_name sbatch = ['sbatch', f'--partition={nodename}', f'--job-name={name}', f'--time={ceil(task.resources.tmax / 60)}', f'--ntasks={ntasks}', f'--cpus-per-task={cpus_per_task}', f'--nodes={nodes}', f'--chdir={task.folder}', f'--output={name}.%j.out', f'--error={name}.%j.err'] extra_args = self.config.extra_args + nodedct.get('extra_args', []) sbatch += extra_args if task.dtasks: ids = ':'.join(str(tsk.id) for tsk in task.dtasks) sbatch.append(f'--dependency=afterok:{ids}') env = [] cmd = str(task.cmd) if task.resources.processes > 1: if 'OMP_NUM_THREADS' not in os.environ: env.append(('OMP_NUM_THREADS', '1')) mpiexec = self.config.mpiexec if 'mpiargs' in nodedct: mpiexec += ' ' + nodedct['mpiargs'] cmd = mpiexec + ' ' + cmd.replace('python3', self.config.parallel_python) # Use bash for the script script = '#!/bin/bash -l\n' # Add environment variables if len(env) > 0: script += '\n'.join(f'export {name}={val}' for name, val in env) script += '\n' home = self.config.home script += ( 'export MYQUEUE_TASK_ID=$SLURM_JOB_ID\n' f'mq={home}/.myqueue/slurm-$MYQUEUE_TASK_ID\n') script += self.get_venv_activation_line() script += ( '(touch $mq-0 && \\\n' f' cd {str(task.folder)!r} && \\\n' f' {cmd} && \\\n' ' touch $mq-1) || \\\n' '(touch $mq-2; exit 1)\n') if dry_run: if verbose: print(' \\\n '.join(sbatch)) print(script) return 1 # Use a clean set of environment variables without any MPI stuff: p = subprocess.run(sbatch, input=script.encode(), capture_output=True, env=os.environ) if p.returncode: raise SchedulerError((p.stderr + p.stdout).decode()) return int(p.stdout.split()[-1].decode())
[docs] def cancel(self, id: int) -> None: subprocess.run(['scancel', str(id)])
def hold(self, id: int) -> None: subprocess.run(['scontrol', 'hold', str(id)]) def release_hold(self, id: int) -> None: subprocess.run(['scontrol', 'release', str(id)])
[docs] def get_ids(self) -> set[int]: user = os.environ.get('USER', 'test') cmd = ['squeue', '--user', user] p = subprocess.run(cmd, stdout=subprocess.PIPE) queued = {int(line.split()[0].decode()) for line in p.stdout.splitlines()[1:]} return queued
def maxrss(self, id: int) -> int: cmd = ['sacct', '-j', str(id), '-n', '--units=K', '-o', 'MaxRSS'] try: p = subprocess.run(cmd, stdout=subprocess.PIPE) except FileNotFoundError: return 0 mem = 0 for line in p.stdout.splitlines(): line = line.strip() if line.endswith(b'K'): mem = max(mem, int(line[:-1]) * 1000) return mem def get_config(self, queue: str = '') -> tuple[list[tuple[str, int, str]], list[str]]: cmd = ['sinfo', '--noheader', '-O', 'CPUs,Memory,Partition'] p = subprocess.run(cmd, stdout=subprocess.PIPE) nodes = [] for line in p.stdout.decode().splitlines(): cores, mem, name = line.split() nodes.append((name.rstrip('*'), int(cores), mem.rstrip('+') + 'M')) return nodes, []