Source code for yellowdog_ray.utils.ray_ssh_tunnels

"""
Simple class to set up SSH tunnel port-forwarding from a Ray
head node to the local host.
"""

from dataclasses import dataclass

from sshtunnel import SSHTunnelForwarder


[docs] @dataclass class SSHTunnelSpec: """ Defines the parameters of an SSH tunnel. """ local_bind_hostname: str local_bind_ip_address: int remote_bind_hostname: str remote_bind_ip_address: int
[docs] class RayTunnels: """ Set up SSH port forwarding for Ray. """
[docs] def __init__( self, ray_head_ip_address: str, ssh_user: str, password: str | None = None, private_key_file: str | None = None, private_key_password: str | None = None, ray_tunnel_specs: list[SSHTunnelSpec] | None = None, ): """ Set SSH tunnel parameters. One of private_key_file or password must be supplied. :param ray_head_ip_address: the network-accessible IP address for the Ray head node :param ssh_user: the SSH user to use :param password: the user's password :param private_key_password: the password for the private key :param private_key_file: the file containing the private key to use :param ray_tunnel_specs: the ports on which to establish tunnels as a list of SSHTunnelSpec objects """ self._ray_ip = ray_head_ip_address self._ssh_user = ssh_user self._private_key_file = private_key_file self._password = password self._private_key_password = private_key_password self._ray_tunnel_specs = ( [ self.basic_port_forward(10001), self.basic_port_forward(8265), ] if ray_tunnel_specs is None else ray_tunnel_specs ) self._ray_tunnels: list[SSHTunnelForwarder] = []
[docs] @staticmethod def basic_port_forward(port: int) -> SSHTunnelSpec: """ Create a basic port-forwarding tunnel specification. :param port: the port to forward :return: An SSHTunnelSpec object """ localhost = "localhost" return SSHTunnelSpec(localhost, port, localhost, port)
[docs] def start_tunnels(self): """ Start the SSH tunnels. """ for tunnel_spec in self._ray_tunnel_specs: tunnel = SSHTunnelForwarder( self._ray_ip, ssh_username=self._ssh_user, ssh_password=self._password, ssh_pkey=self._private_key_file, ssh_private_key_password=self._private_key_password, local_bind_address=( tunnel_spec.local_bind_hostname, tunnel_spec.local_bind_ip_address, ), remote_bind_address=( tunnel_spec.remote_bind_hostname, tunnel_spec.remote_bind_ip_address, ), ) tunnel.start() self._ray_tunnels.append(tunnel)
[docs] def stop_tunnels(self, ignore_errors: bool = True): """ Stop the SSH tunnels. :param ignore_errors: suppress exceptions and proceed """ for tunnel in self._ray_tunnels: try: tunnel.stop(force=True) except: if ignore_errors: pass