Skip to content
226 changes: 226 additions & 0 deletions config_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
Configuration adapters for mapping native specifications from DRM to DRMAA API
"""

from __future__ import annotations
from typing import (List, ClassVar, Union, Optional, TYPE_CHECKING)

from dataclasses import dataclass, asdict, fields, InitVar
from abc import ABC, abstractmethod
import re

if TYPE_CHECKING:
from drmaa import JobTemplate

# DRMAA specific fields, anything else should be put into native spec
DRMAA_FIELDS = [
"email", "deadlineTime", "errorPath", "hardRunDurationLimit",
"hardWallclockTimeLimit", "inputPath", "outputPath", "jobCategory",
"jobName", "outputPath", "workingDirectory", "transferFiles",
"remoteCommand", "args", "jobName", "jobCategory", "blockEmail"
]

TIMESTR_VALIDATE = re.compile("^(\\d+:)?[0-9][0-9]:[0-9][0-9]$")


@dataclass
class DRMAACompatible(ABC):
'''
Abstract dataclass for mapping DRM specific configuration to a
DRMAA compatible specification

Properties:
_mapped_fields: List of DRM specific keys to re-map onto
the DRMAA specification if used. Preferably users will
use the DRMAA variant of these specifications rather than
the corresponding native specification
'''

_mapped_fields: ClassVar[List[str]]

def __str__(self):
'''
Display formatted configuration for executor
'''
attrs = asdict(self)
drmaa_fields = "\n".join([
f"{field}:\t{attrs.get(field)}" for field in DRMAA_FIELDS
if attrs.get(field) is not None
])

drm_fields = "\n".join([
f"{field}:\t{attrs.get(field)}" for field in self._native_fields()
if attrs.get(field) is not None
])

return ("DRMAA Config:\n" + drmaa_fields + "\nNative Specification\n" +
drm_fields)

def get_drmaa_config(self, jt: JobTemplate) -> JobTemplate:
'''
Apply settings onto DRMAA JobTemplate
'''

for field in DRMAA_FIELDS:
value = getattr(self, field, None)
if value is not None:
setattr(jt, field, value)

jt.nativeSpecification = self.drm2drmaa()
return jt

@abstractmethod
def drm2drmaa(self) -> str:
'''
Build native specification from DRM-specific fields
'''

def _native_fields(self):
return [
f for f in asdict(self).keys()
if (f not in self._mapped_fields) and (f not in DRMAA_FIELDS)
]

def set_fields(self, **drmaa_kwargs):
for field, value in drmaa_kwargs.items():
if field not in DRMAA_FIELDS:
raise AttributeError(
"Malformed adapter class! Cannot map field"
f" {field} to a DRMAA-compliant field")

setattr(self, field, value)


@dataclass
class DRMAAConfig(DRMAACompatible):
def drm2drmaa(self):
return


@dataclass
class SlurmConfig(DRMAACompatible):
'''
Transform SLURM resource specification into DRMAA-compliant inputs

References:
See https://github.com/natefoo/slurm-drmaa for native specification
details
'''

_mapped_fields: ClassVar[List[str]] = {
"error", "output", "job_name", "time"
}

job_name: InitVar[str]
time: InitVar[str]
error: InitVar[str] = None
output: InitVar[str] = None

account: Optional[str] = None
acctg_freq: Optional[str] = None
comment: Optional[str] = None
constraint: Optional[List] = None
cpus_per_task: Optional[int] = None
contiguous: Optional[bool] = None
dependency: Optional[List] = None
exclusive: Optional[bool] = None
gres: Optional[Union[List[str], str]] = None
no_kill: Optional[bool] = None
licenses: Optional[List[str]] = None
clusters: Optional[Union[List[str], str]] = None
mail_type: Optional[str] = None
mem: Optional[int] = None
mincpus: Optional[int] = None
nodes: Optional[int] = None
ntasks: Optional[int] = None
no_requeue: Optional[bool] = None
ntasks_per_node: Optional[int] = None
partition: Optional[int] = None
qos: Optional[str] = None
requeue: Optional[bool] = None
reservation: Optional[str] = None
share: Optional[bool] = None
tmp: Optional[str] = None
nodelist: Optional[Union[List[str], str]] = None
exclude: Optional[Union[List[str], str]] = None

def __post_init__(self, job_name, time, error, output):
'''
Transform Union[List[str]] --> comma-delimited str
'''

_validate_timestr(time, "time")
super().set_fields(jobName=job_name,
hardWallclockTimeLimit=time,
errorPath=error,
outputPath=output)

self.job_name = job_name
self.time = time
self.error = error
self.output = output

for field in fields(self):
value = getattr(self, field.name)
if field.type == Union[List[str], str] and isinstance(value, list):
setattr(self, field.name, ",".join(value))

def drm2drmaa(self) -> str:
return self._transform_attrs()

def _transform_attrs(self) -> str:
'''
Remap named attributes to "-" form, excludes renaming
DRMAA-compliant fields (set in __post_init__()) then join
attributes into a nativeSpecification string
'''

out = []
for field in self._native_fields():

value = getattr(self, field)
if value is None:
continue

field_fmtd = field.replace("_", "-")
if isinstance(value, bool):
out.append(f"--{field_fmtd}")
else:
out.append(f"--{field_fmtd}={value}")
return " ".join(out)


def _timestr_to_sec(timestr: str) -> int:
'''
Transform a time-string (D-HH:MM:SS) --> seconds
'''

days = 0
if "-" in timestr:
day_str, timestr = timestr.split('-')
days = int(day_str)

seconds = (24 * days) * (60**2)
for exp, unit in enumerate(reversed(timestr.split(":"))):
seconds += int(unit) * (60**exp)

return seconds


def _validate_timestr(timestr: str, field_name: str) -> str:
'''
Validate timestring to make sure it meets
expected format.
'''

if not isinstance(timestr, str):
raise TypeError(f"Expected {field_name} to be of type string "
f"but received {type(timestr)}!")

result = TIMESTR_VALIDATE.match(timestr)
if not result:
raise ValueError(f"Expected {field_name} to be of format "
"X...XX:XX:XX or XX:XX! "
f"but received {timestr}")

return timestr
56 changes: 56 additions & 0 deletions drmaa_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
'''
Patches on DRMAA-python module
'''

from drmaa import JobTemplate, Session
from drmaa.helpers import Attribute, IntConverter


#TODO: Make sure this is actually correct?
# Works for SLURM
CORRECT_TO_STRING = [
"hardWallclockTimeLimit"
]


class PatchedIntConverter():
'''
Helper class to correctly encode Integer values
as little-endian bytes for Python 3

Info:
The standard IntConverter class attempts to convert
integer values to bytes using `bytes(value)` which
results in a zero'd byte-array of length `value`.
'''
@staticmethod
def to_drmaa(value: int) -> bytes:
return value.to_bytes(8, byteorder="little")

@staticmethod
def from_drmaa(value: bytes) -> int:
return int.from_bytes(value, byteorder="little")


class PatchedJobTemplate(JobTemplate):
def __init__(self):
'''
Dynamically patch attributes using IntConverter
'''
super(PatchedJobTemplate, self).__init__()
for attr, value in vars(JobTemplate).items():
if isinstance(value, Attribute):
if attr in CORRECT_TO_STRING:
setattr(value, "converter", None)
elif value.converter is IntConverter:
setattr(value, "converter", PatchedIntConverter)


class PatchedSession(Session):
'''
Override createJobTemplate method to return
Patched version
'''
@staticmethod
def createJobTemplate(self) -> PatchedJobTemplate:
return PatchedJobTemplate()
15 changes: 15 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
apache-airflow==2.1.4
apache-airflow-providers-ftp==2.0.1
apache-airflow-providers-http==2.0.1
apache-airflow-providers-imap==2.0.1
apache-airflow-providers-sqlite==2.0.1
apache-airflow-providers-ssh==2.2.0
coverage==5.3.1
flake8==3.8.4
pytest==6.2.5
pytest-cov==2.11.0
pytest-forked==1.3.0
pytest-mock==3.6.1
pytest-xdist==2.2.0
Sphinx==3.4.3
toml==0.10.2
93 changes: 93 additions & 0 deletions tests/test_config_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Tests for config_adapters.py to ensure that mapping
from DRM-specific configuration to DRMAA spec works
correctly
"""

import pytest
from drmaa_executor_plugin.drmaa_patches import (PatchedJobTemplate as
JobTemplate)
from drmaa_executor_plugin.config_adapters import SlurmConfig


@pytest.fixture()
def job_template():
jt = JobTemplate()
yield jt
jt.delete()


def test_slurm_config_transforms_to_drmaa(job_template):
'''
Check whether SLURM adapter class correctly
transforms SLURM specs to DRMAA attributes
'''

error = "TEST_VALUE"
output = "TEST_VALUE"
time = "10:00:00"
job_name = "FAKE_JOB"

expected_drmaa_attrs = {
"errorPath": error,
"outputPath": output,
"hardWallclockTimeLimit": "10:00:00",
"jobName": job_name
}

slurm_config = SlurmConfig(error=error,
output=output,
time=time,
job_name=job_name)

jt = slurm_config.get_drmaa_config(job_template)

# Test attributes matches what is expected
for k, v in expected_drmaa_attrs.items():
assert getattr(jt, k) == v


def test_slurm_config_native_spec_transforms_correctly(job_template):
'''
Test whether scheduler-specific configuration is transformed
into nativeSpecification correctly
'''

job_name = "TEST"
time = "01:00"
account = "TEST"
cpus_per_task = 5
slurm_config = SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)

jt = slurm_config.get_drmaa_config(job_template)
for spec in ['account=TEST', 'cpus-per-task=5']:
assert spec in jt.nativeSpecification


def test_invalid_timestr_fails():
job_name = "TEST"
time = "FAILURE"
account = "TEST"
cpus_per_task = 10

with pytest.raises(ValueError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)


def test_timestr_not_string_fails():
job_name = "TEST"
time = 10
account = "TEST"
cpus_per_task = 10

with pytest.raises(TypeError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)