Source code for

"""Handling and processing of job trees."""

import itertools
import os
import os.path
import warnings

from psyrun.exceptions import JobsRunningWarning
from psyrun.utils.doc import inherit_docs

[docs]class Job(object): """Describes a single processing job. Parameters ---------- name : str Name of the job. submit_fn : function Function to use to submit the job for processing. submit_kwargs : dict Additional Keyword arguments to submit function (in addition to *name* and *depends_on*). dependencies : sequence Identifiers of other jobs that need to finish first before this job can be run. targets : sequence of str Files created by this job. Attributes ---------- name : str Name of the job. submit_fn : function Function to use to submit the job for processing. code : str Python code to execute. dependencies : sequence Identifiers of other jobs that need to finish first before this job can be run. targets : sequence of str Files created by this job. """ def __init__(self, name, submit_fn, submit_kwargs, dependencies, targets): = name self.submit_fn = submit_fn self.submit_kwargs = submit_kwargs self.dependencies = dependencies self.targets = targets
[docs]class JobArray(object): def __init__( self, n, name, submit_fn, single_submit_fn, submit_kwargs, dependency_patterns, target_patterns): self.n = n = name self.submit_fn = submit_fn self.single_submit_fn = single_submit_fn self.submit_kwargs = submit_kwargs self.dependency_patterns = dependency_patterns self.target_patterns = target_patterns = [] for i in range(self.n): dependencies = [ d.replace('%a', str(i)) for d in self.dependency_patterns] targets = [t.replace('%a', str(i)) for t in self.target_patterns] submit_kwargs = dict(self.submit_kwargs) submit_kwargs['args'] = [ a.replace('%a', str(i)) for a in self.submit_kwargs['args']] str(i), self.single_submit_fn, submit_kwargs, dependencies, targets)) @property def dependencies(self): for i in range(self.n): for d in self.dependency_patterns: yield d.replace('%a', str(i)) @property def targets(self): for i in range(self.n): for t in self.target_patterns: yield t.replace('%a', str(i))
[docs]class JobChain(object): """Chain of jobs to run in succession. Parameters ---------- name : str Name of the job chain. jobs : sequence of Job Jobs to run in succession. Attributes ---------- name : str Name of the job chain. jobs : sequence of Job Jobs to run in succession. dependencies : sequence Jobs that need to run first before the job chain can be run (equivalent to the dependencies of the first job in the chain). targets : sequence of str Files created or updated by the job chain (equivalent to the targets of the last job in the chain). """ def __init__(self, name, jobs): = name = jobs @property def dependencies(self): return[0].dependencies @property def targets(self): return[-1].targets
[docs]class JobGroup(object): """Group of jobs that can run in parallel. Parameters ---------- name : str Name of the job group. jobs : sequence of Job Jobs to run in the job group. Attributes ---------- name : str Name of the job group. jobs : sequence of Job Jobs to run in the job group. dependencies : sequence Jobs that need to run first before the job group can be run (equivalent to the union of all the group's job's dependencies). targets : sequence of str Files that will be created or updated by the group's jobs (equivalent to the union of all the group's job's targets). """ def __init__(self, name, jobs): = name = jobs @property def dependencies(self): return itertools.chain(j.dependencies for j in @property def targets(self): return itertools.chain.from_iterable(j.targets for j in
[docs]class JobTreeVisitor(object): """Abstract base class to implement visitors on trees of jobs. Base class to implement visitors following the Visitor pattern to traverse the tree constructed out of `Job`, `JobChain`, and `JobGroup` instances. A deriving class should overwrite `visit_job`, `visit_chain`, and `visit_group`. Use the `visit` method to start visiting a tree of jobs. """ def __init__(self): self._dispatcher = { Job: self.visit_job, JobArray: self.visit_array, JobChain: self.visit_chain, JobGroup: self.visit_group, }
[docs] def visit_job(self, job): raise NotImplementedError()
[docs] def visit_array(self, job_array): return self.visit_group(job_array)
[docs] def visit_chain(self, chain): raise NotImplementedError()
[docs] def visit_group(self, group): raise NotImplementedError()
[docs] def visit(self, job): """Visit all jobs in the tree *job*.""" return self._dispatcher[job.__class__](job)
[docs]@inherit_docs class Submit(JobTreeVisitor): """Submit all jobs that are not up-to-date. The constructor will call `visit`. Parameters ---------- job : job tree Tree of jobs to submit. names : dict Maps jobs to their names. (Can be obtained with `Fullname`.) uptodate : dict Maps jobs to their up-to-date status. (Can be obtained with `Uptodate`.) Attributes ---------- names : dict Maps jobs to their names. uptodate : dict Maps jobs to their up-to-date status. """ def __init__(self, job, names, uptodate): super(Submit, self).__init__() self.names = names self.uptodate = uptodate self._depends_on = [] self.visit(job)
[docs] def visit_job(self, job): if self.uptodate.status[job]: print('-', self.names[job]) return [] else: print('.', self.names[job]) return [job.submit_fn( name=self.names[job], depends_on=self._depends_on, **job.submit_kwargs)]
[docs] def visit_array(self, job): if self.uptodate.status[job]: print('-', self.names[job]) return [] else: print('.', self.names[job]) try: return [job.submit_fn( job.n, name=self.names[job], depends_on=self._depends_on, **job.submit_kwargs)] except NotImplementedError: return self.visit_group(job)
[docs] def visit_group(self, group): return sum((self.visit(job) for job in, [])
[docs] def visit_chain(self, chain): old_depends_on = self._depends_on job_ids = [] for job in ids = self.visit(job) job_ids.extend(ids) self._depends_on = old_depends_on + ids self._depends_on = old_depends_on return job_ids
[docs]@inherit_docs class Clean(JobTreeVisitor): """Clean all target files and supporting files of jobs that are outdated. The constructor will call visit. Parameters ---------- job : job tree Tree of jobs to clean. task : TaskDef Task that generated the job tree. names : dict Maps jobs to their names. (Can be obtained with Fullname.) uptodate : dict, optional Maps jobs to their up-to-date status. (Can be obtained with Uptodate.) If not provided, all jobs are treated as outdated. Attributes ---------- task : TaskDef Task that generated the job tree. names : dict Maps jobs to their names. uptodate : dict Maps jobs to their up-to-date status. """ def __init__(self, job, task, names, uptodate=None): super(Clean, self).__init__() self.task = task self.names = names if uptodate is None: self.uptodate = {} else: self.uptodate = uptodate.status self.visit(job)
[docs] def visit_job(self, job): if self.uptodate.get(job, False): return workdir = os.path.join(self.task.workdir, for item in os.listdir(workdir): if item.startswith(self.names[job]): os.remove(os.path.join(workdir, item)) for t in job.targets: if os.path.exists(t): os.remove(t)
[docs] def visit_chain(self, chain): for job in self.visit(job)
[docs] def visit_group(self, group): for job in self.visit(job)
[docs]@inherit_docs class Fullname(JobTreeVisitor): """Construct names of the jobs. The constructor will call `visit`. Parameters ---------- jobtree : job tree Tree of jobs to construct names for. Attributes ---------- names : dict Maps jobs to their names. """ def __init__(self, jobtree): super(Fullname, self).__init__() self._prefix = '' self.names = {} self.visit(jobtree)
[docs] def visit_job(self, job): self.names[job] = self._prefix +
[docs] def visit_chain(self, chain): self.visit_group(chain)
[docs] def visit_group(self, group): self.names[group] = self._prefix + old_prefix = self._prefix self._prefix += + ':' for job in self.visit(job) self._prefix = old_prefix
[docs]@inherit_docs class Uptodate(JobTreeVisitor): """Determines the up-to-date status of jobs. The constructor will call visit. Parameters ---------- jobtree : job tree Tree of jobs to determine the up-to-date status for. names : dict Maps jobs to their names. (Can be obtained with Fullname.) task : TaskDef Task that generated the job tree. Attributes ---------- names : dict Maps jobs to their names. task : TaskDef Task that generated the job tree. status : dict Maps jobs to their up-to-date status. """ def __init__(self, jobtree, names, task): super(Uptodate, self).__init__() self.names = names self.task = task self.status = {} self._clamp = None self.any_queued = False self.outdated = False self.visit(jobtree) self.post_visit()
[docs] def post_visit(self): """Called after `visit`. Checks whether jobs are still running and marks these as up-to-date while issuing a warning. """ skip = False if self.any_queued and self.outdated: skip = True warnings.warn(JobsRunningWarning( if skip: for k in self.status: self.status[k] = True
[docs] def visit_job(self, job): if self.is_job_queued(job): self.status[job] = True elif self._clamp is None: tref = self._get_tref(job.dependencies) self.status[job] = self.files_uptodate(tref, job.targets) else: self.status[job] = self._clamp return self.status[job]
[docs] def visit_chain(self, chain): if self._clamp is None: tref = self._get_tref([0].dependencies) last_uptodate = -1 for i, job in enumerate(reversed( if self.files_uptodate(tref, job.targets): last_uptodate = len( - i - 1 break for i, job in enumerate( if i <= last_uptodate: self._clamp = True elif i == last_uptodate + 1: self._clamp = None else: self._clamp = False self.visit(job) self.status[chain] = last_uptodate + 1 == len( self._clamp = None else: for job in self.visit(job) self.status[chain] = self._clamp return self.status[chain]
[docs] def visit_group(self, group): subtask_status = [self.visit(j) for j in] self.status[group] = all(subtask_status) return self.status[group]
[docs] def is_job_queued(self, job): """Checks whether *job* is queud.""" job_names = [ self.task.scheduler.get_status(j).name for j in self.task.scheduler.get_jobs()] is_queued = self.names[job] in job_names self.any_queued |= is_queued return is_queued
[docs] def files_uptodate(self, tref, targets): """Checks that all *targets* are newer than *tref*.""" uptodate = all( self._is_newer_than_tref(target, tref) for target in targets) self.outdated |= not uptodate return uptodate
def _get_tref(self, dependencies): tref = 0 deps = [d for d in dependencies if os.path.exists(d)] if len(deps) > 0: tref = max(os.stat(d).st_mtime for d in deps) return tref def _is_newer_than_tref(self, filename, tref): return os.path.exists(filename) and os.stat(filename).st_mtime >= tref