From 150541143717335ccad4e589690a1460edcc47c0 Mon Sep 17 00:00:00 2001
From: Stefano Alberto Russo <stefano.russo@gmail.com>
Date: Sat, 7 Oct 2023 02:41:42 +0200
Subject: [PATCH] Fixed wrong tunnel setup for ssh-based computing resources.

---
 .../webapp/code/rosetta/core_app/utils.py     | 120 +++++++++---------
 1 file changed, 63 insertions(+), 57 deletions(-)

diff --git a/services/webapp/code/rosetta/core_app/utils.py b/services/webapp/code/rosetta/core_app/utils.py
index 74c8897..259b6b3 100644
--- a/services/webapp/code/rosetta/core_app/utils.py
+++ b/services/webapp/code/rosetta/core_app/utils.py
@@ -32,9 +32,6 @@ color_map = ["#440154", "#440558", "#450a5c", "#450e60", "#451465", "#461969",
              "#97d73e", "#9ed93a", "#a8db34", "#b0dd31", "#b8de30", "#c3df2e",
              "#cbe02d", "#d6e22b", "#e1e329", "#eae428", "#f5e626", "#fde725"]
 
-#======================
-#  Utility functions
-#======================
 
 def booleanize(*args, **kwargs):
     # Handle both single value and kwargs to get arg name
@@ -265,10 +262,6 @@ def get_md5(string):
     return md5
 
 
-#=========================
-#   Time 
-#=========================
-
 def timezonize(timezone):
     '''Convert a string representation of a timezone to its pytz object or do nothing if the argument is already a pytz timezone'''
     
@@ -283,14 +276,17 @@ def timezonize(timezone):
         timezone = pytz.timezone(timezone)
     return timezone
 
+
 def now_t():
     '''Return the current time in epoch seconds'''
     return now_s()
 
+
 def now_s():
     '''Return the current time in epoch seconds'''
     return calendar.timegm(now_dt().utctimetuple())
 
+
 def now_dt(tzinfo='UTC'):
     '''Return the current time in datetime format'''
     if tzinfo != 'UTC':
@@ -335,10 +331,12 @@ def dt(*args, **kwargs):
 
     return  time_dt
 
+
 def get_tz_offset_s(time_dt):
     '''Get the time zone offset in seconds'''
     return s_from_dt(time_dt.replace(tzinfo=pytz.UTC)) - s_from_dt(time_dt)
 
+
 def check_dt_consistency(date_dt):
     '''Check that the timezone is consistent with the datetime (some conditions in Python lead to have summertime set in winter)'''
 
@@ -355,6 +353,7 @@ def check_dt_consistency(date_dt):
         else:
             return True
 
+
 def correct_dt_dst(datetime_obj):
     '''Check that the dst is correct and if not change it'''
 
@@ -374,14 +373,17 @@ def correct_dt_dst(datetime_obj):
               datetime_obj.microsecond,
               tzinfo=datetime_obj.tzinfo)
 
+
 def change_tz(dt, tz):
     return dt.astimezone(timezonize(tz))
 
+
 def dt_from_t(timestamp_s, tz=None):
     '''Create a datetime object from an epoch timestamp in seconds. If no timezone is given, UTC is assumed'''
     # TODO: check if uniform everything on this one or not.
     return dt_from_s(timestamp_s=timestamp_s, tz=tz)
-    
+
+
 def dt_from_s(timestamp_s, tz=None):
     '''Create a datetime object from an epoch timestamp in seconds. If no timezone is given, UTC is assumed'''
 
@@ -397,6 +399,7 @@ def dt_from_s(timestamp_s, tz=None):
     
     return timestamp_dt
 
+
 def s_from_dt(dt):
     '''Returns seconds with floating point for milliseconds/microseconds.'''
     if not (isinstance(dt, datetime.datetime)):
@@ -404,6 +407,7 @@ def s_from_dt(dt):
     microseconds_part = (dt.microsecond/1000000.0) if dt.microsecond else 0
     return  ( calendar.timegm(dt.utctimetuple()) + microseconds_part)
 
+
 def dt_from_str(string, timezone=None):
 
     # Supported formats on UTC
@@ -458,10 +462,12 @@ def dt_from_str(string, timezone=None):
     
     return dt(year, month, day, hour, minute, second, usecond, offset_s=offset_s)
 
+
 def dt_to_str(dt):
     '''Return the ISO representation of the datetime as argument'''
     return dt.isoformat()
 
+
 class dt_range(object):
 
     def __init__(self, from_dt, to_dt, timeSlotSpan):
@@ -489,20 +495,18 @@ class dt_range(object):
         return self.__next__()
 
 
-#================================
-#  Others
-#================================
-
 def debug_param(**kwargs):
     for item in kwargs:
         logger.critical('Param "{}": "{}"'.format(item, kwargs[item]))
 
+
 def get_my_ip():
     import socket
     hostname = socket.gethostname()
     my_ip = socket.gethostbyname(hostname)
     return my_ip
 
+
 def get_webapp_conn_string():
     webapp_ssl  = booleanize(os.environ.get('ROSETTA_WEBAPP_SSL', False))
     webapp_host = os.environ.get('ROSETTA_WEBAPP_HOST', get_my_ip())
@@ -513,32 +517,68 @@ def get_webapp_conn_string():
         webapp_conn_string = 'http://{}:{}'.format(webapp_host, webapp_port)
     return webapp_conn_string
 
+
 def get_platform_registry():
     platform_registry_host = os.environ.get('PLATFORM_REGISTRY_HOST', 'proxy')
     platform_registry_port = os.environ.get('PLATFORM_REGISTRY_PORT', '5000')
     platform_registry_conn_string = '{}:{}'.format(platform_registry_host, platform_registry_port)
     return platform_registry_conn_string
-    
+
+  
 def get_rosetta_tasks_tunnel_host():
     # Importing here instead of on top avoids circular dependencies problems when loading booleanize in settings
     from django.conf import settings
     tunnel_host = os.environ.get('ROSETTA_TASKS_TUNNEL_HOST', settings.ROSETTA_HOST)
     return tunnel_host
 
+
 def get_rosetta_tasks_proxy_host():
     # Importing here instead of on top avoids circular dependencies problems when loading booleanize in settings
     from django.conf import settings
     proxy_host = os.environ.get('ROSETTA_TASKS_PROXY_HOST', settings.ROSETTA_HOST)
     return proxy_host
 
+
 def hash_string_to_int(string):
     return int(hashlib.sha1(string.encode('utf8')).hexdigest(), 16)
 
 
+def get_ssh_access_mode_credentials(computing, user):
+    
+    from .models import KeyPair
+    
+    # Get computing host
+    try:
+        computing_host = computing.conf.get('host')
+    except AttributeError:
+        computing_host = None
+    if not computing_host:
+        raise ValueError('No computing host?!')
+
+    # Get computing (SSH) port
+    try:
+        computing_port = computing.conf.get('port')
+    except AttributeError:
+        computing_port = 22
+    if not computing_host:
+        computing_port = 22
+      
+    # Get computing user and keys
+    if computing.auth_mode == 'user_keys':
+        computing_user = user.profile.get_extra_conf('computing_user', computing)
+        if not computing_user:
+            raise ValueError('No \'computing_user\' parameter found for computing resource \'{}\' in user profile'.format(computing.name))
+        # Get user key
+        computing_keys = KeyPair.objects.get(user=user, default=True)
+    elif computing.auth_mode == 'platform_keys':        
+        computing_user = computing.conf.get('user')
+        computing_keys = KeyPair.objects.get(user=None, default=True)
+    else:
+        raise NotImplementedError('Auth modes other than user_keys and platform_keys not supported.')
+    if not computing_user:
+            raise ValueError('No \'user\' parameter found for computing resource \'{}\' in its configuration'.format(computing.name))
+    return (computing_user, computing_host, computing_port, computing_keys)
 
-#================================
-#  Tunnel (and proxy) setup
-#================================
 
 def setup_tunnel_and_proxy(task):
 
@@ -602,7 +642,13 @@ def setup_tunnel_and_proxy(task):
             tunnel_command= 'ssh -4 -i {} -o StrictHostKeyChecking=no -nNT -L 0.0.0.0:{}:{}:{} {}@{} & '.format(user_keys.private_key_file, task.tcp_tunnel_port, task.interface_ip, task.interface_port, first_user, first_host)
 
         else:
-            tunnel_command= 'ssh -4 -o StrictHostKeyChecking=no -nNT -L 0.0.0.0:{}:{}:{} localhost & '.format(task.tcp_tunnel_port, task.interface_ip, task.interface_port)
+            
+            if task.computing.access_mode.startswith('ssh'):
+                computing_user, computing_host, computing_port, computing_keys = get_ssh_access_mode_credentials(task.computing, task.user)
+                tunnel_command  = 'ssh -p {} -o LogLevel=ERROR -i {} -4 -o StrictHostKeyChecking=no -o ConnectTimeout=10 '.format(computing_port, computing_keys.private_key_file)
+                tunnel_command += '-nNT -L 0.0.0.0:{}:{}:{} {}@{}'.format(task.tcp_tunnel_port, task.interface_ip, task.interface_port, computing_user, computing_host)
+            else:
+                tunnel_command= 'ssh -4 -o StrictHostKeyChecking=no -nNT -L 0.0.0.0:{}:{}:{} localhost & '.format(task.tcp_tunnel_port, task.interface_ip, task.interface_port)
         
         background_tunnel_command = 'nohup {} >/dev/null 2>&1 &'.format(tunnel_command)
 
@@ -713,46 +759,6 @@ Listen '''+str(task.tcp_tunnel_port)+'''
                 raise ErrorMessage('Something went wrong when loading the task proxy conf')        
             
 
-
-
-def get_ssh_access_mode_credentials(computing, user):
-    
-    from .models import KeyPair
-    
-    # Get computing host
-    try:
-        computing_host = computing.conf.get('host')
-    except AttributeError:
-        computing_host = None
-    if not computing_host:
-        raise ValueError('No computing host?!')
-
-    # Get computing (SSH) port
-    try:
-        computing_port = computing.conf.get('port')
-    except AttributeError:
-        computing_port = 22
-    if not computing_host:
-        computing_port = 22
-      
-    # Get computing user and keys
-    if computing.auth_mode == 'user_keys':
-        computing_user = user.profile.get_extra_conf('computing_user', computing)
-        if not computing_user:
-            raise ValueError('No \'computing_user\' parameter found for computing resource \'{}\' in user profile'.format(computing.name))
-        # Get user key
-        computing_keys = KeyPair.objects.get(user=user, default=True)
-    elif computing.auth_mode == 'platform_keys':        
-        computing_user = computing.conf.get('user')
-        computing_keys = KeyPair.objects.get(user=None, default=True)
-    else:
-        raise NotImplementedError('Auth modes other than user_keys and platform_keys not supported.')
-    if not computing_user:
-            raise ValueError('No \'user\' parameter found for computing resource \'{}\' in its configuration'.format(computing.name))
-    return (computing_user, computing_host, computing_port, computing_keys)
-
-
-
 def sanitize_container_env_vars(env_vars):
     
     for env_var in env_vars:
-- 
GitLab