"""
YellowDog Airflow sensors for work requirements, worker pools and
compute requirements. The sensors test for when one of a specified
set of states is reached.
"""
from collections.abc import Callable, Sequence
from enum import Enum
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.context import Context
from jinja2 import Environment
from yellowdog_client import PlatformClient
from yellowdog_client.model import (
ComputeRequirement,
ComputeRequirementStatus,
ConfiguredWorkerPool,
ProvisionedWorkerPool,
WorkerPool,
WorkerPoolStatus,
WorkRequirement,
WorkRequirementStatus,
)
from yellowdog_provider.hooks.yellowdog_hooks import YellowDogHook
from yellowdog_provider.utils.yellowdog_utils import (
get_compute_requirement_by_id_or_name,
get_work_requirement_by_id_or_name,
get_worker_pool_by_id_or_name,
)
XCOM_KEY = "return_value" # Key to use when writing out IDs from sensors
[docs]
class ObjectName(Enum):
"""
Object naming.
"""
WORK_REQUIREMENT = "work requirement"
WORKER_POOL = "worker pool"
COMPUTE_REQUIREMENT = "compute requirement"
[docs]
class YellowDogSensor(BaseSensorOperator):
"""
Base class for YellowDog sensors.
:param connection_id: connection to run the operator with (templated) or a
Callable that generates the connection ID
:type connection_id: str | Callable
"""
template_fields: Sequence[str] = ("connection_id",)
def __init__(
self,
connection_id: Callable[[Context, Environment], str] | str,
**kwargs,
):
super().__init__(**kwargs)
self.connection_id = connection_id
self.xcom_written = False
[docs]
def check_status(
self,
context: Context,
object_id: str,
object_type: ObjectName,
current_state: (
WorkRequirementStatus | WorkerPoolStatus | ComputeRequirementStatus
),
target_states: (
list[WorkRequirementStatus]
| list[WorkerPoolStatus]
| list[ComputeRequirementStatus]
),
target_states_names: list[str],
) -> bool:
"""
Checks current state against the target state(s) and returns true if the
condition is met.
"""
if not self.xcom_written:
self.log.info(f"Writing {object_type.value} ID to XCom key '{XCOM_KEY}'")
context["task_instance"].xcom_push(key=XCOM_KEY, value=object_id)
self.xcom_written = True
self.log.info(
f"Checking for {object_type.value} status in {target_states_names}"
)
if current_state in target_states:
self.log.info(
f"{str(object_type.value).capitalize()} has reached status "
f"'{current_state.value}'"
)
return True
self.log.info(
f"Current status is '{current_state.value}' ... "
f"waiting for {self.poke_interval} seconds"
)
return False
[docs]
class WorkRequirementStateSensor(YellowDogSensor):
"""
Sensor that tests if a work requirement has reached one of a specified
set of states. Either the work_requirement_id or both namespace and
work_requirement_name must be supplied.
Emits the work requirement ID to XCom using key 'return_value'.
:param task_id: the Airflow task ID
:type task_id: str
:param connection_id: connection to run the operator with (templated) or a
Callable that generates the connection ID
:type connection_id: str | Callable
:param target_states: the list of WorkRequirementStatus states to test for
:type target_states: list[WorkRequirementStatus]
:param work_requirement_id: the ID of the work requirement (templated) or
a Callable that generates the ID
:type work_requirement_id: str | Callable | None
:param namespace: the namespace of the work requirement (templated) or a
Callable that returns the namespace
:type namespace: str | Callable | None
:param work_requirement_name: the name of a work requirement (templated) or a
Callable that generates a work requirement name
:type work_requirement_name: str | Callable | None
:param poke_interval: the time between sensor checks in seconds (default = 60.0)
:type poke_interval: float
"""
template_fields: Sequence[str] = (
*YellowDogSensor.template_fields,
"work_requirement_id",
"namespace",
"work_requirement_name",
)
def __init__(
self,
task_id: str,
connection_id: Callable[[Context, Environment], str] | str,
target_states: list[WorkRequirementStatus],
work_requirement_id: Callable[[Context, Environment], str] | str | None = None,
namespace: Callable[[Context, Environment], str] | str | None = None,
work_requirement_name: (
Callable[[Context, Environment], str] | str | None
) = None,
poke_interval: float = 60, # Seconds
**kwargs,
):
super().__init__(
task_id=task_id,
connection_id=connection_id,
poke_interval=poke_interval,
**kwargs,
)
self.work_requirement_id = work_requirement_id
self.namespace = namespace
self.work_requirement_name = work_requirement_name
self.target_states = target_states
self.target_states_names = [x.value for x in target_states]
[docs]
def poke(self, context) -> bool:
"""
Tests the work requirement's status.
"""
client: PlatformClient = YellowDogHook(self.connection_id).get_conn()
work_requirement: WorkRequirement = get_work_requirement_by_id_or_name(
client,
self.log,
self.work_requirement_id,
self.namespace,
self.work_requirement_name,
)
return self.check_status(
context,
work_requirement.id,
ObjectName.WORK_REQUIREMENT,
work_requirement.status,
self.target_states,
self.target_states_names,
)
[docs]
class WorkerPoolStateSensor(YellowDogSensor):
"""
Sensor that tests if a worker pool has reached one of a specified
set of states. Either the worker_pool_id or both the namespace and
worker_pool_name must be supplied.
Emits the worker pool ID to XCom using key 'return_value'.
:param task_id: the Airflow task ID
:type task_id: str
:param connection_id: connection to run the operator with (templated) or a
Callable that generates the connection ID
:type connection_id: str | Callable
:param target_states: the list of WorkerPoolStatus states to test for
:type target_states: list[WorkerPoolStatus]
:param worker_pool_id: the ID of the worker pool (templated) or
a Callable that generates the ID
:type worker_pool_id: str | Callable | None
:param namespace: the namespace of the worker pool (templated) or a
Callable that returns the namespace
:type namespace: str | Callable | None
:param worker_pool_name: the name of a worker pool (templated) or a
Callable that generates a worker pool name
:type worker_pool_name: str | Callable | None
:param poke_interval: the time between sensor checks in seconds (default = 60.0)
:type poke_interval: float
"""
template_fields: Sequence[str] = (
*YellowDogSensor.template_fields,
"worker_pool_id",
"namespace",
"worker_pool_name",
)
def __init__(
self,
task_id: str,
connection_id: Callable[[Context, Environment], str] | str,
target_states: list[WorkerPoolStatus],
worker_pool_id: Callable[[Context, Environment], str] | str | None = None,
namespace: Callable[[Context, Environment], str] | str | None = None,
worker_pool_name: Callable[[Context, Environment], str] | str | None = None,
poke_interval: float = 60, # Seconds
**kwargs,
):
super().__init__(
task_id=task_id,
connection_id=connection_id,
poke_interval=poke_interval,
**kwargs,
)
self.worker_pool_id = worker_pool_id
self.namespace = namespace
self.worker_pool_name = worker_pool_name
self.target_states = target_states
self.target_states_names = [x.value for x in target_states]
[docs]
def poke(self, context) -> bool:
"""
Tests the worker pool's status.
"""
client: PlatformClient = YellowDogHook(self.connection_id).get_conn()
worker_pool: WorkerPool = get_worker_pool_by_id_or_name(
client,
self.log,
self.worker_pool_id,
self.namespace,
self.worker_pool_name,
)
# Keep the type system happy
if isinstance(worker_pool, ProvisionedWorkerPool) or isinstance(
worker_pool, ConfiguredWorkerPool
):
return self.check_status(
context,
worker_pool.id,
ObjectName.WORKER_POOL,
worker_pool.status,
self.target_states,
self.target_states_names,
)
return False # Never gets here
[docs]
class ComputeRequirementStateSensor(YellowDogSensor):
"""
Sensor that tests if a compute requirement has reached one of a specified
set of states. Either the compute_requirement_id or both the namespace and
compute_requirement_name must be supplied.
Emits the compute requirement ID to XCom using key 'return_value'.
:param task_id: the Airflow task ID
:type task_id: str
:param connection_id: connection to run the operator with (templated) or a
Callable that generates the connection ID
:param target_states: the list of ComputeRequirementStatus states to test for
:type target_states: list[ComputeRequirementStatus]
:type connection_id: str | Callable
:param compute_requirement_id: the ID of the compute requirement (templated) or a
Callable that returns the ID
:type compute_requirement_id: str | Callable | None
:param namespace: the namespace of the work requirement (templated) or a
Callable that returns the namespace
:type namespace: str | Callable | None
:param compute_requirement_name: the name of the compute requirement (templated)
or a Callable that returns the compute requirement name
:type compute_requirement_name: str | Callable | None
:param poke_interval: the time between sensor checks in seconds (default = 60.0)
:type poke_interval: float
"""
template_fields: Sequence[str] = (
*YellowDogSensor.template_fields,
"compute_requirement_id",
"namespace",
"compute_requirement_name",
)
def __init__(
self,
task_id: str,
connection_id: Callable[[Context, Environment], str] | str,
target_states: list[ComputeRequirementStatus],
compute_requirement_id: (
Callable[[Context, Environment], str] | str | None
) = None,
namespace: Callable[[Context, Environment], str] | str | None = None,
compute_requirement_name: (
Callable[[Context, Environment], str] | str | None
) = None,
poke_interval: float = 60, # Seconds
**kwargs,
):
super().__init__(
task_id=task_id,
connection_id=connection_id,
poke_interval=poke_interval,
**kwargs,
)
self.compute_requirement_id = compute_requirement_id
self.namespace = namespace
self.compute_requirement_name = compute_requirement_name
self.target_states = target_states
self.target_states_names = [x.value for x in target_states]
[docs]
def poke(self, context) -> bool:
"""
Tests the compute requirement's status.
"""
client: PlatformClient = YellowDogHook(self.connection_id).get_conn()
compute_requirement: ComputeRequirement = get_compute_requirement_by_id_or_name(
client,
self.log,
self.compute_requirement_id,
self.namespace,
self.compute_requirement_name,
)
return self.check_status(
context,
compute_requirement.id,
ObjectName.COMPUTE_REQUIREMENT,
compute_requirement.status,
self.target_states,
self.target_states_names,
)