From eb7d970a0e88bbabda2a99bfc34b9a75f68102ae Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sat, 19 Aug 2023 09:43:12 +1000 Subject: [PATCH 01/43] Working Async socket read/write --- database.py => lib/database.py | 1 - diameter.py => lib/diameter.py | 47 +++++++++--------- lib/messaging.py | 25 ++++++++++ services/diameterService.py | 88 ++++++++++++++++++++++++++++++++++ services/georedService.py | 0 services/prometheusService.py | 0 services/webhookService.py | 0 7 files changed, 135 insertions(+), 26 deletions(-) rename database.py => lib/database.py (99%) rename diameter.py => lib/diameter.py (99%) create mode 100644 lib/messaging.py create mode 100644 services/diameterService.py create mode 100644 services/georedService.py create mode 100644 services/prometheusService.py create mode 100644 services/webhookService.py diff --git a/database.py b/lib/database.py similarity index 99% rename from database.py rename to lib/database.py index b4773a5..bc8a6aa 100755 --- a/database.py +++ b/lib/database.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import sessionmaker, relationship, Session, class_mapper from sqlalchemy.orm.attributes import History, get_history import sys, os -sys.path.append(os.path.realpath('lib')) from functools import wraps import json import datetime, time diff --git a/diameter.py b/lib/diameter.py similarity index 99% rename from diameter.py rename to lib/diameter.py index 26dbee5..57adb13 100644 --- a/diameter.py +++ b/lib/diameter.py @@ -9,29 +9,25 @@ import os import random import ipaddress -sys.path.append(os.path.realpath('lib')) -import S6a_crypt - import jinja2 -import yaml -import time -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) -#Setup Logging -import logtool -from logtool import * -logtool = logtool.LogTool() -logtool.setup_logger('DiameterLogger', yaml_config['logging']['logfiles']['diameter_logging_file'], level=yaml_config['logging']['level']) -DiameterLogger = logging.getLogger('DiameterLogger') +# with open("config.yaml", 'r') as stream: +# yaml_config = (yaml.safe_load(stream)) + +# #Setup Logging +# import logtool +# from logtool import * +# logtool = logtool.LogTool() +# logtool.setup_logger('DiameterLogger', yaml_config['logging']['logfiles']['diameter_logging_file'], level=yaml_config['logging']['level']) +# DiameterLogger = logging.getLogger('DiameterLogger') -DiameterLogger.info("Initialised Diameter Logger, importing database") -import database -DiameterLogger.info("Imported database") +# DiameterLogger.info("Initialised Diameter Logger, importing database") +# import database +# DiameterLogger.info("Imported database") -if yaml_config['redis']['enabled'] == True: - DiameterLogger.debug("Redis support enabled") - import redis +# if yaml_config['redis']['enabled'] == True: +# DiameterLogger.debug("Redis support enabled") +# import redis class Diameter: @@ -213,12 +209,13 @@ def TBCD_decode(self, input): return output #Hexify the vars we got when initializing the class - def __init__(self, OriginHost, OriginRealm, ProductName, MNC, MCC): - self.OriginHost = self.string_to_hex(OriginHost) - self.OriginRealm = self.string_to_hex(OriginRealm) - self.ProductName = self.string_to_hex(ProductName) - self.MNC = str(MNC) - self.MCC = str(MCC) + #@@@Fixme + def __init__(self): + self.OriginHost = self.string_to_hex("OriginHost") + self.OriginRealm = self.string_to_hex("OriginRealm") + self.ProductName = self.string_to_hex("ProductName") + self.MNC = str(505) + self.MCC = str(52) DiameterLogger.info("Initialized Diameter for " + str(OriginHost) + " at Realm " + str(OriginRealm) + " serving as Product Name " + str(ProductName)) DiameterLogger.info("PLMN is " + str(MCC) + "/" + str(MNC)) diff --git a/lib/messaging.py b/lib/messaging.py new file mode 100644 index 0000000..e6e42e4 --- /dev/null +++ b/lib/messaging.py @@ -0,0 +1,25 @@ +from redis import Redis + +class RedisMessaging(): + """ + PyHSS Redis Message Service + A class for sending and receiving redis messages. + """ + + def __init__(self, host: str='localhost', port: int=6379): + self.redisClient = Redis(host=host, port=port) + pass + + def sendMessage(self, queue: str, message: str) -> str: + self.redisClient.rpush(queue, message) + + def getMessage(self, queue: str) -> str: + message = self.redisClient.lpop(queue) + if message is None: + message = '' + else: + try: + message = message.decode() + except (UnicodeDecodeError, AttributeError): + pass + return message \ No newline at end of file diff --git a/services/diameterService.py b/services/diameterService.py new file mode 100644 index 0000000..32db996 --- /dev/null +++ b/services/diameterService.py @@ -0,0 +1,88 @@ +import asyncio +import sctp, socket +import sys, os, binascii +import time +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from diameter import Diameter + + +class DiameterService(): + """ + PyHSS Diameter Service + A class for handling diameter requests and replies on Port 3868, via TCP or SCTP. + """ + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.diameterLibrary = Diameter() + pass + + def validateDiameterRequest(self, requestData) -> bool: + try: + packetVars, avps = self.diameterLibrary.decode_diameter_packet(requestData) + originHost = self.diameterLibrary.get_avp_data(avps, 264)[0] + originHost = binascii.unhexlify(originHost).decode("utf-8") + except Exception as e: + return False + return True + + async def readRequestData(self, reader, clientAddress: str, clientPort: str) -> bool: + requestQueueName = f"{clientAddress}-{clientPort}-requests" + print("In readRequestData") + + while True: + requestData = await reader.read(1024) + if len(requestData) > 0: + print(f"Received data from {clientAddress} on port {clientPort}") + print(f"Data: {binascii.hexlify(requestData)}") + + if not self.validateDiameterRequest(requestData): + print(f"Invalid Diameter Request.") + break + + requestHexString = binascii.hexlify(requestData) + print(requestHexString) + self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString) + + async def writeResponseData(self, writer, clientAddress: str, clientPort: str) -> bool: + responseQueueName = f"{clientAddress}-{clientPort}-responses" + print("In writeResponseData") + + while True: + responseHexString = self.redisMessaging.getMessage(queue=responseQueueName) + if not len(responseHexString) > 0: + await asyncio.sleep(0.005) + continue + + diameterResponse = f'Received diameter request successfully.' + print(f"Sending: {diameterResponse}") + writer.write(diameterResponse) + await writer.drain() + + async def handleConnection(self, reader, writer): + (clientAddress, clientPort) = writer.get_extra_info('peername') + if not await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort), + self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort)): + print("Closing Connection") + writer.close() + return + + async def startServer(self, host: str='0.0.0.0', port: int=3868, type: str='TCP'): + if type.upper() == 'TCP': + server = await asyncio.start_server(self.handleConnection, host, port) + elif type.upper() == 'SCTP': + sctpSocket = sctp.sctpsocket_tcp(socket.AF_INET) + server = await asyncio.start_server(self.handleConnection, host, port, socket=sctpSocket) + else: + return False + servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) + print(f'Serving on {servingAddresses}') + + async with server: + await server.serve_forever() + + +if __name__ == '__main__': + diameterService = DiameterService() + asyncio.run(diameterService.startServer()) \ No newline at end of file diff --git a/services/georedService.py b/services/georedService.py new file mode 100644 index 0000000..e69de29 diff --git a/services/prometheusService.py b/services/prometheusService.py new file mode 100644 index 0000000..e69de29 diff --git a/services/webhookService.py b/services/webhookService.py new file mode 100644 index 0000000..e69de29 From 6b354b50ed64c5202b94d7bd949efc533c14730c Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 22 Aug 2023 15:20:15 +1000 Subject: [PATCH 02/43] Working diameterService and hssService --- config.yaml | 9 - hss.py | 1012 +---------------------------------- lib/S6a_crypt.py | 7 +- lib/banners.py | 74 +++ lib/database.py | 259 ++++----- lib/diameter.py | 758 ++++++++++++++------------ lib/logtool.py | 252 +-------- lib/messaging.py | 74 ++- lib/messagingAsync.py | 84 +++ lib/milenage.py | 2 - lib/old.logtool.py | 243 +++++++++ log/.gitkeep | 0 old.hss.py | 1012 +++++++++++++++++++++++++++++++++++ services/diameterService.py | 105 ++-- services/hssService.py | 71 +++ 15 files changed, 2175 insertions(+), 1787 deletions(-) create mode 100644 lib/banners.py create mode 100644 lib/messagingAsync.py create mode 100644 lib/old.logtool.py create mode 100644 log/.gitkeep create mode 100644 old.hss.py create mode 100644 services/hssService.py diff --git a/config.yaml b/config.yaml index 75fb78a..1cebdc0 100644 --- a/config.yaml +++ b/config.yaml @@ -30,18 +30,9 @@ hss: #IMSI of Test Subscriber for Unit Checks (Optional) test_sub_imsi: '001021234567890' - #Device Watchdog Request Interval (In Seconds - If set to 0 disabled) - device_watchdog_request_interval: 0 - - #Async Queue Check Interval (In Seconds - If set to 0 disabled) - async_check_interval: 0 - #The maximum time to wait, in seconds, before disconnecting a client when no data is received. client_socket_timeout: 120 - #The maximum amount of times a failed diameter response/query should be resent before considering the peer offline and terminating their connection - diameter_max_retries: 1 - #Prevent updates from being performed without a valid 'Provisioning-Key' in the header lock_provisioning: False diff --git a/hss.py b/hss.py index f32c53e..72d741f 100644 --- a/hss.py +++ b/hss.py @@ -1,1012 +1,8 @@ -# PyHSS -# This serves as a basic 3GPP Home Subscriber Server implimenting a EIR & IMS HSS functionality -import logging -import yaml -import os -import sys -import socket -import socketserver -import binascii -import time -import _thread -import threading -import sctp -import traceback -import pprint -import diameter as DiameterLib -import systemd.daemon -from threading import Thread, Lock -from logtool import * -import contextlib -import queue - - -class ThreadJoiner: - def __init__(self, threads, thread_event): - self.threads = threads - self.thread_event = thread_event - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is not None: - self.thread_event.set() - for thread in self.threads: - while thread.is_alive(): - try: - thread.join(timeout=1) - except Exception as e: - print( - f"ThreadJoiner Exception: failed to join thread {thread}: {e}" - ) - break - +import os, sys, json, yaml class PyHSS: + def __init__(self): - # Load config from yaml file - try: - with open("config.yaml", "r") as config_stream: - self.yaml_config = yaml.safe_load(config_stream) - except: - print(f"config.yaml not found, exiting PyHSS.") - quit() - - # Setup logging - self.logtool = LogTool(HSS_Init=True) - self.logtool.setup_logger( - "HSS_Logger", - self.yaml_config["logging"]["logfiles"]["hss_logging_file"], - level=self.yaml_config["logging"]["level"], - ) - self.logger = logging.getLogger("HSS_Logger") - if self.yaml_config["logging"]["log_to_terminal"]: - logging.getLogger().addHandler(logging.StreamHandler()) - - # Setup Diameter - self.diameter_instance = DiameterLib.Diameter( - str(self.yaml_config["hss"].get("OriginHost", "")), - str(self.yaml_config["hss"].get("OriginRealm", "")), - str(self.yaml_config["hss"].get("ProductName", "")), - str(self.yaml_config["hss"].get("MNC", "")), - str(self.yaml_config["hss"].get("MCC", "")), - ) - - self.max_diameter_retries = int( - self.yaml_config["hss"].get("diameter_max_retries", 1) - ) - - - - try: - assert(self.yaml_config['prometheus']['enabled'] == True) - assert(self.yaml_config['prometheus']['async_subscriber_count'] == True) - - self.logger.info("Enabling Prometheus Async Sub thread") - #Add Prometheus Async Calls - prom_async_thread = threading.Thread( - target=self.prom_async_function, - name=f"prom_async_function", - args=(), - ) - prom_async_thread.start() - except: - self.logger.info("Prometheus Async Sub Count thread disabled") - - - - def terminate_connection(self, clientsocket, client_address, thread_event): - thread_event.set() - clientsocket.close() - self.logtool.Manage_Diameter_Peer(client_address, client_address, "remove") - - def handle_new_connection(self, clientsocket, client_address): - # Create our threading event, accessible by sibling threads in this connection. - socket_close_event = threading.Event() - try: - send_queue = queue.Queue() - self.logger.debug(f"New connection from {client_address}") - if ( - "client_socket_timeout" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["client_socket_timeout"] == 0 - ): - self.yaml_config["hss"]["client_socket_timeout"] = 120 - clientsocket.settimeout( - self.yaml_config["hss"].get("client_socket_timeout", 120) - ) - - send_data_thread = threading.Thread( - target=self.send_data, - name=f"send_data_thread", - args=(clientsocket, send_queue, socket_close_event), - ) - self.logger.debug("handle_new_connection: Starting send_data thread") - send_data_thread.start() - - self.logtool.Manage_Diameter_Peer(client_address, client_address, "add") - manage_client_thread = threading.Thread( - target=self.manage_client, - name=f"manage_client_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug("handle_new_connection: Starting manage_client thread") - manage_client_thread.start() - - threads_to_join = [manage_client_thread] - threads_to_join.append(send_data_thread) - - # If Redis is enabled, start manage_client_async and manage_client_dwr threads. - if self.yaml_config["redis"]["enabled"]: - if ( - "async_check_interval" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["async_check_interval"] == 0 - ): - self.yaml_config["hss"]["async_check_interval"] = 10 - manage_client_async_thread = threading.Thread( - target=self.manage_client_async, - name=f"manage_client_async_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_async thread" - ) - manage_client_async_thread.start() - - manage_client_dwr_thread = threading.Thread( - target=self.manage_client_dwr, - name=f"manage_client_dwr_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_dwr thread" - ) - manage_client_dwr_thread.start() - - threads_to_join.append(manage_client_async_thread) - threads_to_join.append(manage_client_dwr_thread) - - self.logger.debug( - f"handle_new_connection: Total PyHSS Active Threads: {threading.active_count()}" - ) - for thread in threading.enumerate(): - if "dummy" not in thread.name.lower(): - self.logger.debug(f"Active Thread name: {thread.name}") - - with ThreadJoiner(threads_to_join, socket_close_event): - socket_close_event.wait() - self.terminate_connection( - clientsocket, client_address, socket_close_event - ) - self.logger.debug(f"Closing thread for client; {client_address}") - return - - except Exception as e: - self.logger.error(f"Exception for client {client_address}: {e}") - self.logger.error(f"Closing connection for {client_address}") - self.terminate_connection(clientsocket, client_address, socket_close_event) - return - - @prom_diam_response_time_diam.time() - def process_Diameter_request( - self, clientsocket, client_address, diameter, data, thread_event, send_queue - ): - packet_length = diameter.decode_diameter_packet_length( - data - ) # Calculate length of packet from start of packet - if packet_length <= 32: - self.logger.error("Received an invalid packet with length <= 32") - self.terminate_connection(clientsocket, client_address, thread_event) - return - - data_sum = data + clientsocket.recv( - packet_length - 32 - ) # Recieve remainder of packet from buffer - packet_vars, avps = diameter.decode_diameter_packet( - data_sum - ) # Decode packet into array of AVPs and Dict of Packet Variables (packet_vars) - try: - packet_vars["Source_IP"] = client_address[0] - except: - self.logger.debug("Failed to add Source_IP to packet_vars") - - start_time = time.time() - origin_host = diameter.get_avp_data(avps, 264)[0] # Get OriginHost from AVP - origin_host = binascii.unhexlify(origin_host).decode("utf-8") # Format it - - # label_values = str(packet_vars['ApplicationId']), str(packet_vars['command_code']), origin_host, 'request' - prom_diam_request_count.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).inc() - - - self.logger.info( - "\n\nNew request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + ", flags " - + str(packet_vars["flags"]) - + ", e2e ID: " - + str(packet_vars["end-to-end-identifier"]) - ) - - # Gobble up any Response traffic that is sent to us: - if packet_vars["flags_bin"][0:1] == "0": - self.logger.info("Got a Response, not a request - dropping it.") - self.logger.info(packet_vars) - return - - # Send Capabilities Exchange Answer (CEA) response to Capabilites Exchange Request (CER) - elif ( - packet_vars["command_code"] == 257 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 257 (CER) from {origin_host}" - + "\n\tSending response (CEA)" - ) - try: - response = diameter.Answer_257( - packet_vars, avps, str(self.yaml_config["hss"]["bind_ip"][0]) - ) # Generate Diameter packet - # prom_diam_response_count_successful.inc() - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - # prom_diam_response_count_fail.inc() - self.logger.info("Generated CEA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - prom_diam_connected_peers.labels(origin_host).set(1) - - # Send Credit Control Answer (CCA) response to Credit Control Request (CCR) - elif ( - packet_vars["command_code"] == 272 - and packet_vars["ApplicationId"] == 16777238 - ): - self.logger.info( - f"Received 3GPP Credit-Control-Request from {origin_host}" - + "\n\tGenerating (CCA)" - ) - try: - response = diameter.Answer_16777238_272( - packet_vars, avps - ) # Generate Diameter packet - except Exception as E: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error(f"Failed to generate response {str(E)}") - self.logger.info("Generated CCA") - - # Send Device Watchdog Answer (DWA) response to Device Watchdog Requests (DWR) - elif ( - packet_vars["command_code"] == 280 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 280 (DWR) from {origin_host}" - + "\n\tSending response (DWA)" - ) - self.logger.debug(f"Total PyHSS Active Threads: {threading.active_count()}") - try: - response = diameter.Answer_280( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.info("Generated DWA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - - # Send Disconnect Peer Answer (DPA) response to Disconnect Peer Request (DPR) - elif ( - packet_vars["command_code"] == 282 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 282 (DPR) from {origin_host}" - + "\n\tForwarding request..." - ) - response = diameter.Answer_282( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated DPA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "remove") - prom_diam_connected_peers.labels(origin_host).set(0) - - # S6a Authentication Information Answer (AIA) response to Authentication Information Request (AIR) - elif ( - packet_vars["command_code"] == 318 - and packet_vars["ApplicationId"] == 16777251 - and packet_vars["flags"] == "c0" - ): - self.logger.info( - f"Received Request with command code 318 (3GPP Authentication-Information-Request) from {origin_host}" - + "\n\tGenerating (AIA)" - ) - try: - response = diameter.Answer_16777251_318( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated AIR") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for AIR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated DIAMETER_USER_DATA_NOT_AVAILABLE AIR") - - # S6a Update Location Answer (ULA) response to Update Location Request (ULR) - elif ( - packet_vars["command_code"] == 316 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 316 (3GPP Update Location-Request) from {origin_host}" - + "\n\tGenerating (ULA)" - ) - try: - response = diameter.Answer_16777251_316( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated ULA") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for ULR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated error DIAMETER_USER_DATA_NOT_AVAILABLE ULA") - - # Send ULA data & clear tx buffer - clientsocket.sendall(bytes.fromhex(response)) - response = "" - if "Insert_Subscriber_Data_Force" in yaml_config["hss"]: - if yaml_config["hss"]["Insert_Subscriber_Data_Force"] == True: - self.logger.debug("ISD triggered after ULA") - # Generate Insert Subscriber Data Request - response = diameter.Request_16777251_319( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated IDR") - # Send ISD data - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent IDR") - return - # S6a inbound Insert-Data-Answer in response to our IDR - elif ( - packet_vars["command_code"] == 319 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received response with command code 319 (3GPP Insert-Subscriber-Answer) from {origin_host}" - ) - return - # S6a Purge UE Answer (PUA) response to Purge UE Request (PUR) - elif ( - packet_vars["command_code"] == 321 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 321 (3GPP Purge UE Request) from {origin_host}" - + "\n\tGenerating (PUA)" - ) - try: - response = diameter.Answer_16777251_321( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate PUA") - self.logger.info("Generated PUA") - # S6a Notify Answer (NOA) response to Notify Request (NOR) - elif ( - packet_vars["command_code"] == 323 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 323 (3GPP Notify Request) from {origin_host}" - + "\n\tGenerating (NOA)" - ) - try: - response = diameter.Answer_16777251_323( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate NOA") - self.logger.info("Generated NOA") - # S6a Cancel Location Answer eater - elif ( - packet_vars["command_code"] == 317 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info("Received Response with command code 317 (3GPP Cancel Location Request) from " + str(origin_host)) - - # Cx Authentication Answer - elif ( - packet_vars["command_code"] == 300 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 300 (3GPP Cx User Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_300( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Auth Answer" - ) - self.logger.info(e) - self.logger.info(traceback.print_exc()) - self.logger.info( - type(e).__name__, # TypeError - __file__, # /tmp/example.py - e.__traceback__.tb_lineno # 2 - ) - - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Auth Answer") - - # Cx Server Assignment Answer - elif ( - packet_vars["command_code"] == 301 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 301 (3GPP Cx Server Assignemnt Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_301( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Server Assignment Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Server Assignment Answer") - - # Cx Location Information Answer - elif ( - packet_vars["command_code"] == 302 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 302 (3GPP Cx Location Information Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_302( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Location Information Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Location Information Answer") - - # Cx Multimedia Authentication Answer - elif ( - packet_vars["command_code"] == 303 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 303 (3GPP Cx Multimedia Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_303( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Multimedia Authentication Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Multimedia Authentication Answer") - - # Sh User-Data-Answer - elif ( - packet_vars["command_code"] == 306 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 306 (3GPP Sh User-Data Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_306( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh User-Data Answer") - - # Sh Profile-Update-Answer - elif ( - packet_vars["command_code"] == 307 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 307 (3GPP Sh Profile-Update Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_307( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh Profile-Update Answer") - - # S13 ME-Identity-Check Answer - elif ( - packet_vars["command_code"] == 324 - and packet_vars["ApplicationId"] == 16777252 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP S13 ME-Identity-Check Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777252_324( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for S13 ME-Identity Check Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated S13 ME-Identity Check Answer") - - # SLh LCS-Routing-Info-Answer - elif ( - packet_vars["command_code"] == 8388622 - and packet_vars["ApplicationId"] == 16777291 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP SLh LCS-Routing-Info-Answer Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777291_8388622( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for SLh LCS-Routing-Info-Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated SLh LCS-Routing-Info-Answer") - - # Handle Responses generated by the Async functions - elif packet_vars["flags"] == "00": - self.logger.info( - "Got response back with command code " - + str(packet_vars["command_code"]) - ) - self.logger.info("response packet_vars: " + str(packet_vars)) - self.logger.info("response avps: " + str(avps)) - response = "" - else: - self.logger.error( - "\n\nRecieved unrecognised request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + " and flags " - + str(packet_vars["flags"]) - ) - for keys in packet_vars: - self.logger.error(keys) - self.logger.error("\t" + str(packet_vars[keys])) - self.logger.error(avps) - self.logger.error("Sending negative response") - response = diameter.Respond_ResultCode( - packet_vars, avps, 3001 - ) # Generate Diameter response with "Command Unsupported" (3001) - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - - prom_diam_response_time_method.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).observe(time.time() - start_time) - - # Diameter Transmission - retries = 0 - while retries < self.max_diameter_retries: - try: - send_queue.put(bytes.fromhex(response)) - break - except socket.error as e: - self.logger.error(f"Socket error for client {client_address}: {e}") - retries += 1 - if retries > self.max_diameter_retries: - self.logger.error( - f"Max retries reached for client {client_address}. Closing connection." - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - break - time.sleep(1) # Wait for 1 second before retrying - except Exception as e: - self.logger.info("Failed to send Diameter Response") - self.logger.debug(f"Diameter Response Body: {str(response)}") - self.logger.info(e) - traceback.print_exc() - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info("Thread terminated to " + str(client_address)) - break - - def manage_client( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - data = clientsocket.recv(32) - if not data: - self.logger.info( - f"manage_client: Connection closed by {str(client_address)}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - self.process_Diameter_request( - clientsocket, - client_address, - diameter, - data, - thread_event, - send_queue, - ) - - except socket.timeout: - self.logger.warning( - f"manage_client: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except KeyboardInterrupt: - # Clean up the connection on keyboard interrupt - response = ( - diameter.Request_282() - ) # Generate Disconnect Peer Request Diameter packet - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info( - "manage_client: Connection closed nicely due to keyboard interrupt" - ) - sys.exit() - - except Exception as manage_client_exception: - self.logger.error( - f"manage_client: Exception in manage_client: {manage_client_exception}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_async( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - # # Sleep for 10 seconds to wait for the connection to come up - time.sleep(10) - self.logger.debug("manage_client_async: Getting ActivePeerDict") - self.logger.debug( - f"manage_client_async: Total PyHSS Active Threads: {threading.active_count()}" - ) - ActivePeerDict = self.logtool.GetDiameterPeers() - self.logger.debug( - f"manage_client_async: Got Active Peer dict in Async Thread: {str(ActivePeerDict)}" - ) - if client_address[0] in ActivePeerDict: - self.logger.debug( - "manage_client_async: This is host: " - + str(ActivePeerDict[str(client_address[0])]["DiameterHostname"]) - ) - DiameterHostname = str( - ActivePeerDict[str(client_address[0])]["DiameterHostname"] - ) - else: - self.logger.debug("manage_client_async: No matching Diameter Host found.") - return - - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"manage_client_async: Closing manage_client_async thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - time.sleep(self.yaml_config["hss"]["async_check_interval"]) - self.logger.debug( - f"manage_client_async: Sleep interval expired for Diameter Peer {str(DiameterHostname)}" - ) - if int(self.yaml_config["hss"]["async_check_interval"]) == 0: - self.logger.error( - f"manage_client_async: No async_check_interval Timer set - Not checking Async Queue for host connection {str(DiameterHostname)}" - ) - return - try: - self.logger.debug( - "manage_client_async: Reading from request queue '" - + str(DiameterHostname) - + "_request_queue'" - ) - data_to_send = self.logtool.RedisHMGET( - str(DiameterHostname) + "_request_queue" - ) - for key in data_to_send: - data = data_to_send[key].decode("utf-8") - send_queue.put(bytes.fromhex(data)) - self.logtool.RedisHDEL( - str(DiameterHostname) + "_request_queue", key - ) - except Exception as redis_exception: - self.logger.error( - f"manage_client_async: Redis exception in manage_client_async: {redis_exception}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - - except socket.timeout: - self.logger.warning( - f"manage_client_async: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client_async: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception: - self.logger.error( - f"manage_client_async: Terminating for host connection {str(DiameterHostname)}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_dwr( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"Closing manage_client_dwr thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - if ( - int(self.yaml_config["hss"]["device_watchdog_request_interval"]) - != 0 - ): - time.sleep( - self.yaml_config["hss"]["device_watchdog_request_interval"] - ) - else: - self.logger.info("DWR Timer to set to 0 - Not sending DWRs") - return - - except: - self.logger.error( - "No DWR Timer set - Not sending Device Watchdog Requests" - ) - return - try: - self.logger.debug("Sending Keepalive to " + str(client_address) + "...") - request = diameter.Request_280() - send_queue.put(bytes.fromhex(request)) - # clientsocket.sendall(bytes.fromhex(request)) # Send it - self.logger.debug("Sent Keepalive to " + str(client_address) + "...") - except socket.error as e: - self.logger.error( - f"manage_client_dwr: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception as e: - self.logger.error( - f"manage_client_dwr: General exception for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - - def get_socket_family(self): - if ":" in self.yaml_config["hss"]["bind_ip"][0]: - self.logger.info("IPv6 Address Specified") - return socket.AF_INET6 - else: - self.logger.info("IPv4 Address Specified") - return socket.AF_INET - - def send_data(self, clientsocket, send_queue, thread_event): - while not thread_event.is_set(): - try: - data = send_queue.get(timeout=1) - # Check if data is bytes, otherwise convert it using bytes.fromhex() - if not isinstance(data, bytes): - data = bytes.fromhex(data) - - clientsocket.sendall(data) - except ( - queue.Empty - ): # Catch the Empty exception when the queue is empty and the timeout has expired - continue - except Exception as e: - self.logger.error(f"send_data_thread: Exception: {e}") - return - - def start_server(self): - if self.yaml_config["hss"]["transport"] == "SCTP": - self.logger.debug("Using SCTP for Transport") - # Create a SCTP socket - sock = sctp.sctpsocket_tcp(self.get_socket_family()) - sock.initparams.num_ostreams = 64 - # Loop through the possible Binding IPs from the config and bind to each for Multihoming - server_addresses = [] - - # Prepend each entry into list, so the primary IP is bound first - for host in self.yaml_config["hss"]["bind_ip"]: - self.logger.info("Seting up SCTP binding on IP address " + str(host)) - this_IP_binding = [ - (str(host), int(self.yaml_config["hss"]["bind_port"])) - ] - server_addresses = this_IP_binding + server_addresses - - print("server_addresses are: " + str(server_addresses)) - sock.bindx(server_addresses) - self.logger.info("PyHSS listening on SCTP port " + str(server_addresses)) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming SCTP connections - sock.listen(20) - elif self.yaml_config["hss"]["transport"] == "TCP": - self.logger.debug("Using TCP socket") - # Create a TCP/IP socket - sock = socket.socket(self.get_socket_family(), socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Bind the socket to the port - server_address = ( - str(self.yaml_config["hss"]["bind_ip"][0]), - int(self.yaml_config["hss"]["bind_port"]), - ) - sock.bind(server_address) - self.logger.debug( - "PyHSS listening on TCP port " - + str(self.yaml_config["hss"]["bind_ip"][0]) - ) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming TCP connections - sock.listen(20) - else: - self.logger.error("No valid transports found (No SCTP or TCP) - Exiting") - quit() - - while True: - # Wait for a connection - self.logger.info("Waiting for a connection...") - connection, client_address = sock.accept() - _thread.start_new_thread( - self.handle_new_connection, - ( - connection, - client_address, - ), - ) - - - def prom_async_function(self): - while True: - self.logger.debug("Running prom_async_function") - self.diameter_instance.Generate_Prom_Stats() - time.sleep(120) - + pass -if __name__ == "__main__": - pyHss = PyHSS() - pyHss.start_server() + \ No newline at end of file diff --git a/lib/S6a_crypt.py b/lib/S6a_crypt.py index 0a489b3..c1ab38f 100755 --- a/lib/S6a_crypt.py +++ b/lib/S6a_crypt.py @@ -2,17 +2,16 @@ import binascii import base64 import logging -import logtool import os import sys sys.path.append(os.path.realpath('../')) import yaml -with open("config.yaml", 'r') as stream: +with open("../config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -logtool = logtool.LogTool() -logtool.setup_logger('CryptoLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) +# logtool = logtool.LogTool() +# logtool.setup_logger('CryptoLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) CryptoLogger = logging.getLogger('CryptoLogger') CryptoLogger.info("Initialised Diameter Logger, importing database") diff --git a/lib/banners.py b/lib/banners.py new file mode 100644 index 0000000..0c3f51b --- /dev/null +++ b/lib/banners.py @@ -0,0 +1,74 @@ +class Banners: + + def diameterService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Diameter Service + +""" + return bannerText + + + def hssService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + HSS Service + +""" + return bannerText + + def georedService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Geographic Redundancy Service + +""" + return bannerText + + def metricService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Metric Service + +""" + return bannerText \ No newline at end of file diff --git a/lib/database.py b/lib/database.py index bc8a6aa..e27583d 100755 --- a/lib/database.py +++ b/lib/database.py @@ -17,9 +17,7 @@ import traceback from contextlib import contextmanager import logging -import logtool import pprint -from logtool import * from construct import Default import S6a_crypt import requests @@ -29,16 +27,16 @@ import threading import yaml -with open("config.yaml", 'r') as stream: +with open("../config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -logtool = logtool.LogTool() -logtool.setup_logger('DBLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) +# logtool = logtool.LogTool() +# logtool.setup_logger('DBLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) DBLogger = logging.getLogger('DBLogger') DBLogger.info("DB Log Initialised.") db_string = 'mysql://' + str(yaml_config['database']['username']) + ':' + str(yaml_config['database']['password']) + '@' + str(yaml_config['database']['server']) + '/' + str(yaml_config['database']['database'] + "?autocommit=true") -print(db_string) +# print(db_string) engine = create_engine( db_string, echo = yaml_config['logging'].get('sqlalchemy_sql_echo', True), @@ -291,39 +289,41 @@ class SUBSCRIBER_ATTRIBUTES(Base): DBLogger.debug("Database already created") def load_IMEI_database_into_Redis(): - try: - DBLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) - csvfile = open(str(yaml_config['eir']['tac_database_csv'])) - DBLogger.info("This may take a few seconds to buffer into Redis...") - except: - DBLogger.error("Failed to read CSV file of IMEI TAC database") - return - try: - count = 0 - for line in csvfile: - line = line.replace('"', '') #Strip excess invered commas - line = line.replace("'", '') #Strip excess invered commas - line = line.rstrip() #Strip newlines - result = line.split(',') - tac_prefix = result[0] - name = result[1].lstrip() - model = result[2].lstrip() - if count == 0: - DBLogger.info("Checking to see if entries are already present...") - #DBLogger.info("Searching Redis for key " + str(tac_prefix) + " to see if data already provisioned") - redis_imei_result = logtool.RedisHMGET(key=str(tac_prefix)) - if len(redis_imei_result) != 0: - DBLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") - break - else: - DBLogger.info("No data loaded into Redis, proceeding to load...") - imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - logtool.RedisHMSET(key=str(tac_prefix), value_dict=imei_result) - count = count +1 - DBLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") - except Exception as E: - DBLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) - return + return + #@@Fixme + # try: + # DBLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) + # csvfile = open(str(yaml_config['eir']['tac_database_csv'])) + # DBLogger.info("This may take a few seconds to buffer into Redis...") + # except: + # DBLogger.error("Failed to read CSV file of IMEI TAC database") + # return + # try: + # count = 0 + # for line in csvfile: + # line = line.replace('"', '') #Strip excess invered commas + # line = line.replace("'", '') #Strip excess invered commas + # line = line.rstrip() #Strip newlines + # result = line.split(',') + # tac_prefix = result[0] + # name = result[1].lstrip() + # model = result[2].lstrip() + # if count == 0: + # DBLogger.info("Checking to see if entries are already present...") + # #DBLogger.info("Searching Redis for key " + str(tac_prefix) + " to see if data already provisioned") + # redis_imei_result = logtool.RedisHMGET(key=str(tac_prefix)) + # if len(redis_imei_result) != 0: + # DBLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") + # break + # else: + # DBLogger.info("No data loaded into Redis, proceeding to load...") + # imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} + # logtool.RedisHMSET(key=str(tac_prefix), value_dict=imei_result) + # count = count +1 + # DBLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") + # except Exception as E: + # DBLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) + # return #Load IMEI TAC database into Redis if enabled if ('tac_database_csv' in yaml_config['eir']) and (yaml_config['redis']['enabled'] == True): @@ -880,71 +880,73 @@ def get_last_operation_log(existingSession=None): def GeoRed_Push_Request(remote_hss, json_data, transaction_id, url=None): headers = {"Content-Type": "application/json", "Transaction-Id": str(transaction_id)} DBLogger.debug("transaction_id: " + str(transaction_id) + " pushing update to " + str(remote_hss).replace('http://', '')) - try: - session = requests.Session() - # Create a Retry object with desired parameters - retries = Retry(total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]) - - # Create an HTTPAdapter and pass the Retry object - adapter = HTTPAdapter(max_retries=retries) - - session.mount('http://', adapter) - if url == None: - endpoint = 'geored' - r = session.patch(str(remote_hss) + '/geored/', data=json.dumps(json_data), headers=headers) - else: - endpoint = url.split('/', 1)[0] - r = session.patch(url, data=json.dumps(json_data), headers=headers) - DBLogger.debug("transaction_id: " + str(transaction_id) + " updated on " + str(remote_hss).replace('http://', '') + " with status code " + str(r.status_code)) - if str(r.status_code).startswith('2'): - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code=str(r.status_code), - error="" - ).inc() - else: - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code=str(r.status_code), - error=str(r.reason) - ).inc() - except ConnectionError as e: - error_message = str(e) - if "Name or service not known" in error_message: - DBLogger.error("transaction_id: " + str(transaction_id) + " name or service not known") - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="No matching DNS entry found" - ).inc() - else: - print("Other ConnectionError:", error_message) - DBLogger.error("transaction_id: " + str(transaction_id) + " " + str(error_message)) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="Connection Refused" - ).inc() - except Timeout: - DBLogger.error("transaction_id: " + str(transaction_id) + " timed out connecting to peer " + str(remote_hss).replace('http://', '')) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error="Timeout" - ).inc() - except Exception as e: - DBLogger.error("transaction_id: " + str(transaction_id) + " unexpected error " + str(e) + " when connecting to peer " + str(remote_hss).replace('http://', '')) - prom_http_geored.labels( - geored_host=str(remote_hss).replace('http://', ''), - endpoint=endpoint, - http_response_code='000', - error=str(e) - ).inc() + #@@Fixme + # try: + # session = requests.Session() + # # Create a Retry object with desired parameters + # retries = Retry(total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]) + + # # Create an HTTPAdapter and pass the Retry object + # adapter = HTTPAdapter(max_retries=retries) + + # session.mount('http://', adapter) + # if url == None: + # endpoint = 'geored' + # r = session.patch(str(remote_hss) + '/geored/', data=json.dumps(json_data), headers=headers) + # else: + # endpoint = url.split('/', 1)[0] + # r = session.patch(url, data=json.dumps(json_data), headers=headers) + # DBLogger.debug("transaction_id: " + str(transaction_id) + " updated on " + str(remote_hss).replace('http://', '') + " with status code " + str(r.status_code)) + # if str(r.status_code).startswith('2'): + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code=str(r.status_code), + # error="" + # ).inc() + # else: + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code=str(r.status_code), + # error=str(r.reason) + # ).inc() + # except ConnectionError as e: + # error_message = str(e) + # if "Name or service not known" in error_message: + # DBLogger.error("transaction_id: " + str(transaction_id) + " name or service not known") + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code='000', + # error="No matching DNS entry found" + # ).inc() + # else: + # print("Other ConnectionError:", error_message) + # DBLogger.error("transaction_id: " + str(transaction_id) + " " + str(error_message)) + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code='000', + # error="Connection Refused" + # ).inc() + # except Timeout: + # DBLogger.error("transaction_id: " + str(transaction_id) + " timed out connecting to peer " + str(remote_hss).replace('http://', '')) + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code='000', + # error="Timeout" + # ).inc() + # except Exception as e: + # DBLogger.error("transaction_id: " + str(transaction_id) + " unexpected error " + str(e) + " when connecting to peer " + str(remote_hss).replace('http://', '')) + # prom_http_geored.labels( + # geored_host=str(remote_hss).replace('http://', ''), + # endpoint=endpoint, + # http_response_code='000', + # error=str(e) + # ).inc() + return @@ -1951,18 +1953,19 @@ def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): try: device_info = get_device_info_from_TAC(imei=str(imei)) DBLogger.debug("Got Device Info: " + str(device_info)) - prom_eir_devices.labels( - imei_prefix=device_info['tac_prefix'], - device_type=device_info['name'], - device_name=device_info['model'] - ).inc() + #@@Fixme + # prom_eir_devices.labels( + # imei_prefix=device_info['tac_prefix'], + # device_type=device_info['name'], + # device_name=device_info['model'] + # ).inc() except Exception as E: DBLogger.debug("Failed to get device info from TAC") - prom_eir_devices.labels( - imei_prefix=str(imei)[0:8], - device_type='Unknown', - device_name='Unknown' - ).inc() + # prom_eir_devices.labels( + # imei_prefix=str(imei)[0:8], + # device_type='Unknown', + # device_name='Unknown' + # ).inc() else: DBLogger.debug("No TAC database configured, skipping device info lookup") @@ -2104,23 +2107,27 @@ def get_device_info_from_TAC(imei): #Try 8 digit TAC try: DBLogger.debug("Trying to match on 8 Digit IMEI") - imei_result = logtool.RedisHMGET(str(imei[0:8])) - print("Got back: " + str(imei_result)) - imei_result = dict_bytes_to_dict_string(imei_result) - assert(len(imei_result) != 0) - DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - return imei_result + #@@Fixme + # imei_result = logtool.RedisHMGET(str(imei[0:8])) + # print("Got back: " + str(imei_result)) + # imei_result = dict_bytes_to_dict_string(imei_result) + # assert(len(imei_result) != 0) + # DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) + # return imei_result + return "0" except: DBLogger.debug("Failed to match on 8 digit IMEI") try: DBLogger.debug("Trying to match on 6 Digit IMEI") - imei_result = logtool.RedisHMGET(str(imei[0:6])) - print("Got back: " + str(imei_result)) - imei_result = dict_bytes_to_dict_string(imei_result) - assert(len(imei_result) != 0) - DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - return imei_result + #@@Fixme + # imei_result = logtool.RedisHMGET(str(imei[0:6])) + # print("Got back: " + str(imei_result)) + # imei_result = dict_bytes_to_dict_string(imei_result) + # assert(len(imei_result) != 0) + # DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) + # return imei_result + return "0" except: DBLogger.debug("Failed to match on 6 digit IMEI") diff --git a/lib/diameter.py b/lib/diameter.py index 57adb13..1be7c4c 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -10,29 +10,31 @@ import random import ipaddress import jinja2 - -# with open("config.yaml", 'r') as stream: -# yaml_config = (yaml.safe_load(stream)) +import traceback +import database +import yaml # #Setup Logging # import logtool # from logtool import * # logtool = logtool.LogTool() -# logtool.setup_logger('DiameterLogger', yaml_config['logging']['logfiles']['diameter_logging_file'], level=yaml_config['logging']['level']) -# DiameterLogger = logging.getLogger('DiameterLogger') - -# DiameterLogger.info("Initialised Diameter Logger, importing database") -# import database -# DiameterLogger.info("Imported database") +# logtool.setup_logger('DiameterLogger', self.yaml_config['logging']['logfiles']['diameter_logging_file'], level=self.yaml_config['logging']['level']) -# if yaml_config['redis']['enabled'] == True: -# DiameterLogger.debug("Redis support enabled") -# import redis +class Diameter: + def __init__(self, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999"): + with open("../config.yaml", 'r') as stream: + self.yaml_config = (yaml.safe_load(stream)) -class Diameter: - ##Function Definitions + self.OriginHost = self.string_to_hex(originHost) + self.OriginRealm = self.string_to_hex(originRealm) + self.ProductName = self.string_to_hex(productName) + self.MNC = str(mnc) + self.MCC = str(mcc) + self.diameterLibLogger = logging.getLogger('DiameterLibLogger') + self.diameterLibLogger.info("Initialized Diameter for " + str(self.OriginHost) + " at Realm " + str(self.OriginRealm) + " serving as Product Name " + str(self.ProductName)) + self.diameterLibLogger.info("PLMN is " + str(self.MCC) + "/" + str(self.MNC)) #Generates rounding for calculating padding def myround(self, n, base=4): @@ -56,7 +58,7 @@ def ip_to_hex(self, ip): else: ip_hex = "0002" #IPv6 ip_hex += format(ipaddress.IPv6Address(ip), 'X') - #DiameterLogger.debug("Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex)) + #self.diameterLibLogger.debug("Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex)) return ip_hex def hex_to_int(self, hex): @@ -106,12 +108,12 @@ def Reverse(self, str): return (slicedString) def DecodePLMN(self, plmn): - DiameterLogger.debug("Decoded PLMN: " + str(plmn)) + self.diameterLibLogger.debug("Decoded PLMN: " + str(plmn)) mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') - DiameterLogger.debug("Decoded MCC: " + mcc) + self.diameterLibLogger.debug("Decoded MCC: " + mcc) mnc = self.Reverse(plmn[4:6]) - DiameterLogger.debug("Decoded MNC: " + mnc) + self.diameterLibLogger.debug("Decoded MNC: " + mnc) return mcc, mnc def EncodePLMN(self, mcc, mnc): @@ -126,50 +128,50 @@ def EncodePLMN(self, mcc, mnc): plmn = '' for bits in plmn_list: plmn = plmn + bits - DiameterLogger.debug("Encoded PLMN: " + str(plmn)) + self.diameterLibLogger.debug("Encoded PLMN: " + str(plmn)) return plmn def TBCD_special_chars(self, input): - DiameterLogger.debug("Special character possible in " + str(input)) + self.diameterLibLogger.debug("Special character possible in " + str(input)) if input == "*": - DiameterLogger.debug("Found * - Returning 1010") + self.diameterLibLogger.debug("Found * - Returning 1010") return "1010" elif input == "#": - DiameterLogger.debug("Found # - Returning 1011") + self.diameterLibLogger.debug("Found # - Returning 1011") return "1011" elif input == "a": - DiameterLogger.debug("Found a - Returning 1100") + self.diameterLibLogger.debug("Found a - Returning 1100") return "1100" elif input == "b": - DiameterLogger.debug("Found b - Returning 1101") + self.diameterLibLogger.debug("Found b - Returning 1101") return "1101" elif input == "c": - DiameterLogger.debug("Found c - Returning 1100") + self.diameterLibLogger.debug("Found c - Returning 1100") return "1100" else: binform = "{:04b}".format(int(input)) - DiameterLogger.debug("input " + str(input) + " is not a special char, converted to bin: " + str(binform)) + self.diameterLibLogger.debug("input " + str(input) + " is not a special char, converted to bin: " + str(binform)) return (binform) def TBCD_encode(self, input): - DiameterLogger.debug("TBCD_encode input value is " + str(input)) + self.diameterLibLogger.debug("TBCD_encode input value is " + str(input)) offset = 0 output = '' matches = ['*', '#', 'a', 'b', 'c'] while offset < len(input): if len(input[offset:offset+2]) == 2: - DiameterLogger.debug("processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset)) + self.diameterLibLogger.debug("processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset)) bit = input[offset:offset+2] #Get two digits at a time bit = bit[::-1] #Reverse them #Check if *, #, a, b or c if any(x in bit for x in matches): - DiameterLogger.debug("Special char in bit " + str(bit)) + self.diameterLibLogger.debug("Special char in bit " + str(bit)) new_bit = '' new_bit = new_bit + str(self.TBCD_special_chars(bit[0])) new_bit = new_bit + str(self.TBCD_special_chars(bit[1])) - DiameterLogger.debug("Final bin output of new_bit is " + str(new_bit)) + self.diameterLibLogger.debug("Final bin output of new_bit is " + str(new_bit)) bit = hex(int(new_bit, 2))[2:] #Get Hex value - DiameterLogger.debug("Formatted as Hex this is " + str(bit)) + self.diameterLibLogger.debug("Formatted as Hex this is " + str(bit)) output = output + bit offset = offset + 2 else: @@ -177,23 +179,23 @@ def TBCD_encode(self, input): last_digit = str(input[offset:offset+2]) #Check if *, #, a, b or c if any(x in last_digit for x in matches): - DiameterLogger.debug("Special char in bit " + str(bit)) + self.diameterLibLogger.debug("Special char in bit " + str(bit)) new_bit = '' new_bit = new_bit + '1111' #Add the F first #Encode the symbol into binary and append it to the new_bit var new_bit = new_bit + str(self.TBCD_special_chars(last_digit)) - DiameterLogger.debug("Final bin output of new_bit is " + str(new_bit)) + self.diameterLibLogger.debug("Final bin output of new_bit is " + str(new_bit)) bit = hex(int(new_bit, 2))[2:] #Get Hex value - DiameterLogger.debug("Formatted as Hex this is " + str(bit)) + self.diameterLibLogger.debug("Formatted as Hex this is " + str(bit)) else: bit = "f" + last_digit offset = offset + 2 output = output + bit - DiameterLogger.debug("TBCD_encode final output value is " + str(output)) + self.diameterLibLogger.debug("TBCD_encode final output value is " + str(output)) return output def TBCD_decode(self, input): - DiameterLogger.debug("TBCD_decode Input value is " + str(input)) + self.diameterLibLogger.debug("TBCD_decode Input value is " + str(input)) offset = 0 output = '' while offset < len(input): @@ -205,21 +207,9 @@ def TBCD_decode(self, input): else: #If f in bit strip it bit = input[offset:offset+2] output = output + bit[1] - DiameterLogger.debug("TBCD_decode output value is " + str(output)) + self.diameterLibLogger.debug("TBCD_decode output value is " + str(output)) return output - #Hexify the vars we got when initializing the class - #@@@Fixme - def __init__(self): - self.OriginHost = self.string_to_hex("OriginHost") - self.OriginRealm = self.string_to_hex("OriginRealm") - self.ProductName = self.string_to_hex("ProductName") - self.MNC = str(505) - self.MCC = str(52) - - DiameterLogger.info("Initialized Diameter for " + str(OriginHost) + " at Realm " + str(OriginRealm) + " serving as Product Name " + str(ProductName)) - DiameterLogger.info("PLMN is " + str(MCC) + "/" + str(MNC)) - #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) #AVP content must already be in HEX - This can be done with binascii.hexlify(avp_content.encode()) def generate_avp(self, avp_code, avp_flags, avp_content): @@ -257,8 +247,8 @@ def generate_vendor_avp(self, avp_code, avp_flags, avp_vendorid, avp_content): avp_padding = '' else: #Not multiple of 4 - Padding needed rounded_value = self.myround(avp_length) - DiameterLogger.debug("Rounded value is " + str(rounded_value)) - DiameterLogger.debug("Has " + str( int( rounded_value - avp_length)) + " bytes of padding") + self.diameterLibLogger.debug("Rounded value is " + str(rounded_value)) + self.diameterLibLogger.debug("Has " + str( int( rounded_value - avp_length)) + " bytes of padding") avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) @@ -283,7 +273,6 @@ def generate_diameter_packet(self, packet_version, packet_flags, packet_command_ return packet_hex def decode_diameter_packet(self, data): - print(data) packet_vars = {} avps = [] @@ -364,7 +353,7 @@ def decode_avp_packet(self, data): logging.debug("AVP length 0 error v2") pass else: - DiameterLogger.debug("failed to decode sub-avp - error: " + str(e)) + self.diameterLibLogger.debug("failed to decode sub-avp - error: " + str(e)) pass @@ -388,6 +377,52 @@ def decode_diameter_packet_length(self, data): else: return False + def generateDiameterResponse(self, requestBinaryData: str) -> str: + packet_vars, avps = self.decode_diameter_packet(requestBinaryData) + origin_host = self.get_avp_data(avps, 264)[0] + origin_host = binascii.unhexlify(origin_host).decode("utf-8") + response = '' + + diameterList = [ + {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, + {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, + {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, + {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, + {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, + {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, + {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, + {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, + {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + ] + + self.diameterLibLogger.debug(f"Generating a diameter response") + + # Drop packet if it's a response packet: + if packet_vars["flags_bin"][0:1] == "0": + self.diameterLibLogger.debug("Got a Response, not a request - dropping it.") + self.diameterLibLogger.debug(packet_vars) + return + + for diameterApplication in diameterList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if 'flags' in diameterApplication: + assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + response = diameterApplication["responseMethod"](packet_vars, avps) + self.diameterLibLogger.debug(f"[diameter.py] Successfully generated response: {response}") + except Exception as e: + continue + + return response + def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: if avp_dicts['avp_code'] == 278: @@ -397,23 +432,23 @@ def AVP_278_Origin_State_Incriment(self, avps): return origin_state_incriment_hex def Charging_Rule_Generator(self, ChargingRules, ue_ip): - DiameterLogger.debug("Called Charging_Rule_Generator") + self.diameterLibLogger.debug("Called Charging_Rule_Generator") #Install Charging Rules - DiameterLogger.info("Naming Charging Rule") + self.diameterLibLogger.info("Naming Charging Rule") Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) - DiameterLogger.info("Named Charging Rule") + self.diameterLibLogger.info("Named Charging Rule") #Populate all Flow Information AVPs Flow_Information = '' for tft in ChargingRules['tft']: - DiameterLogger.info(tft) + self.diameterLibLogger.info(tft) #If {{ UE_IP }} in TFT splice in the real UE IP Value try: tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) - DiameterLogger.info("Spliced in UE IP into TFT: " + str(tft['tft_string'])) + self.diameterLibLogger.info("Spliced in UE IP into TFT: " + str(tft['tft_string'])) except Exception as E: - DiameterLogger.error("Failed to splice in UE IP into flow description") + self.diameterLibLogger.error("Failed to splice in UE IP into flow description") #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) @@ -421,90 +456,91 @@ def Charging_Rule_Generator(self, ChargingRules, ue_ip): Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) - DiameterLogger.info("Defined Flow_Status: " + str(Flow_Status)) + self.diameterLibLogger.info("Defined Flow_Status: " + str(Flow_Status)) - DiameterLogger.info("Defining QoS information") + self.diameterLibLogger.info("Defining QoS information") #QCI QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) #ARP - DiameterLogger.info("Defining ARP information") + self.diameterLibLogger.info("Defining ARP information") AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_capability']), 4)) AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_vulnerability']), 4)) ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - DiameterLogger.info("Defining MBR information") + self.diameterLibLogger.info("Defining MBR information") #Max Requested Bandwidth Bandwidth_info = '' Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) - DiameterLogger.info("Defining GBR information") + self.diameterLibLogger.info("Defining GBR information") #GBR if int(ChargingRules['gbr_ul']) != 0: Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) if int(ChargingRules['gbr_dl']) != 0: Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) - DiameterLogger.info("Defined Bandwith Info: " + str(Bandwidth_info)) + self.diameterLibLogger.info("Defined Bandwith Info: " + str(Bandwidth_info)) #Populate QoS Information QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) - DiameterLogger.info("Defined QoS_Information: " + str(QoS_Information)) + self.diameterLibLogger.info("Defined QoS_Information: " + str(QoS_Information)) #Precedence - DiameterLogger.info("Defining Precedence information") + self.diameterLibLogger.info("Defining Precedence information") Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) - DiameterLogger.info("Defined Precedence " + str(Precedence)) + self.diameterLibLogger.info("Defined Precedence " + str(Precedence)) #Rating Group - DiameterLogger.info("Defining Rating Group information") + self.diameterLibLogger.info("Defining Rating Group information") if ChargingRules['rating_group'] != None: RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID else: RatingGroup = '' - DiameterLogger.info("Defined Rating Group " + str(ChargingRules['rating_group'])) + self.diameterLibLogger.info("Defined Rating Group " + str(ChargingRules['rating_group'])) #Complete Charging Rule Defintion - DiameterLogger.info("Collating ChargingRuleDef") + self.diameterLibLogger.info("Collating ChargingRuleDef") ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) #Charging Rule Install - DiameterLogger.info("Collating ChargingRuleDef") + self.diameterLibLogger.info("Collating ChargingRuleDef") return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) def Get_IMS_Subscriber_Details_from_AVP(self, username): #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data username = binascii.unhexlify(username).decode('utf-8') - DiameterLogger.info("Username AVP is present, value is " + str(username)) + self.diameterLibLogger.info("Username AVP is present, value is " + str(username)) username = username.split('@')[0] #Strip Domain to get User part username = username[4:] #Strip tel: or sip: prefix #Determine if dealing with IMSI or MSISDN if (len(username) == 15) or (len(username) == 16): - DiameterLogger.debug("We have an IMSI: " + str(username)) + self.diameterLibLogger.debug("We have an IMSI: " + str(username)) ims_subscriber_details = database.Get_IMS_Subscriber(imsi=username) else: - DiameterLogger.debug("We have an msisdn: " + str(username)) + self.diameterLibLogger.debug("We have an msisdn: " + str(username)) ims_subscriber_details = database.Get_IMS_Subscriber(msisdn=username) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) return ims_subscriber_details def Generate_Prom_Stats(self): - DiameterLogger.debug("Called Generate_Prom_Stats") - try: - prom_ims_subs_value = len(database.Get_Served_IMS_Subscribers(get_local_users_only=True)) - prom_ims_subs.set(prom_ims_subs_value) - prom_mme_subs_value = len(database.Get_Served_Subscribers(get_local_users_only=True)) - prom_mme_subs.set(prom_mme_subs_value) - prom_pcrf_subs_value = len(database.Get_Served_PCRF_Subscribers(get_local_users_only=True)) - prom_pcrf_subs.set(prom_pcrf_subs_value) - except Exception as e: - DiameterLogger.debug("Failed to generate Prometheus Stats for IMS Subscribers") - DiameterLogger.debug(e) - DiameterLogger.debug("Generated Prometheus Stats for IMS Subscribers") + self.diameterLibLogger.debug("Called Generate_Prom_Stats") + #@@ Fixme + # try: + # prom_ims_subs_value = len(database.Get_Served_IMS_Subscribers(get_local_users_only=True)) + # prom_ims_subs.set(prom_ims_subs_value) + # prom_mme_subs_value = len(database.Get_Served_Subscribers(get_local_users_only=True)) + # prom_mme_subs.set(prom_mme_subs_value) + # prom_pcrf_subs_value = len(database.Get_Served_PCRF_Subscribers(get_local_users_only=True)) + # prom_pcrf_subs.set(prom_pcrf_subs_value) + # except Exception as e: + # self.diameterLibLogger.debug("Failed to generate Prometheus Stats for IMS Subscribers") + # self.diameterLibLogger.debug(e) + # self.diameterLibLogger.debug("Generated Prometheus Stats for IMS Subscribers") return @@ -512,7 +548,7 @@ def Generate_Prom_Stats(self): #### Diameter Answers #### #Capabilities Exchange Answer - def Answer_257(self, packet_vars, avps, recv_ip): + def Answer_257(self, packet_vars, avps): avp = '' #Initiate empty var AVP avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -520,7 +556,7 @@ def Answer_257(self, packet_vars, avps, recv_ip): for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it if avps_to_check['avp_code'] == 278: avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) - for host in yaml_config['hss']['bind_ip']: #Loop through all IPs from Config and add to response + for host in self.yaml_config['hss']['bind_ip']: #Loop through all IPs from Config and add to response avp += self.generate_avp(257, 40, self.ip_to_hex(host)) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) avp += self.generate_avp(266, 40, "00000000") #Vendor-Id avp += self.generate_avp(269, "00", self.ProductName) #Product-Name @@ -547,7 +583,7 @@ def Answer_257(self, packet_vars, avps, recv_ip): avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) response = self.generate_diameter_packet("01", "00", 257, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated CEA") + self.diameterLibLogger.debug("Successfully Generated CEA") return response #Device Watchdog Answer @@ -561,7 +597,7 @@ def Answer_280(self, packet_vars, avps): if avps_to_check['avp_code'] == 278: avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated DWA") + self.diameterLibLogger.debug("Successfully Generated DWA") orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it return response @@ -573,7 +609,7 @@ def Answer_282(self, packet_vars, avps): avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(268, 40, "000007d1") #Result Code (DIAMETER_SUCCESS (2001)) response = self.generate_diameter_packet("01", "00", 282, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated DPA") + self.diameterLibLogger.debug("Successfully Generated DPA") return response #3GPP S6a/S6d Update Location Answer @@ -605,20 +641,20 @@ def Answer_16777251_316(self, packet_vars, avps): imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI try: subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - DiameterLogger.debug("Got back subscriber_details: " + str(subscriber_details)) + self.diameterLibLogger.debug("Got back subscriber_details: " + str(subscriber_details)) except ValueError as e: - DiameterLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - DiameterLogger.error("Error is " + str(e)) - DiameterLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") + self.diameterLibLogger.error("failed to get data backfrom database for imsi " + str(imsi)) + self.diameterLibLogger.error("Error is " + str(e)) + self.diameterLibLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") + self.diameterLibLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") return response except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.diameterLibLogger.critical(message) + self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise #Store MME Location into Database @@ -626,7 +662,7 @@ def Answer_16777251_316(self, packet_vars, avps): OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - DiameterLogger.debug("Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm)) + self.diameterLibLogger.debug("Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm)) #Find Remote Peer we need to address CLRs through try: #Check if we have a record-route set as that's where we'll need to send the response @@ -634,8 +670,8 @@ def Answer_16777251_316(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) + remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) + self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) @@ -671,34 +707,34 @@ def Answer_16777251_316(self, packet_vars, avps): #Split the APN list into a list apn_list = subscriber_details['apn_list'].split(',') - DiameterLogger.debug("Current APN List: " + str(apn_list)) + self.diameterLibLogger.debug("Current APN List: " + str(apn_list)) #Remove the default APN from the list try: apn_list.remove(str(subscriber_details['default_apn'])) except: - DiameterLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") + self.diameterLibLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") pass #Add default APN in first position apn_list.insert(0, str(subscriber_details['default_apn'])) - DiameterLogger.debug("APN list: " + str(apn_list)) + self.diameterLibLogger.debug("APN list: " + str(apn_list)) APN_context_identifer_count = 1 for apn_id in apn_list: #Per APN Setup - DiameterLogger.debug("Processing APN ID " + str(apn_id)) + self.diameterLibLogger.debug("Processing APN ID " + str(apn_id)) try: apn_data = database.Get_APN(apn_id) except: - DiameterLogger.error("Failed to get APN " + str(apn_id)) + self.diameterLibLogger.error("Failed to get APN " + str(apn_id)) continue APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_data['apn']))) - DiameterLogger.debug("Setting APN Configuration Profile") + self.diameterLibLogger.debug("Setting APN Configuration Profile") #Sub AVPs of APN Configuration Profile APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(int(apn_data['ip_version']), 4)) - DiameterLogger.debug("Setting APN AMBR") + self.diameterLibLogger.debug("Setting APN AMBR") #AMBR AMBR = '' #Initiate empty var AVP for AMBR apn_ambr_ul = int(apn_data['apn_ambr_ul']) @@ -707,7 +743,7 @@ def Answer_16777251_316(self, packet_vars, avps): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") + self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) @@ -719,24 +755,24 @@ def Answer_16777251_316(self, packet_vars, avps): #Try static IP allocation try: subscriber_routing_dict = database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details - DiameterLogger.info("Got static UE IP " + str(subscriber_routing_dict)) - DiameterLogger.debug("Found static IP for UE " + str(subscriber_routing_dict['ip_address'])) + self.diameterLibLogger.info("Got static UE IP " + str(subscriber_routing_dict)) + self.diameterLibLogger.debug("Found static IP for UE " + str(subscriber_routing_dict['ip_address'])) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) except Exception as E: - DiameterLogger.debug("Error getting static UE IP: " + str(E)) + self.diameterLibLogger.debug("Error getting static UE IP: " + str(E)) Served_Party_Address = "" #if 'PDN_GW_Allocation_Type' in apn_profile: - # DiameterLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) + # self.diameterLibLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) # PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - # DiameterLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) + # self.diameterLibLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) # else: # PDN_GW_Allocation_Type = '' # if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - # DiameterLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) + # self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) # VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - # DiameterLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) + # self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) # else: # VPLMN_Dynamic_Address_Allowed = '' PDN_GW_Allocation_Type = '' @@ -744,7 +780,7 @@ def Answer_16777251_316(self, packet_vars, avps): #If static SMF / PGW-C defined if apn_data['pgw_address'] is not None: - DiameterLogger.info("MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address'])) + self.diameterLibLogger.info("MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address'])) MIP_Home_Agent_Address = self.generate_avp(334, '40', self.ip_to_hex(apn_data['pgw_address'])) MIP6_Agent_Info = self.generate_avp(486, '40', MIP_Home_Agent_Address) else: @@ -757,40 +793,40 @@ def Answer_16777251_316(self, packet_vars, avps): #Incriment Context Identifier Count to keep track of how many APN Profiles returned APN_context_identifer_count = APN_context_identifer_count + 1 - DiameterLogger.debug("Completed processing APN ID " + str(apn_id)) + self.diameterLibLogger.debug("Completed processing APN ID " + str(apn_id)) subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_Configuration_Profile + APN_Configuration) try: - DiameterLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") + self.diameterLibLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(subscriber_details['msisdn']))) #MSISDN - DiameterLogger.debug(msisdn_avp) + self.diameterLibLogger.debug(msisdn_avp) subscription_data += msisdn_avp except Exception as E: - DiameterLogger.error("Failed to populate MSISDN in ULA due to error " + str(E)) + self.diameterLibLogger.error("Failed to populate MSISDN in ULA due to error " + str(E)) if 'RAT_freq_priorityID' in subscriber_details: - DiameterLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") + self.diameterLibLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - DiameterLogger.debug("Adding rat_freq_priorityID: " + str(rat_freq_priorityID)) + self.diameterLibLogger.debug("Adding rat_freq_priorityID: " + str(rat_freq_priorityID)) subscription_data += rat_freq_priorityID if 'charging_characteristics' in subscriber_details: - DiameterLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA") + self.diameterLibLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA") _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, str(subscriber_details['charging_characteristics'])) subscription_data += _3gpp_charging_characteristics - DiameterLogger.debug("Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics)) + self.diameterLibLogger.debug("Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics)) #ToDo - Fix this # if 'APN_OI_replacement' in subscriber_details: - # DiameterLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") + # self.diameterLibLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") # subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated ULA") + self.diameterLibLogger.debug("Successfully Generated ULA") return response #3GPP S6a/S6d Authentication Information Answer @@ -802,17 +838,18 @@ def Answer_16777251_318(self, packet_vars, avps): try: subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details except ValueError as e: - DiameterLogger.info("Minor getting subscriber details for IMSI " + str(imsi)) - DiameterLogger.info(e) + self.diameterLibLogger.info("Minor getting subscriber details for IMSI " + str(imsi)) + self.diameterLibLogger.info(e) #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - prom_diam_auth_event_count.labels( - diameter_application_id = 16777251, - diameter_cmd_code = 318, - event='Unknown User', - imsi_prefix = str(imsi[0:6]), - ).inc() - - DiameterLogger.info("Subscriber " + str(imsi) + " is unknown in database") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777251, + # diameter_cmd_code = 318, + # event='Unknown User', + # imsi_prefix = str(imsi[0:6]), + # ).inc() + + self.diameterLibLogger.info("Subscriber " + str(imsi) + " is unknown in database") avp = '' session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set @@ -832,8 +869,8 @@ def Answer_16777251_318(self, packet_vars, avps): except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.diameterLibLogger.critical(message) + self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise @@ -841,19 +878,20 @@ def Answer_16777251_318(self, packet_vars, avps): requested_vectors = 1 for avp in avps: if avp['avp_code'] == 1408: - DiameterLogger.debug("AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP") + self.diameterLibLogger.debug("AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP") EUTRAN_Authentication_Info = avp['misc_data'] - DiameterLogger.debug("EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info)) + self.diameterLibLogger.debug("EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info)) for sub_avp in EUTRAN_Authentication_Info: #If resync request if sub_avp['avp_code'] == 1411: - DiameterLogger.debug("Re-Synchronization required - SQN is out of sync") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777251, - diameter_cmd_code = 318, - event='Resync', - imsi_prefix = str(imsi[0:6]), - ).inc() + self.diameterLibLogger.debug("Re-Synchronization required - SQN is out of sync") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777251, + # diameter_cmd_code = 318, + # event='Resync', + # imsi_prefix = str(imsi[0:6]), + # ).inc() auts = str(sub_avp['misc_data'])[32:] rand = str(sub_avp['misc_data'])[:32] rand = binascii.unhexlify(rand) @@ -862,16 +900,16 @@ def Answer_16777251_318(self, packet_vars, avps): #Get number of requested vectors if sub_avp['avp_code'] == 1410: - DiameterLogger.debug("Raw value of requested vectors is " + str(sub_avp['misc_data'])) + self.diameterLibLogger.debug("Raw value of requested vectors is " + str(sub_avp['misc_data'])) requested_vectors = int(sub_avp['misc_data'], 16) if requested_vectors >= 32: - DiameterLogger.info("Client has requested " + str(requested_vectors) + " vectors, limiting this to 32") + self.diameterLibLogger.info("Client has requested " + str(requested_vectors) + " vectors, limiting this to 32") requested_vectors = 32 - DiameterLogger.debug("Generating " + str(requested_vectors) + " vectors as requested") + self.diameterLibLogger.debug("Generating " + str(requested_vectors) + " vectors as requested") eutranvector_complete = '' while requested_vectors != 0: - DiameterLogger.debug("Generating vector number " + str(requested_vectors)) + self.diameterLibLogger.debug("Generating vector number " + str(requested_vectors)) plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) @@ -896,8 +934,8 @@ def Answer_16777251_318(self, packet_vars, avps): #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated AIA") - DiameterLogger.debug(response) + self.diameterLibLogger.debug("Successfully Generated AIA") + self.diameterLibLogger.debug(response) return response #Purge UE Answer (PUA) @@ -931,7 +969,7 @@ def Answer_16777251_321(self, packet_vars, avps): database.Update_Serving_MME(imsi, None) - DiameterLogger.debug("Successfully Generated PUA") + self.diameterLibLogger.debug("Successfully Generated PUA") return response #Notify Answer (NOA) @@ -952,7 +990,7 @@ def Answer_16777251_323(self, packet_vars, avps): SupportedFeatures += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP response = self.generate_diameter_packet("01", "40", 323, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.debug("Successfully Generated PUA") + self.diameterLibLogger.debug("Successfully Generated PUA") return response #3GPP Gx Credit Control Answer @@ -960,9 +998,9 @@ def Answer_16777238_272(self, packet_vars, avps): CC_Request_Type = self.get_avp_data(avps, 416)[0] CC_Request_Number = self.get_avp_data(avps, 415)[0] #Called Station ID - DiameterLogger.debug("Attempting to find APN in CCR") + self.diameterLibLogger.debug("Attempting to find APN in CCR") apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') - DiameterLogger.debug("CCR for APN " + str(apn)) + self.diameterLibLogger.debug("CCR for APN " + str(apn)) OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it @@ -975,8 +1013,8 @@ def Answer_16777238_272(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) + self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) + remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -989,36 +1027,36 @@ def Answer_16777238_272(self, packet_vars, avps): #Get Subscriber info from Subscription ID for SubscriptionIdentifier in self.get_avp_data(avps, 443): for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - DiameterLogger.debug("Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI") + self.diameterLibLogger.debug("Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI") if UniqueSubscriptionIdentifier['avp_code'] == 444: imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - DiameterLogger.debug("Found IMSI " + str(imsi)) + self.diameterLibLogger.debug("Found IMSI " + str(imsi)) - DiameterLogger.info("SubscriptionID: " + str(self.get_avp_data(avps, 443))) + self.diameterLibLogger.info("SubscriptionID: " + str(self.get_avp_data(avps, 443))) try: - DiameterLogger.info("Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database") #Get subscriber details + self.diameterLibLogger.info("Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database") #Get subscriber details ChargingRules = database.Get_Charging_Rules(imsi=imsi, apn=apn) - DiameterLogger.info("Got Charging Rules: " + str(ChargingRules)) + self.diameterLibLogger.info("Got Charging Rules: " + str(ChargingRules)) except Exception as E: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - DiameterLogger.debug(E) - DiameterLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists") + self.diameterLibLogger.debug(E) + self.diameterLibLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists") if int(CC_Request_Type) == 1: - DiameterLogger.info("Request type for CCA is 1 - Initial") + self.diameterLibLogger.info("Request type for CCA is 1 - Initial") #Get UE IP try: ue_ip = self.get_avp_data(avps, 8)[0] ue_ip = str(self.hex_to_ip(ue_ip)) except Exception as E: - DiameterLogger.error("Failed to get UE IP") - DiameterLogger.error(E) + self.diameterLibLogger.error("Failed to get UE IP") + self.diameterLibLogger.error(E) ue_ip = 'Failed to Decode / Get UE IP' #Store PGW location into Database - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) + remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) #Supported-Features(628) (Gx feature list) @@ -1027,7 +1065,7 @@ def Answer_16777238_272(self, packet_vars, avps): #Default EPS Beaerer QoS (From database with fallback source CCR-I) try: apn_data = ChargingRules['apn_data'] - DiameterLogger.debug("Setting APN AMBR") + self.diameterLibLogger.debug("Setting APN AMBR") #AMBR AMBR = '' #Initiate empty var AVP for AMBR apn_ambr_ul = int(apn_data['apn_ambr_ul']) @@ -1036,7 +1074,7 @@ def Answer_16777238_272(self, packet_vars, avps): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") + self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) @@ -1045,13 +1083,13 @@ def Answer_16777238_272(self, packet_vars, avps): AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) except Exception as E: - DiameterLogger.error(E) - DiameterLogger.error("Failed to populate default_EPS_QoS from DB for sub " + str(imsi)) + self.diameterLibLogger.error(E) + self.diameterLibLogger.error("Failed to populate default_EPS_QoS from DB for sub " + str(imsi)) default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - DiameterLogger.info("Creating QoS Information") + self.diameterLibLogger.info("Creating QoS Information") #QoS-Information try: apn_data = ChargingRules['apn_data'] @@ -1059,38 +1097,38 @@ def Answer_16777238_272(self, packet_vars, avps): apn_ambr_dl = int(apn_data['apn_ambr_dl']) QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - DiameterLogger.info("Created both QoS AVPs from data from Database") - DiameterLogger.info("Populated QoS_Information") + self.diameterLibLogger.info("Created both QoS AVPs from data from Database") + self.diameterLibLogger.info("Populated QoS_Information") avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) except Exception as E: - DiameterLogger.error("Failed to get QoS information dynamically for sub " + str(imsi)) - DiameterLogger.error(E) + self.diameterLibLogger.error("Failed to get QoS information dynamically for sub " + str(imsi)) + self.diameterLibLogger.error(E) QoS_Information = '' for AMBR_Part in self.get_avp_data(avps, 1016)[0]: - DiameterLogger.debug(AMBR_Part) + self.diameterLibLogger.debug(AMBR_Part) AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) QoS_Information += AMBR_AVP - DiameterLogger.debug("QoS_Information added " + str(AMBR_AVP)) + self.diameterLibLogger.debug("QoS_Information added " + str(AMBR_AVP)) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - DiameterLogger.debug("QoS information set statically") + self.diameterLibLogger.debug("QoS information set statically") - DiameterLogger.info("Added to AVP List") - DiameterLogger.debug("QoS Information: " + str(QoS_Information)) + self.diameterLibLogger.info("Added to AVP List") + self.diameterLibLogger.debug("QoS Information: " + str(QoS_Information)) #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I if ChargingRules and ChargingRules['charging_rules'] is not None: try: - DiameterLogger.debug(ChargingRules) + self.diameterLibLogger.debug(ChargingRules) for individual_charging_rule in ChargingRules['charging_rules']: - DiameterLogger.debug("Processing Charging Rule: " + str(individual_charging_rule)) + self.diameterLibLogger.debug("Processing Charging Rule: " + str(individual_charging_rule)) avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) except Exception as E: - DiameterLogger.debug("Error in populating dynamic charging rules: " + str(E)) + self.diameterLibLogger.debug("Error in populating dynamic charging rules: " + str(E)) elif int(CC_Request_Type) == 3: - DiameterLogger.info("Request type for CCA is 3 - Termination") + self.diameterLibLogger.info("Request type for CCA is 3 - Termination") database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -1121,26 +1159,27 @@ def Answer_16777216_300(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) + self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) try: - DiameterLogger.info("Checking if username present") + self.diameterLibLogger.info("Checking if username present") username = self.get_avp_data(avps, 1)[0] username = binascii.unhexlify(username).decode('utf-8') - DiameterLogger.info("Username AVP is present, value is " + str(username)) + self.diameterLibLogger.info("Username AVP is present, value is " + str(username)) imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part - DiameterLogger.debug("Extracted imsi: " + str(imsi) + " now checking backend for this IMSI") + self.diameterLibLogger.debug("Extracted imsi: " + str(imsi) + " now checking backend for this IMSI") ims_subscriber_details = database.Get_IMS_Subscriber(imsi=imsi) except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_300() input") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 300, - event='Unknown User', - imsi_prefix = str(imsi[0:6]), - ).inc() + self.diameterLibLogger.error("Threw Exception: " + str(E)) + self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_300() input") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777216, + # diameter_cmd_code = 300, + # event='Unknown User', + # imsi_prefix = str(imsi[0:6]), + # ).inc() result_code = 5001 #IMS User Unknown #Experimental Result AVP avp_experimental_result = '' @@ -1155,9 +1194,9 @@ def Answer_16777216_300(self, packet_vars, avps): if user_authorization_type_avp_data: try: User_Authorization_Type = int(user_authorization_type_avp_data[0]) - DiameterLogger.debug("User_Authorization_Type is: " + str(User_Authorization_Type)) + self.diameterLibLogger.debug("User_Authorization_Type is: " + str(User_Authorization_Type)) if (User_Authorization_Type == 1): - DiameterLogger.debug("This is Deregister") + self.diameterLibLogger.debug("This is Deregister") database.Update_Serving_CSCF(imsi, serving_cscf=None) #Populate S-CSCF Address avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) @@ -1166,28 +1205,28 @@ def Answer_16777216_300(self, packet_vars, avps): return response except Exception as E: - DiameterLogger.debug("Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E)) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.diameterLibLogger.debug("Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E)) + self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) if ims_subscriber_details['scscf'] != None: - DiameterLogger.debug("Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf'])) + self.diameterLibLogger.debug("Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf'])) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) experimental_avp = '' experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2002),"x").zfill(8)) #DIAMETER_SUBSEQUENT_REGISTRATION (2002) avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result else: - DiameterLogger.debug("No SCSCF Assigned from DB") - if 'scscf_pool' in yaml_config['hss']: + self.diameterLibLogger.debug("No SCSCF Assigned from DB") + if 'scscf_pool' in self.yaml_config['hss']: try: - scscf = random.choice(yaml_config['hss']['scscf_pool']) - DiameterLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") + scscf = random.choice(self.yaml_config['hss']['scscf_pool']) + self.diameterLibLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated S-CSCF Address as failed to source from list due to " + str(E)) + self.diameterLibLogger.info("Using generated S-CSCF Address as failed to source from list due to " + str(E)) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated S-CSCF Address as none set in scscf_pool in config") + self.diameterLibLogger.info("Using generated S-CSCF Address as none set in scscf_pool in config") experimental_avp = '' experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2001),"x").zfill(8)) #DIAMETER_FIRST_REGISTRATION (2001) @@ -1220,18 +1259,18 @@ def Answer_16777216_301(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - DiameterLogger.debug("Remote Peer is " + str(remote_peer)) + self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) try: - DiameterLogger.info("Checking if username present") + self.diameterLibLogger.info("Checking if username present") username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) - DiameterLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) imsi = ims_subscriber_details['imsi'] domain = "ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org" except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_301() input") + self.diameterLibLogger.error("Threw Exception: " + str(E)) + self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_301() input") result_code = 5005 #Experimental Result AVP avp_experimental_result = '' @@ -1247,7 +1286,7 @@ def Answer_16777216_301(self, packet_vars, avps): #This loads a Jinja XML template as the default iFC templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - DiameterLogger.debug("Loading iFC from path " + str(ims_subscriber_details['ifc_path'])) + self.diameterLibLogger.debug("Loading iFC from path " + str(ims_subscriber_details['ifc_path'])) template = templateEnv.get_template(ims_subscriber_details['ifc_path']) #These variables are passed to the template for use @@ -1264,16 +1303,16 @@ def Answer_16777216_301(self, packet_vars, avps): #Determine SAR Type & Store Server_Assignment_Type_Hex = self.get_avp_data(avps, 614)[0] Server_Assignment_Type = self.hex_to_int(Server_Assignment_Type_Hex) - DiameterLogger.debug("Server-Assignment-Type is: " + str(Server_Assignment_Type)) + self.diameterLibLogger.debug("Server-Assignment-Type is: " + str(Server_Assignment_Type)) ServingCSCF = self.get_avp_data(avps, 602)[0] #Get OriginHost from AVP ServingCSCF = binascii.unhexlify(ServingCSCF).decode('utf-8') #Format it - DiameterLogger.debug("Subscriber is served by S-CSCF " + str(ServingCSCF)) + self.diameterLibLogger.debug("Subscriber is served by S-CSCF " + str(ServingCSCF)) if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): - DiameterLogger.debug("SAR is Register / Re-Restister") - remote_peer = remote_peer + ";" + str(yaml_config['hss']['OriginHost']) + self.diameterLibLogger.debug("SAR is Register / Re-Restister") + remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) else: - DiameterLogger.debug("SAR is not Register") + self.diameterLibLogger.debug("SAR is not Register") database.Update_Serving_CSCF(imsi, serving_cscf=None) avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) @@ -1294,36 +1333,37 @@ def Answer_16777216_302(self, packet_vars, avps): try: - DiameterLogger.info("Checking if username present") + self.diameterLibLogger.info("Checking if username present") username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) if ims_subscriber_details['scscf'] != None: - DiameterLogger.debug("Got SCSCF on record for Sub") + self.diameterLibLogger.debug("Got SCSCF on record for Sub") #Strip double sip prefix avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(str(ims_subscriber_details['scscf']))),'ascii')) else: - DiameterLogger.debug("No SCSF assigned - Using SCSCF Pool") - if 'scscf_pool' in yaml_config['hss']: + self.diameterLibLogger.debug("No SCSF assigned - Using SCSCF Pool") + if 'scscf_pool' in self.yaml_config['hss']: try: - scscf = random.choice(yaml_config['hss']['scscf_pool']) - DiameterLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") + scscf = random.choice(self.yaml_config['hss']['scscf_pool']) + self.diameterLibLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated iFC as failed to source from list due to " + str(E)) + self.diameterLibLogger.info("Using generated iFC as failed to source from list due to " + str(E)) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - DiameterLogger.info("Using generated iFC") + self.diameterLibLogger.info("Using generated iFC") except Exception as E: - DiameterLogger.error("Threw Exception: " + str(E)) - DiameterLogger.error("No known MSISDN or IMSI in Answer_16777216_302() input") + self.diameterLibLogger.error("Threw Exception: " + str(E)) + self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_302() input") result_code = 5001 - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 302, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777216, + # diameter_cmd_code = 302, + # event='Unknown User', + # imsi_prefix = str(username[0:6]), + # ).inc() #Experimental Result AVP avp_experimental_result = '' avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID @@ -1341,12 +1381,12 @@ def Answer_16777216_302(self, packet_vars, avps): def Answer_16777216_303(self, packet_vars, avps): public_identity = self.get_avp_data(avps, 601)[0] public_identity = binascii.unhexlify(public_identity).decode('utf-8') - DiameterLogger.debug("Got MAR for public_identity : " + str(public_identity)) + self.diameterLibLogger.debug("Got MAR for public_identity : " + str(public_identity)) username = self.get_avp_data(avps, 1)[0] username = binascii.unhexlify(username).decode('utf-8') imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part - DiameterLogger.debug("Got MAR username: " + str(username)) + self.diameterLibLogger.debug("Got MAR username: " + str(username)) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1360,13 +1400,14 @@ def Answer_16777216_303(self, packet_vars, avps): subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details except: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - DiameterLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for MAA") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 303, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() + self.diameterLibLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for MAA") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777216, + # diameter_cmd_code = 303, + # event='Unknown User', + # imsi_prefix = str(username[0:6]), + # ).inc() experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #Result Code (DIAMETER ERROR - User Unknown) experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") #Experimental Result (297) @@ -1374,7 +1415,7 @@ def Answer_16777216_303(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response - DiameterLogger.debug("Got subscriber data for MAA OK") + self.diameterLibLogger.debug("Got subscriber data for MAA OK") mcc, mnc = imsi[0:3], imsi[3:5] plmn = self.EncodePLMN(mcc, mnc) @@ -1382,31 +1423,32 @@ def Answer_16777216_303(self, packet_vars, avps): #Determine if SQN Resync is required & auth type to use for sub_avp_612 in self.get_avp_data(avps, 612)[0]: if sub_avp_612['avp_code'] == 610: - DiameterLogger.info("SQN in HSS is out of sync - Performing resync") + self.diameterLibLogger.info("SQN in HSS is out of sync - Performing resync") auts = str(sub_avp_612['misc_data'])[32:] rand = str(sub_avp_612['misc_data'])[:32] rand = binascii.unhexlify(rand) database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - DiameterLogger.debug("Resynced SQN in DB") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 302, - event='ReAuth', - imsi_prefix = str(imsi[0:6]), - ).inc() + self.diameterLibLogger.debug("Resynced SQN in DB") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777216, + # diameter_cmd_code = 302, + # event='ReAuth', + # imsi_prefix = str(imsi[0:6]), + # ).inc() if sub_avp_612['avp_code'] == 608: - DiameterLogger.info("Auth mechansim requested: " + str(sub_avp_612['misc_data'])) + self.diameterLibLogger.info("Auth mechansim requested: " + str(sub_avp_612['misc_data'])) auth_scheme = binascii.unhexlify(sub_avp_612['misc_data']).decode('utf-8') - DiameterLogger.info("Auth mechansim requested: " + str(auth_scheme)) + self.diameterLibLogger.info("Auth mechansim requested: " + str(auth_scheme)) - DiameterLogger.debug("IMSI is " + str(imsi)) + self.diameterLibLogger.debug("IMSI is " + str(imsi)) avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi + "@" + domain)),'ascii')) #Username #Determine Vectors to Generate if auth_scheme == "Digest-MD5": - DiameterLogger.debug("Generating MD5 Challenge") + self.diameterLibLogger.debug("Generating MD5 Challenge") vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "Digest-MD5", username=imsi, plmn=plmn) avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-MD5'),'ascii')) @@ -1416,7 +1458,7 @@ def Answer_16777216_303(self, packet_vars, avps): avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(str.encode(vector_dict['SIP_Authenticate'])),'ascii')) auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization else: - DiameterLogger.debug("Generating AKA-MD5 Auth Challenge") + self.diameterLibLogger.debug("Generating AKA-MD5 Auth Challenge") vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "sip_auth", plmn=plmn) @@ -1456,7 +1498,7 @@ def Respond_ResultCode(self, packet_vars, avps, result_code): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID except: - DiameterLogger.info("Failed to add SessionID into error") + self.diameterLibLogger.info("Failed to add SessionID into error") for avps_to_check in avps: #Only include AVP 260 (Vendor-Specific-Application-ID) if inital request included it if avps_to_check['avp_code'] == 260: concat_subavp = '' @@ -1480,9 +1522,9 @@ def Answer_16777216_304(self, packet_vars, avps): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID vendor_id = self.generate_avp(266, 40, str(binascii.hexlify('10415'),'ascii')) - DiameterLogger.debug("vendor_id avp: " + str(vendor_id)) + self.diameterLibLogger.debug("vendor_id avp: " + str(vendor_id)) auth_application_id = self.generate_avp(248, 40, self.int_to_hex(16777252, 8)) - DiameterLogger.debug("auth_application_id: " + auth_application_id) + self.diameterLibLogger.debug("auth_application_id: " + auth_application_id) avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx avp += self.generate_avp(268, 40, "000007d1") #Result Code - DIAMETER_SUCCESS avp += self.generate_avp(277, 40, "00000001") #Auth Session State @@ -1508,29 +1550,30 @@ def Answer_16777217_306(self, packet_vars, avps): try: user_identity_avp = self.get_avp_data(avps, 700)[0] msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request - DiameterLogger.info("Got raw MSISDN with value " + str(msisdn)) + self.diameterLibLogger.info("Got raw MSISDN with value " + str(msisdn)) msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) + self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) except: - DiameterLogger.error("No MSISDN") + self.diameterLibLogger.error("No MSISDN") if msisdn is not None: - DiameterLogger.debug("Getting susbcriber IMS info based on MSISDN") + self.diameterLibLogger.debug("Getting susbcriber IMS info based on MSISDN") subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - DiameterLogger.debug("Getting susbcriber info based on MSISDN") + self.diameterLibLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) + self.diameterLibLogger.debug("Getting susbcriber info based on MSISDN") subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber details: " + str(subscriber_details)) + self.diameterLibLogger.debug("Got subscriber details: " + str(subscriber_details)) subscriber_details = {**subscriber_details, **subscriber_ims_details} - DiameterLogger.debug("Merged subscriber details: " + str(subscriber_details)) + self.diameterLibLogger.debug("Merged subscriber details: " + str(subscriber_details)) else: - DiameterLogger.error("No MSISDN or IMSI in Answer_16777217_306() input") - prom_diam_auth_event_count.labels( - diameter_application_id = 16777216, - diameter_cmd_code = 306, - event='Unknown User', - imsi_prefix = str(username[0:6]), - ).inc() + self.diameterLibLogger.error("No MSISDN or IMSI in Answer_16777217_306() input") + #@@Fixme + # prom_diam_auth_event_count.labels( + # diameter_application_id = 16777216, + # diameter_cmd_code = 306, + # event='Unknown User', + # imsi_prefix = str(username[0:6]), + # ).inc() result_code = 5005 #Experimental Result AVP avp_experimental_result = '' @@ -1552,14 +1595,14 @@ def Answer_16777217_306(self, packet_vars, avps): #This loads a Jinja XML template containing the Sh-User-Data templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = yaml_config['hss']['Default_Sh_UserData'] - DiameterLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") + sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] + self.diameterLibLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) subscriber_details['mcc'] = self.MCC.zfill(3) - DiameterLogger.debug("Rendering template with values: " + str(subscriber_details)) + self.diameterLibLogger.debug("Rendering template with values: " + str(subscriber_details)) xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) @@ -1581,7 +1624,7 @@ def Answer_16777217_307(self, packet_vars, avps): sh_user_data = self.get_avp_data(avps, 702)[0] #Get IMSI from User-Name AVP in request sh_user_data = binascii.unhexlify(sh_user_data).decode('utf-8') - DiameterLogger.debug("Got Sh User data: " + str(sh_user_data)) + self.diameterLibLogger.debug("Got Sh User data: " + str(sh_user_data)) #Push updated User Data into IMS Backend #Start with the Current User Data @@ -1612,17 +1655,17 @@ def Answer_16777252_324(self, packet_vars, avps): imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - DiameterLogger.info("Got IMSI with value " + str(imsi)) + self.diameterLibLogger.info("Got IMSI with value " + str(imsi)) except Exception as e: - DiameterLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) + self.diameterLibLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") + self.diameterLibLogger.debug("Error was: " + str(e)) #Get IMEI for sub_avp in self.get_avp_data(avps, 1401)[0]: - DiameterLogger.debug("Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI") + self.diameterLibLogger.debug("Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI") if sub_avp['avp_code'] == 1402: imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') - DiameterLogger.debug("Found IMEI " + str(imei)) + self.diameterLibLogger.debug("Found IMEI " + str(imei)) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1640,7 +1683,8 @@ def Answer_16777252_324(self, packet_vars, avps): #Equipment-Status EquipmentStatus = database.Check_EIR(imsi=imsi, imei=imei) avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) - prom_diam_eir_event_count.labels(response=EquipmentStatus).inc() + # @@Fixme + # prom_diam_eir_event_count.labels(response=EquipmentStatus).inc() response = self.generate_diameter_packet("01", "40", 324, 16777252, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response @@ -1670,56 +1714,56 @@ def Answer_16777291_8388622(self, packet_vars, avps): #Try and get IMSI if present if 1 in present_avps: - DiameterLogger.info("IMSI AVP is present") + self.diameterLibLogger.info("IMSI AVP is present") try: imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - DiameterLogger.info("Got IMSI with value " + str(imsi)) + self.diameterLibLogger.info("Got IMSI with value " + str(imsi)) except Exception as e: - DiameterLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) + self.diameterLibLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") + self.diameterLibLogger.debug("Error was: " + str(e)) elif 701 in present_avps: #Try and get MSISDN if present try: msisdn = self.get_avp_data(avps, 701)[0] #Get MSISDN from AVP in request - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) + self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) avp += self.generate_vendor_avp(701, 'c0', 10415, self.get_avp_data(avps, 701)[0]) #MSISDN - DiameterLogger.info("Got MSISDN with encoded value " + str(msisdn)) + self.diameterLibLogger.info("Got MSISDN with encoded value " + str(msisdn)) msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with decoded value " + str(msisdn)) + self.diameterLibLogger.info("Got MSISDN with decoded value " + str(msisdn)) except Exception as e: - DiameterLogger.debug("Failed to get MSISDN from LCS-Routing-Info-Request") - DiameterLogger.debug("Error was: " + str(e)) + self.diameterLibLogger.debug("Failed to get MSISDN from LCS-Routing-Info-Request") + self.diameterLibLogger.debug("Error was: " + str(e)) else: - DiameterLogger.error("No MSISDN or IMSI") + self.diameterLibLogger.error("No MSISDN or IMSI") try: if imsi is not None: - DiameterLogger.debug("Getting susbcriber location based on IMSI") + self.diameterLibLogger.debug("Getting susbcriber location based on IMSI") subscriber_details = database.Get_Subscriber(imsi=imsi) - DiameterLogger.debug("Got subscriber_details from IMSI: " + str(subscriber_details)) + self.diameterLibLogger.debug("Got subscriber_details from IMSI: " + str(subscriber_details)) elif msisdn is not None: - DiameterLogger.debug("Getting susbcriber location based on MSISDN") + self.diameterLibLogger.debug("Getting susbcriber location based on MSISDN") subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber_details from MSISDN: " + str(subscriber_details)) + self.diameterLibLogger.debug("Got subscriber_details from MSISDN: " + str(subscriber_details)) except Exception as E: - DiameterLogger.error("No MSISDN or IMSI returned in Answer_16777291_8388622 input") - DiameterLogger.error("Error is " + str(E)) - DiameterLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") + self.diameterLibLogger.error("No MSISDN or IMSI returned in Answer_16777291_8388622 input") + self.diameterLibLogger.error("Error is " + str(E)) + self.diameterLibLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - DiameterLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") + self.diameterLibLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") return response - DiameterLogger.info("Got subscriber_details for subscriber: " + str(subscriber_details)) + self.diameterLibLogger.info("Got subscriber_details for subscriber: " + str(subscriber_details)) if subscriber_details['serving_mme'] == None: #DB has no location on record for subscriber - DiameterLogger.info("No location on record for Subscriber") + self.diameterLibLogger.info("No location on record for Subscriber") result_code = 4201 #DIAMETER_ERROR_ABSENT_USER (4201) #This result code shall be sent by the HSS to indicate that the location of the targeted user is not known at this time to @@ -1739,7 +1783,7 @@ def Answer_16777291_8388622(self, packet_vars, avps): avp_serving_node = '' avp_serving_node += self.generate_vendor_avp(2402, "c0", 10415, self.string_to_hex(subscriber_details['serving_mme'])) #MME-Name avp_serving_node += self.generate_vendor_avp(2408, "c0", 10415, self.OriginRealm) #MME-Realm - avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(yaml_config['hss']['bind_ip'][0])) #GMLC-Address + avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(self.yaml_config['hss']['bind_ip'][0])) #GMLC-Address avp += self.generate_vendor_avp(2401, "c0", 10415, avp_serving_node) #Serving-Node AVP #Set Result-Code @@ -1818,7 +1862,7 @@ def Request_16777251_316(self, imsi, DestinationRealm): sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + yaml_config['hss']['OriginHost'])),'ascii')) + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.yaml_config['hss']['OriginHost'])),'ascii')) avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) @@ -1899,7 +1943,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags if 'GetLocation' in kwargs: - DiameterLogger.debug("Requsted Get Location ISD") + self.diameterLibLogger.debug("Requsted Get Location ISD") #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP SupportedFeatures = '' SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID @@ -1910,23 +1954,23 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): try: user_identity_avp = self.get_avp_data(avps, 700)[0] - DiameterLogger.info(user_identity_avp) + self.diameterLibLogger.info(user_identity_avp) msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request msisdn = self.TBCD_decode(msisdn) - DiameterLogger.info("Got MSISDN with value " + str(msisdn)) + self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) except: - DiameterLogger.error("No MSISDN present") + self.diameterLibLogger.error("No MSISDN present") return #Get Subscriber Location from Database subscriber_location = database.GetSubscriberLocation(msisdn=msisdn) - DiameterLogger.debug("Got subscriber location: " + subscriber_location) + self.diameterLibLogger.debug("Got subscriber location: " + subscriber_location) - DiameterLogger.info("Getting IMSI for MSISDN " + str(msisdn)) + self.diameterLibLogger.info("Getting IMSI for MSISDN " + str(msisdn)) imsi = database.Get_IMSI_from_MSISDN(msisdn) avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - DiameterLogger.info("Got back location data: " + str(subscriber_location)) + self.diameterLibLogger.info("Got back location data: " + str(subscriber_location)) #Populate Destination Host & Realm avp += self.generate_avp(293, 40, self.string_to_hex(subscriber_location)) #Destination Host #Destination-Host @@ -1942,10 +1986,10 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): destinationHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP destinationHost = binascii.unhexlify(destinationHost).decode('utf-8') #Format it - DiameterLogger.debug("Received originHost to use as destinationHost is " + str(destinationHost)) + self.diameterLibLogger.debug("Received originHost to use as destinationHost is " + str(destinationHost)) destinationRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP destinationRealm = binascii.unhexlify(destinationRealm).decode('utf-8') #Format it - DiameterLogger.debug("Received originRealm to use as destinationRealm is " + str(destinationRealm)) + self.diameterLibLogger.debug("Received originRealm to use as destinationRealm is " + str(destinationRealm)) avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination-Host avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) @@ -1954,14 +1998,14 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): try: subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details except ValueError as e: - DiameterLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - DiameterLogger.error("Error is " + str(e)) + self.diameterLibLogger.error("failed to get data backfrom database for imsi " + str(imsi)) + self.diameterLibLogger.error("Error is " + str(e)) raise except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - DiameterLogger.critical(message) - DiameterLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.diameterLibLogger.critical(message) + self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise @@ -1998,18 +2042,18 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): apn_list = subscriber_details['pdn'] - DiameterLogger.debug("APN list: " + str(apn_list)) + self.diameterLibLogger.debug("APN list: " + str(apn_list)) APN_context_identifer_count = 1 for apn_profile in apn_list: - DiameterLogger.debug("Processing APN profile " + str(apn_profile)) + self.diameterLibLogger.debug("Processing APN profile " + str(apn_profile)) APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_profile['apn']))) - DiameterLogger.debug("Setting APN Configuration Profile") + self.diameterLibLogger.debug("Setting APN Configuration Profile") #Sub AVPs of APN Configuration Profile APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(0, 4)) - DiameterLogger.debug("Setting APN AMBR") + self.diameterLibLogger.debug("Setting APN AMBR") #AMBR AMBR = '' #Initiate empty var AVP for AMBR if 'AMBR' in apn_profile: @@ -2024,7 +2068,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - DiameterLogger.debug("Setting APN Allocation-Retention-Priority") + self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['priority_level']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_capability']), 4)) @@ -2037,32 +2081,32 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #If static UE IP is specified try: apn_ip = apn_profile['ue']['addr'] - DiameterLogger.debug("Found static IP for UE " + str(apn_ip)) + self.diameterLibLogger.debug("Found static IP for UE " + str(apn_ip)) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(apn_ip)) except: Served_Party_Address = "" if 'MIP6-Agent-Info' in apn_profile: - DiameterLogger.info("MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info'])) + self.diameterLibLogger.info("MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info'])) MIP6_Destination_Host = self.generate_avp(293, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_HOST']))) MIP6_Destination_Realm = self.generate_avp(283, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_REALM']))) MIP6_Home_Agent_Host = self.generate_avp(348, '40', MIP6_Destination_Host + MIP6_Destination_Realm) MIP6_Agent_Info = self.generate_avp(486, '40', MIP6_Home_Agent_Host) - DiameterLogger.info("MIP6 value is " + str(MIP6_Agent_Info)) + self.diameterLibLogger.info("MIP6 value is " + str(MIP6_Agent_Info)) else: MIP6_Agent_Info = '' if 'PDN_GW_Allocation_Type' in apn_profile: - DiameterLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) + self.diameterLibLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - DiameterLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) + self.diameterLibLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) else: PDN_GW_Allocation_Type = '' if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - DiameterLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) + self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - DiameterLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) + self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) else: VPLMN_Dynamic_Address_Allowed = '' @@ -2073,7 +2117,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #Incriment Context Identifier Count to keep track of how many APN Profiles returned APN_context_identifer_count = APN_context_identifer_count + 1 - DiameterLogger.debug("Processed APN profile " + str(apn_profile['apn'])) + self.diameterLibLogger.debug("Processed APN profile " + str(apn_profile['apn'])) subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(720, 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_context_identifer + \ @@ -2081,26 +2125,26 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #If MSISDN is present include it in Subscription Data if 'msisdn' in subscriber_details: - DiameterLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") + self.diameterLibLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, str(subscriber_details['msisdn'])) #MSISDN - DiameterLogger.debug(msisdn_avp) + self.diameterLibLogger.debug(msisdn_avp) subscription_data += msisdn_avp if 'RAT_freq_priorityID' in subscriber_details: - DiameterLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") + self.diameterLibLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - DiameterLogger.debug(rat_freq_priorityID) + self.diameterLibLogger.debug(rat_freq_priorityID) subscription_data += rat_freq_priorityID if '3gpp-charging-characteristics' in subscriber_details: - DiameterLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA") + self.diameterLibLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA") _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, self.string_to_hex(str(subscriber_details['3gpp-charging-characteristics']))) subscription_data += _3gpp_charging_characteristics - DiameterLogger.debug(_3gpp_charging_characteristics) + self.diameterLibLogger.debug(_3gpp_charging_characteristics) if 'APN_OI_replacement' in subscriber_details: - DiameterLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") + self.diameterLibLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) @@ -2153,7 +2197,7 @@ def Request_16777216_301(self, imsi, domain, server_assignment_type): avp = '' #Initiate empty var AVP #Session-ID sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session Session ID - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + yaml_config['hss']['OriginHost'])),'ascii')) #Origin Host + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.yaml_config['hss']['OriginHost'])),'ascii')) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx @@ -2382,7 +2426,7 @@ def Request_16777238_258(self, sessionid, ChargingRules, ue_ip, Serving_PGW, Ser avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP #Setup Charging Rule - DiameterLogger.debug(ChargingRules) + self.diameterLibLogger.debug(ChargingRules) avp += self.Charging_Rule_Generator(ChargingRules=ChargingRules, ue_ip=ue_ip) @@ -2476,14 +2520,14 @@ def Request_16777217_307(self, msisdn): avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - DiameterLogger.debug("Getting susbcriber IMS info based on MSISDN") + self.diameterLibLogger.debug("Getting susbcriber IMS info based on MSISDN") subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - DiameterLogger.debug("Getting susbcriber info based on MSISDN") + self.diameterLibLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) + self.diameterLibLogger.debug("Getting susbcriber info based on MSISDN") subscriber_details = database.Get_Subscriber(msisdn=msisdn) - DiameterLogger.debug("Got subscriber details: " + str(subscriber_details)) + self.diameterLibLogger.debug("Got subscriber details: " + str(subscriber_details)) subscriber_details = {**subscriber_details, **subscriber_ims_details} - DiameterLogger.debug("Merged subscriber details: " + str(subscriber_details)) + self.diameterLibLogger.debug("Merged subscriber details: " + str(subscriber_details)) avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(subscriber_details['imsi'])),'ascii')) #Username AVP @@ -2492,14 +2536,14 @@ def Request_16777217_307(self, msisdn): #This loads a Jinja XML template containing the Sh-User-Data templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = yaml_config['hss']['Default_Sh_UserData'] - DiameterLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") + sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] + self.diameterLibLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) subscriber_details['mcc'] = self.MCC.zfill(3) - DiameterLogger.debug("Rendering template with values: " + str(subscriber_details)) + self.diameterLibLogger.debug("Rendering template with values: " + str(subscriber_details)) xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) diff --git a/lib/logtool.py b/lib/logtool.py index ac341a2..b6f0773 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -2,242 +2,28 @@ import logging.handlers as handlers import os import sys -import inspect sys.path.append(os.path.realpath('../')) -import yaml -from datetime import datetime as log_dt -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -import json -import pickle - -from prometheus_client import Counter, Gauge, Histogram, Summary - -from prometheus_client import start_http_server - -if yaml_config['prometheus']['enabled'] == True: - #Check if this is the HSS service, and if it's not increment the port before starting - print(sys.argv[0]) - if 'hss.py' in str(sys.argv[0]): - print("Starting Prometheus on port from config " + str(yaml_config['prometheus']['port'])) - else: - print("This is not the HSS stack so offsetting Prometheus port") - yaml_config['prometheus']['port'] += 1 - try: - start_http_server(yaml_config['prometheus']['port']) - print("Started Prometheus on port " + str(yaml_config['prometheus']['port'])) - except Exception as E: - print("Error loading Prometheus") - print(E) - - -tags = ['diameter_application_id', 'diameter_cmd_code', 'endpoint', 'type'] -prom_diam_request_count = Counter('prom_diam_request_count', 'Number of Diameter Requests', tags) -prom_diam_response_count_successful = Counter('prom_diam_response_count_successful', 'Number of Successful Diameter Responses', tags) -prom_diam_response_count_fail = Counter('prom_diam_response_count_fail', 'Number of Failed Diameter Responses', tags) -prom_diam_connected_peers = Gauge('prom_diam_connected_peers', 'Connected Diameter Peer Count', ['endpoint']) -prom_diam_connected_peers._metrics.clear() -prom_diam_response_time_diam = Histogram('prom_diam_response_time_diam', 'Diameter Response Times') -prom_diam_response_time_method = Histogram('prom_diam_response_time_method', 'Diameter Response Times', tags) -prom_diam_response_time_db = Summary('prom_diam_response_time_db', 'Diameter Response Times from Database') -prom_diam_response_time_h = Histogram('request_latency_seconds', 'Diameter Response Time Histogram') -prom_diam_auth_event_count = Counter('prom_diam_auth_event_count', 'Diameter Authentication related Counters', ['diameter_application_id', 'diameter_cmd_code', 'event', 'imsi_prefix']) -prom_diam_eir_event_count = Counter('prom_diam_eir_event_count', 'Diameter EIR event related Counters', ['response']) - -prom_eir_devices = Counter('prom_eir_devices', 'Profile of attached devices', ['imei_prefix', 'device_type', 'device_name']) - -prom_http_geored = Counter('prom_http_geored', 'Number of Geored Pushes', ['geored_host', 'endpoint', 'http_response_code', 'error']) -prom_flask_http_geored_endpoints = Counter('prom_flask_http_geored_endpoints', 'Number of Geored Pushes Received', ['geored_host', 'endpoint']) - - -prom_pcrf_subs = Gauge('prom_pcrf_subs', 'Number of attached PCRF Subscribers') -prom_mme_subs = Gauge('prom_mme_subs', 'Number of attached MME Subscribers') -prom_ims_subs = Gauge('prom_ims_subs', 'Number of attached IMS Subscribers') class LogTool: - def __init__(self, **kwargs): - print("Instantiating LogTool with Kwargs " + str(kwargs.items())) - if yaml_config['redis']['enabled'] == True: - print("Redis support enabled") - import redis - redis_store = redis.Redis(host=str(yaml_config['redis']['host']), port=str(yaml_config['redis']['port']), db=0) - self.redis_store = redis_store - try: - if "HSS_Init" in kwargs: - print("Called Init for HSS_Init") - redis_store.incr('restart_count') - if yaml_config['redis']['clear_stats_on_boot'] == True: - logging.debug("Clearing ActivePeerDict") - redis_store.delete('ActivePeerDict') - else: - logging.debug("Leaving prexisting Redis keys") - #Clear ActivePeerDict - redis_store.delete('ActivePeerDict') - - #Clear Async Keys - for key in redis_store.scan_iter("*_request_queue"): - print("Deleting Key: " + str(key)) - redis_store.delete(key) - logging.info("Connected to Redis server") - else: - logging.info("Init of Logtool but not from HSS_Init") - except: - logging.error("Failed to connect to Redis server - Disabling") - yaml_config['redis']['enabled'] == False - - #function for handling incrimenting Redis counters with error handling - def RedisIncrimenter(self, name): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.incr(name) - except: - logging.error("failed to incriment " + str(name)) - - def RedisStore(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(key, value) - except: - logging.error("failed to set Redis key " + str(key) + " to value " + str(value)) - - def RedisGet(self, key): - if yaml_config['redis']['enabled'] == True: - try: - return self.redis_store.get(key) - except: - logging.error("failed to set Redis key " + str(key)) - - def RedisHMSET(self, key, value_dict): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.hmset(key, value_dict) - except: - logging.error("failed to set hm Redis key " + str(key) + " to value " + str(value_dict)) - - def Async_SendRequest(self, request, DiameterHostname): - if yaml_config['redis']['enabled'] == True: - try: - import time - print("Writing request to Queue '" + str(DiameterHostname) + "_request_queue'") - self.redis_store.hset(str(DiameterHostname) + "_request_queue", "hss_Async_client_" + str(int(time.time())), request) - print("Written to Queue to send.") - except Exception as E: - logging.error("failed to run Async_SendRequest to " + str(DiameterHostname)) - - def RedisHMGET(self, key): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Getting HM Get from " + str(key)) - data = self.redis_store.hgetall(key) - logging.debug("Result: " + str(data)) - return data - except: - logging.error("failed to get hm Redis key " + str(key)) - - def RedisHDEL(self, key, item): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Removing item " + str(item) + " from key " + str(key)) - self.redis_store.hdel(key, item) - except: - logging.error("failed to hdel Redis key " + str(key) + " item " + str(item)) - - def RedisStoreDict(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(str(key), pickle.dumps(value)) - except: - logging.error("failed to set Redis dict " + str(key) + " to value " + str(value)) - - def RedisGetDict(self, key): - if yaml_config['redis']['enabled'] == True: - try: - read_dict = self.redis_store.get(key) - return pickle.loads(read_dict) - except: - logging.error("failed to hmget Redis key " + str(key)) - - def GetDiameterPeers(self): - if yaml_config['redis']['enabled'] == True: - try: - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - return ActivePeerDict - except: - logging.error("Failed to get ActivePeerDict") - - - def Manage_Diameter_Peer(self, peername, ip, action): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Managing Diameter peer to Redis with hostname" + str(peername) + " and IP " + str(ip)) - now = log_dt.now() - timestamp = str(now.strftime("%Y-%m-%d %H:%M:%S")) - - #Try and get IP and Port seperately - try: - ip = ip[0] - port = ip[1] - except: - pass - - if self.redis_store.exists('ActivePeerDict') == False: - #Initialise empty active peer dict in Redis - logging.debug("Populated new empty ActivePeerDict Redis key") - ActivePeerDict = {} - ActivePeerDict['internal_connection'] = {"connect_timestamp" : timestamp} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "add": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict) + " to add peer " + str(peername) + " with ip " + str(ip)) - - - #If key has already existed in dict due to disconnect / reconnect, get reconnection count - try: - reconnection_count = ActivePeerDict[str(ip)]['reconnection_count'] + 1 - except: - reconnection_count = 0 - - ActivePeerDict[str(ip)] = {"connect_timestamp" : timestamp, \ - "recv_ip_address" : str(ip), "DiameterHostname" : "Unknown - Socket connection only", \ - "reconnection_count" : reconnection_count, - "connection_status" : "Pending"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "remove": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict)) - ActivePeerDict[str(ip)] = {"disconnect_timestamp" : str(timestamp), \ - "DiameterHostname" : str(ActivePeerDict[str(ip)]['DiameterHostname']), \ - "reconnection_count" : ActivePeerDict[str(ip)]['reconnection_count'], - "connection_status" : "Disconnected"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "update": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - ActivePeerDict[str(ip)]['DiameterHostname'] = str(peername) - ActivePeerDict[str(ip)]['last_dwr_timestamp'] = str(timestamp) - ActivePeerDict[str(ip)]['connection_status'] = "Connected" - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - except Exception as E: - logging.error("failed to add/update/remove Diameter peer from Redis") - logging.error(E) - - def setup_logger(self, logger_name, log_file, level=logging.DEBUG): - l = logging.getLogger(logger_name) - formatter = logging.Formatter('%(asctime)s \t %(levelname)s \t {%(pathname)s:%(lineno)d} \t %(message)s') - fileHandler = logging.FileHandler(log_file, mode='a+') - fileHandler.setFormatter(formatter) + def setupLogger(self, loggerName: str, config: dict): + logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_diameter.log') + logLevel = config.get('logging', {}).get('level', 'INFO') + logger = logging.getLogger(loggerName) + formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s {%(pathname)s:%(lineno)d} %(message)s", datefmt="%m/%d/%Y %H:%M:%S %Z") + try: + rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) + except PermissionError: + logFileName = logFile.split('/')[-1] + pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + print(f"[LogTool] Warning - Unable to write to {logFile}, using {pyhssRootDir}/log/{logFileName} instead.") + logFile = f"{pyhssRootDir}/log/{logFileName}" + rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) + pass streamHandler = logging.StreamHandler() streamHandler.setFormatter(formatter) - rolloverHandler = handlers.RotatingFileHandler(log_file, maxBytes=50000000, backupCount=5) - l.setLevel(level) - l.addHandler(fileHandler) - l.addHandler(streamHandler) - l.addHandler(rolloverHandler) + rolloverHandler.setFormatter(formatter) + logger.setLevel(logLevel) + logger.addHandler(streamHandler) + logger.addHandler(rolloverHandler) + return logger \ No newline at end of file diff --git a/lib/messaging.py b/lib/messaging.py index e6e42e4..fa2e218 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,6 +1,6 @@ from redis import Redis -class RedisMessaging(): +class RedisMessaging: """ PyHSS Redis Message Service A class for sending and receiving redis messages. @@ -10,16 +10,66 @@ def __init__(self, host: str='localhost', port: int=6379): self.redisClient = Redis(host=host, port=port) pass - def sendMessage(self, queue: str, message: str) -> str: - self.redisClient.rpush(queue, message) + def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key). + """ + try: + self.redisClient.rpush(queue, message) + if queueExpiry is not None: + self.redisClient.expire(queue, queueExpiry) + return f'{message} stored in {queue} successfully.' + except Exception as e: + return '' def getMessage(self, queue: str) -> str: - message = self.redisClient.lpop(queue) - if message is None: - message = '' - else: - try: - message = message.decode() - except (UnicodeDecodeError, AttributeError): - pass - return message \ No newline at end of file + """ + Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. + """ + try: + message = self.redisClient.lpop(queue) + if message is None: + message = '' + else: + try: + message = message.decode() + except (UnicodeDecodeError, AttributeError): + pass + return message + except Exception as e: + return '' + + def getQueues(self, pattern: str='*') -> list: + """ + Returns all Queues (Keys) in the database. + """ + try: + allQueues = self.redisClient.keys(pattern) + return [x.decode() for x in allQueues] + except Exception as e: + return [] + + def getNextQueue(self, pattern: str='*') -> dict: + """ + Returns the next Queue (Key) in the list. + """ + try: + for nextQueue in self.redisClient.scan_iter(match=pattern): + return nextQueue.decode() + except Exception as e: + return {} + + def deleteQueue(self, queue: str) -> bool: + """ + Deletes the given Queue (Key) + """ + try: + self.redisClient.delete(queue) + return True + except Exception as e: + return False + + +if __name__ == '__main__': + redisMessaging = RedisMessaging() + print(redisMessaging.getNextQueue()) \ No newline at end of file diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py new file mode 100644 index 0000000..34fb25a --- /dev/null +++ b/lib/messagingAsync.py @@ -0,0 +1,84 @@ +import asyncio +import redis.asyncio as redis + +class RedisMessagingAsync: + """ + PyHSS Redis Asynchronous Message Service + A class for sending and receiving redis messages asynchronously. + """ + + def __init__(self, host: str='localhost', port: int=6379): + self.redisClient = redis.Redis(host=host, port=port) + + async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key), and sets an expiry (in seconds) if provided. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + sendMessageResult = await(redisPipe.rpush(queue, message).execute()) + if queueExpiry is not None: + expireKeyResult = await(redisPipe.expire(queue, queueExpiry).execute()) + return f'{message} stored in {queue} successfully.' + except Exception as e: + return '' + + async def getMessage(self, queue: str) -> str: + """ + Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + message = await(redisPipe.lpop(queue).execute()) + if message is None: + message = '' + else: + try: + message = message[0].decode() + except (UnicodeDecodeError, AttributeError): + pass + return message + except Exception as e: + print(e) + return '' + + async def getQueues(self, pattern: str='*') -> list: + """ + Returns all Queues (Keys) in the database. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + allQueues = await(redisPipe.keys(pattern).execute()) + return [x.decode() for x in allQueues[0]] + except Exception as e: + return [] + + async def getNextQueue(self, pattern: str='*') -> dict: + """ + Returns the next Queue (Key) in the list. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + return await(redisPipe.keys(pattern).execute())[1][0].decode() + except Exception as e: + return {} + + async def deleteQueue(self, queue: str) -> bool: + """ + Deletes the given Queue (Key) + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await(redisPipe.delete(queue).execute()) + return True + except Exception as e: + return False + + async def closeConnection(self) -> bool: + await self.redisClient.close() + return True + + +if __name__ == '__main__': + redisMessaging = RedisMessagingAsync() + print(redisMessaging.getNextQueue()) \ No newline at end of file diff --git a/lib/milenage.py b/lib/milenage.py index 6175920..ef31daa 100644 --- a/lib/milenage.py +++ b/lib/milenage.py @@ -14,8 +14,6 @@ from lte import BaseLTEAuthAlgo import logging -import logtool -logtool = logtool.LogTool() import os import sys sys.path.append(os.path.realpath('../')) diff --git a/lib/old.logtool.py b/lib/old.logtool.py new file mode 100644 index 0000000..ac341a2 --- /dev/null +++ b/lib/old.logtool.py @@ -0,0 +1,243 @@ +import logging +import logging.handlers as handlers +import os +import sys +import inspect +sys.path.append(os.path.realpath('../')) +import yaml +from datetime import datetime as log_dt +with open("config.yaml", 'r') as stream: + yaml_config = (yaml.safe_load(stream)) + +import json +import pickle + +from prometheus_client import Counter, Gauge, Histogram, Summary + +from prometheus_client import start_http_server + +if yaml_config['prometheus']['enabled'] == True: + #Check if this is the HSS service, and if it's not increment the port before starting + print(sys.argv[0]) + if 'hss.py' in str(sys.argv[0]): + print("Starting Prometheus on port from config " + str(yaml_config['prometheus']['port'])) + else: + print("This is not the HSS stack so offsetting Prometheus port") + yaml_config['prometheus']['port'] += 1 + try: + start_http_server(yaml_config['prometheus']['port']) + print("Started Prometheus on port " + str(yaml_config['prometheus']['port'])) + except Exception as E: + print("Error loading Prometheus") + print(E) + + +tags = ['diameter_application_id', 'diameter_cmd_code', 'endpoint', 'type'] +prom_diam_request_count = Counter('prom_diam_request_count', 'Number of Diameter Requests', tags) +prom_diam_response_count_successful = Counter('prom_diam_response_count_successful', 'Number of Successful Diameter Responses', tags) +prom_diam_response_count_fail = Counter('prom_diam_response_count_fail', 'Number of Failed Diameter Responses', tags) +prom_diam_connected_peers = Gauge('prom_diam_connected_peers', 'Connected Diameter Peer Count', ['endpoint']) +prom_diam_connected_peers._metrics.clear() +prom_diam_response_time_diam = Histogram('prom_diam_response_time_diam', 'Diameter Response Times') +prom_diam_response_time_method = Histogram('prom_diam_response_time_method', 'Diameter Response Times', tags) +prom_diam_response_time_db = Summary('prom_diam_response_time_db', 'Diameter Response Times from Database') +prom_diam_response_time_h = Histogram('request_latency_seconds', 'Diameter Response Time Histogram') +prom_diam_auth_event_count = Counter('prom_diam_auth_event_count', 'Diameter Authentication related Counters', ['diameter_application_id', 'diameter_cmd_code', 'event', 'imsi_prefix']) +prom_diam_eir_event_count = Counter('prom_diam_eir_event_count', 'Diameter EIR event related Counters', ['response']) + +prom_eir_devices = Counter('prom_eir_devices', 'Profile of attached devices', ['imei_prefix', 'device_type', 'device_name']) + +prom_http_geored = Counter('prom_http_geored', 'Number of Geored Pushes', ['geored_host', 'endpoint', 'http_response_code', 'error']) +prom_flask_http_geored_endpoints = Counter('prom_flask_http_geored_endpoints', 'Number of Geored Pushes Received', ['geored_host', 'endpoint']) + + +prom_pcrf_subs = Gauge('prom_pcrf_subs', 'Number of attached PCRF Subscribers') +prom_mme_subs = Gauge('prom_mme_subs', 'Number of attached MME Subscribers') +prom_ims_subs = Gauge('prom_ims_subs', 'Number of attached IMS Subscribers') + +class LogTool: + def __init__(self, **kwargs): + print("Instantiating LogTool with Kwargs " + str(kwargs.items())) + if yaml_config['redis']['enabled'] == True: + print("Redis support enabled") + import redis + redis_store = redis.Redis(host=str(yaml_config['redis']['host']), port=str(yaml_config['redis']['port']), db=0) + self.redis_store = redis_store + try: + if "HSS_Init" in kwargs: + print("Called Init for HSS_Init") + redis_store.incr('restart_count') + if yaml_config['redis']['clear_stats_on_boot'] == True: + logging.debug("Clearing ActivePeerDict") + redis_store.delete('ActivePeerDict') + else: + logging.debug("Leaving prexisting Redis keys") + #Clear ActivePeerDict + redis_store.delete('ActivePeerDict') + + #Clear Async Keys + for key in redis_store.scan_iter("*_request_queue"): + print("Deleting Key: " + str(key)) + redis_store.delete(key) + logging.info("Connected to Redis server") + else: + logging.info("Init of Logtool but not from HSS_Init") + except: + logging.error("Failed to connect to Redis server - Disabling") + yaml_config['redis']['enabled'] == False + + #function for handling incrimenting Redis counters with error handling + def RedisIncrimenter(self, name): + if yaml_config['redis']['enabled'] == True: + try: + self.redis_store.incr(name) + except: + logging.error("failed to incriment " + str(name)) + + def RedisStore(self, key, value): + if yaml_config['redis']['enabled'] == True: + try: + self.redis_store.set(key, value) + except: + logging.error("failed to set Redis key " + str(key) + " to value " + str(value)) + + def RedisGet(self, key): + if yaml_config['redis']['enabled'] == True: + try: + return self.redis_store.get(key) + except: + logging.error("failed to set Redis key " + str(key)) + + def RedisHMSET(self, key, value_dict): + if yaml_config['redis']['enabled'] == True: + try: + self.redis_store.hmset(key, value_dict) + except: + logging.error("failed to set hm Redis key " + str(key) + " to value " + str(value_dict)) + + def Async_SendRequest(self, request, DiameterHostname): + if yaml_config['redis']['enabled'] == True: + try: + import time + print("Writing request to Queue '" + str(DiameterHostname) + "_request_queue'") + self.redis_store.hset(str(DiameterHostname) + "_request_queue", "hss_Async_client_" + str(int(time.time())), request) + print("Written to Queue to send.") + except Exception as E: + logging.error("failed to run Async_SendRequest to " + str(DiameterHostname)) + + def RedisHMGET(self, key): + if yaml_config['redis']['enabled'] == True: + try: + logging.debug("Getting HM Get from " + str(key)) + data = self.redis_store.hgetall(key) + logging.debug("Result: " + str(data)) + return data + except: + logging.error("failed to get hm Redis key " + str(key)) + + def RedisHDEL(self, key, item): + if yaml_config['redis']['enabled'] == True: + try: + logging.debug("Removing item " + str(item) + " from key " + str(key)) + self.redis_store.hdel(key, item) + except: + logging.error("failed to hdel Redis key " + str(key) + " item " + str(item)) + + def RedisStoreDict(self, key, value): + if yaml_config['redis']['enabled'] == True: + try: + self.redis_store.set(str(key), pickle.dumps(value)) + except: + logging.error("failed to set Redis dict " + str(key) + " to value " + str(value)) + + def RedisGetDict(self, key): + if yaml_config['redis']['enabled'] == True: + try: + read_dict = self.redis_store.get(key) + return pickle.loads(read_dict) + except: + logging.error("failed to hmget Redis key " + str(key)) + + def GetDiameterPeers(self): + if yaml_config['redis']['enabled'] == True: + try: + data = self.RedisGet('ActivePeerDict') + ActivePeerDict = json.loads(data) + return ActivePeerDict + except: + logging.error("Failed to get ActivePeerDict") + + + def Manage_Diameter_Peer(self, peername, ip, action): + if yaml_config['redis']['enabled'] == True: + try: + logging.debug("Managing Diameter peer to Redis with hostname" + str(peername) + " and IP " + str(ip)) + now = log_dt.now() + timestamp = str(now.strftime("%Y-%m-%d %H:%M:%S")) + + #Try and get IP and Port seperately + try: + ip = ip[0] + port = ip[1] + except: + pass + + if self.redis_store.exists('ActivePeerDict') == False: + #Initialise empty active peer dict in Redis + logging.debug("Populated new empty ActivePeerDict Redis key") + ActivePeerDict = {} + ActivePeerDict['internal_connection'] = {"connect_timestamp" : timestamp} + self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) + + if action == "add": + data = self.RedisGet('ActivePeerDict') + ActivePeerDict = json.loads(data) + logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict) + " to add peer " + str(peername) + " with ip " + str(ip)) + + + #If key has already existed in dict due to disconnect / reconnect, get reconnection count + try: + reconnection_count = ActivePeerDict[str(ip)]['reconnection_count'] + 1 + except: + reconnection_count = 0 + + ActivePeerDict[str(ip)] = {"connect_timestamp" : timestamp, \ + "recv_ip_address" : str(ip), "DiameterHostname" : "Unknown - Socket connection only", \ + "reconnection_count" : reconnection_count, + "connection_status" : "Pending"} + self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) + + if action == "remove": + data = self.RedisGet('ActivePeerDict') + ActivePeerDict = json.loads(data) + logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict)) + ActivePeerDict[str(ip)] = {"disconnect_timestamp" : str(timestamp), \ + "DiameterHostname" : str(ActivePeerDict[str(ip)]['DiameterHostname']), \ + "reconnection_count" : ActivePeerDict[str(ip)]['reconnection_count'], + "connection_status" : "Disconnected"} + self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) + + if action == "update": + data = self.RedisGet('ActivePeerDict') + ActivePeerDict = json.loads(data) + ActivePeerDict[str(ip)]['DiameterHostname'] = str(peername) + ActivePeerDict[str(ip)]['last_dwr_timestamp'] = str(timestamp) + ActivePeerDict[str(ip)]['connection_status'] = "Connected" + self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) + except Exception as E: + logging.error("failed to add/update/remove Diameter peer from Redis") + logging.error(E) + + + def setup_logger(self, logger_name, log_file, level=logging.DEBUG): + l = logging.getLogger(logger_name) + formatter = logging.Formatter('%(asctime)s \t %(levelname)s \t {%(pathname)s:%(lineno)d} \t %(message)s') + fileHandler = logging.FileHandler(log_file, mode='a+') + fileHandler.setFormatter(formatter) + streamHandler = logging.StreamHandler() + streamHandler.setFormatter(formatter) + rolloverHandler = handlers.RotatingFileHandler(log_file, maxBytes=50000000, backupCount=5) + l.setLevel(level) + l.addHandler(fileHandler) + l.addHandler(streamHandler) + l.addHandler(rolloverHandler) diff --git a/log/.gitkeep b/log/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/old.hss.py b/old.hss.py new file mode 100644 index 0000000..f32c53e --- /dev/null +++ b/old.hss.py @@ -0,0 +1,1012 @@ +# PyHSS +# This serves as a basic 3GPP Home Subscriber Server implimenting a EIR & IMS HSS functionality +import logging +import yaml +import os +import sys +import socket +import socketserver +import binascii +import time +import _thread +import threading +import sctp +import traceback +import pprint +import diameter as DiameterLib +import systemd.daemon +from threading import Thread, Lock +from logtool import * +import contextlib +import queue + + +class ThreadJoiner: + def __init__(self, threads, thread_event): + self.threads = threads + self.thread_event = thread_event + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + self.thread_event.set() + for thread in self.threads: + while thread.is_alive(): + try: + thread.join(timeout=1) + except Exception as e: + print( + f"ThreadJoiner Exception: failed to join thread {thread}: {e}" + ) + break + + +class PyHSS: + def __init__(self): + # Load config from yaml file + try: + with open("config.yaml", "r") as config_stream: + self.yaml_config = yaml.safe_load(config_stream) + except: + print(f"config.yaml not found, exiting PyHSS.") + quit() + + # Setup logging + self.logtool = LogTool(HSS_Init=True) + self.logtool.setup_logger( + "HSS_Logger", + self.yaml_config["logging"]["logfiles"]["hss_logging_file"], + level=self.yaml_config["logging"]["level"], + ) + self.logger = logging.getLogger("HSS_Logger") + if self.yaml_config["logging"]["log_to_terminal"]: + logging.getLogger().addHandler(logging.StreamHandler()) + + # Setup Diameter + self.diameter_instance = DiameterLib.Diameter( + str(self.yaml_config["hss"].get("OriginHost", "")), + str(self.yaml_config["hss"].get("OriginRealm", "")), + str(self.yaml_config["hss"].get("ProductName", "")), + str(self.yaml_config["hss"].get("MNC", "")), + str(self.yaml_config["hss"].get("MCC", "")), + ) + + self.max_diameter_retries = int( + self.yaml_config["hss"].get("diameter_max_retries", 1) + ) + + + + try: + assert(self.yaml_config['prometheus']['enabled'] == True) + assert(self.yaml_config['prometheus']['async_subscriber_count'] == True) + + self.logger.info("Enabling Prometheus Async Sub thread") + #Add Prometheus Async Calls + prom_async_thread = threading.Thread( + target=self.prom_async_function, + name=f"prom_async_function", + args=(), + ) + prom_async_thread.start() + except: + self.logger.info("Prometheus Async Sub Count thread disabled") + + + + def terminate_connection(self, clientsocket, client_address, thread_event): + thread_event.set() + clientsocket.close() + self.logtool.Manage_Diameter_Peer(client_address, client_address, "remove") + + def handle_new_connection(self, clientsocket, client_address): + # Create our threading event, accessible by sibling threads in this connection. + socket_close_event = threading.Event() + try: + send_queue = queue.Queue() + self.logger.debug(f"New connection from {client_address}") + if ( + "client_socket_timeout" not in self.yaml_config["hss"] + or self.yaml_config["hss"]["client_socket_timeout"] == 0 + ): + self.yaml_config["hss"]["client_socket_timeout"] = 120 + clientsocket.settimeout( + self.yaml_config["hss"].get("client_socket_timeout", 120) + ) + + send_data_thread = threading.Thread( + target=self.send_data, + name=f"send_data_thread", + args=(clientsocket, send_queue, socket_close_event), + ) + self.logger.debug("handle_new_connection: Starting send_data thread") + send_data_thread.start() + + self.logtool.Manage_Diameter_Peer(client_address, client_address, "add") + manage_client_thread = threading.Thread( + target=self.manage_client, + name=f"manage_client_thread: client_address: {client_address}", + args=( + clientsocket, + client_address, + self.diameter_instance, + socket_close_event, + send_queue, + ), + ) + self.logger.debug("handle_new_connection: Starting manage_client thread") + manage_client_thread.start() + + threads_to_join = [manage_client_thread] + threads_to_join.append(send_data_thread) + + # If Redis is enabled, start manage_client_async and manage_client_dwr threads. + if self.yaml_config["redis"]["enabled"]: + if ( + "async_check_interval" not in self.yaml_config["hss"] + or self.yaml_config["hss"]["async_check_interval"] == 0 + ): + self.yaml_config["hss"]["async_check_interval"] = 10 + manage_client_async_thread = threading.Thread( + target=self.manage_client_async, + name=f"manage_client_async_thread: client_address: {client_address}", + args=( + clientsocket, + client_address, + self.diameter_instance, + socket_close_event, + send_queue, + ), + ) + self.logger.debug( + "handle_new_connection: Starting manage_client_async thread" + ) + manage_client_async_thread.start() + + manage_client_dwr_thread = threading.Thread( + target=self.manage_client_dwr, + name=f"manage_client_dwr_thread: client_address: {client_address}", + args=( + clientsocket, + client_address, + self.diameter_instance, + socket_close_event, + send_queue, + ), + ) + self.logger.debug( + "handle_new_connection: Starting manage_client_dwr thread" + ) + manage_client_dwr_thread.start() + + threads_to_join.append(manage_client_async_thread) + threads_to_join.append(manage_client_dwr_thread) + + self.logger.debug( + f"handle_new_connection: Total PyHSS Active Threads: {threading.active_count()}" + ) + for thread in threading.enumerate(): + if "dummy" not in thread.name.lower(): + self.logger.debug(f"Active Thread name: {thread.name}") + + with ThreadJoiner(threads_to_join, socket_close_event): + socket_close_event.wait() + self.terminate_connection( + clientsocket, client_address, socket_close_event + ) + self.logger.debug(f"Closing thread for client; {client_address}") + return + + except Exception as e: + self.logger.error(f"Exception for client {client_address}: {e}") + self.logger.error(f"Closing connection for {client_address}") + self.terminate_connection(clientsocket, client_address, socket_close_event) + return + + @prom_diam_response_time_diam.time() + def process_Diameter_request( + self, clientsocket, client_address, diameter, data, thread_event, send_queue + ): + packet_length = diameter.decode_diameter_packet_length( + data + ) # Calculate length of packet from start of packet + if packet_length <= 32: + self.logger.error("Received an invalid packet with length <= 32") + self.terminate_connection(clientsocket, client_address, thread_event) + return + + data_sum = data + clientsocket.recv( + packet_length - 32 + ) # Recieve remainder of packet from buffer + packet_vars, avps = diameter.decode_diameter_packet( + data_sum + ) # Decode packet into array of AVPs and Dict of Packet Variables (packet_vars) + try: + packet_vars["Source_IP"] = client_address[0] + except: + self.logger.debug("Failed to add Source_IP to packet_vars") + + start_time = time.time() + origin_host = diameter.get_avp_data(avps, 264)[0] # Get OriginHost from AVP + origin_host = binascii.unhexlify(origin_host).decode("utf-8") # Format it + + # label_values = str(packet_vars['ApplicationId']), str(packet_vars['command_code']), origin_host, 'request' + prom_diam_request_count.labels( + str(packet_vars["ApplicationId"]), + str(packet_vars["command_code"]), + origin_host, + "request", + ).inc() + + + self.logger.info( + "\n\nNew request with Command Code: " + + str(packet_vars["command_code"]) + + ", ApplicationID: " + + str(packet_vars["ApplicationId"]) + + ", flags " + + str(packet_vars["flags"]) + + ", e2e ID: " + + str(packet_vars["end-to-end-identifier"]) + ) + + # Gobble up any Response traffic that is sent to us: + if packet_vars["flags_bin"][0:1] == "0": + self.logger.info("Got a Response, not a request - dropping it.") + self.logger.info(packet_vars) + return + + # Send Capabilities Exchange Answer (CEA) response to Capabilites Exchange Request (CER) + elif ( + packet_vars["command_code"] == 257 + and packet_vars["ApplicationId"] == 0 + and packet_vars["flags"] == "80" + ): + self.logger.info( + f"Received Request with command code 257 (CER) from {origin_host}" + + "\n\tSending response (CEA)" + ) + try: + response = diameter.Answer_257( + packet_vars, avps, str(self.yaml_config["hss"]["bind_ip"][0]) + ) # Generate Diameter packet + # prom_diam_response_count_successful.inc() + except: + response = diameter.Respond_ResultCode( + packet_vars, avps, 5012 + ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) + # prom_diam_response_count_fail.inc() + self.logger.info("Generated CEA") + self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") + prom_diam_connected_peers.labels(origin_host).set(1) + + # Send Credit Control Answer (CCA) response to Credit Control Request (CCR) + elif ( + packet_vars["command_code"] == 272 + and packet_vars["ApplicationId"] == 16777238 + ): + self.logger.info( + f"Received 3GPP Credit-Control-Request from {origin_host}" + + "\n\tGenerating (CCA)" + ) + try: + response = diameter.Answer_16777238_272( + packet_vars, avps + ) # Generate Diameter packet + except Exception as E: + response = diameter.Respond_ResultCode( + packet_vars, avps, 5012 + ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) + self.logger.error(f"Failed to generate response {str(E)}") + self.logger.info("Generated CCA") + + # Send Device Watchdog Answer (DWA) response to Device Watchdog Requests (DWR) + elif ( + packet_vars["command_code"] == 280 + and packet_vars["ApplicationId"] == 0 + and packet_vars["flags"] == "80" + ): + self.logger.info( + f"Received Request with command code 280 (DWR) from {origin_host}" + + "\n\tSending response (DWA)" + ) + self.logger.debug(f"Total PyHSS Active Threads: {threading.active_count()}") + try: + response = diameter.Answer_280( + packet_vars, avps + ) # Generate Diameter packet + except: + response = diameter.Respond_ResultCode( + packet_vars, avps, 5012 + ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) + self.logger.info("Generated DWA") + self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") + + # Send Disconnect Peer Answer (DPA) response to Disconnect Peer Request (DPR) + elif ( + packet_vars["command_code"] == 282 + and packet_vars["ApplicationId"] == 0 + and packet_vars["flags"] == "80" + ): + self.logger.info( + f"Received Request with command code 282 (DPR) from {origin_host}" + + "\n\tForwarding request..." + ) + response = diameter.Answer_282( + packet_vars, avps + ) # Generate Diameter packet + self.logger.info("Generated DPA") + self.logtool.Manage_Diameter_Peer(origin_host, client_address, "remove") + prom_diam_connected_peers.labels(origin_host).set(0) + + # S6a Authentication Information Answer (AIA) response to Authentication Information Request (AIR) + elif ( + packet_vars["command_code"] == 318 + and packet_vars["ApplicationId"] == 16777251 + and packet_vars["flags"] == "c0" + ): + self.logger.info( + f"Received Request with command code 318 (3GPP Authentication-Information-Request) from {origin_host}" + + "\n\tGenerating (AIA)" + ) + try: + response = diameter.Answer_16777251_318( + packet_vars, avps + ) # Generate Diameter packet + self.logger.info("Generated AIR") + except Exception as e: + self.logger.info("Failed to generate Diameter Response for AIR") + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated DIAMETER_USER_DATA_NOT_AVAILABLE AIR") + + # S6a Update Location Answer (ULA) response to Update Location Request (ULR) + elif ( + packet_vars["command_code"] == 316 + and packet_vars["ApplicationId"] == 16777251 + ): + self.logger.info( + f"Received Request with command code 316 (3GPP Update Location-Request) from {origin_host}" + + "\n\tGenerating (ULA)" + ) + try: + response = diameter.Answer_16777251_316( + packet_vars, avps + ) # Generate Diameter packet + self.logger.info("Generated ULA") + except Exception as e: + self.logger.info("Failed to generate Diameter Response for ULR") + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated error DIAMETER_USER_DATA_NOT_AVAILABLE ULA") + + # Send ULA data & clear tx buffer + clientsocket.sendall(bytes.fromhex(response)) + response = "" + if "Insert_Subscriber_Data_Force" in yaml_config["hss"]: + if yaml_config["hss"]["Insert_Subscriber_Data_Force"] == True: + self.logger.debug("ISD triggered after ULA") + # Generate Insert Subscriber Data Request + response = diameter.Request_16777251_319( + packet_vars, avps + ) # Generate Diameter packet + self.logger.info("Generated IDR") + # Send ISD data + send_queue.put(bytes.fromhex(response)) + # clientsocket.sendall(bytes.fromhex(response)) + self.logger.info("Sent IDR") + return + # S6a inbound Insert-Data-Answer in response to our IDR + elif ( + packet_vars["command_code"] == 319 + and packet_vars["ApplicationId"] == 16777251 + ): + self.logger.info( + f"Received response with command code 319 (3GPP Insert-Subscriber-Answer) from {origin_host}" + ) + return + # S6a Purge UE Answer (PUA) response to Purge UE Request (PUR) + elif ( + packet_vars["command_code"] == 321 + and packet_vars["ApplicationId"] == 16777251 + ): + self.logger.info( + f"Received Request with command code 321 (3GPP Purge UE Request) from {origin_host}" + + "\n\tGenerating (PUA)" + ) + try: + response = diameter.Answer_16777251_321( + packet_vars, avps + ) # Generate Diameter packet + except: + response = diameter.Respond_ResultCode( + packet_vars, avps, 5012 + ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) + self.logger.error("Failed to generate PUA") + self.logger.info("Generated PUA") + # S6a Notify Answer (NOA) response to Notify Request (NOR) + elif ( + packet_vars["command_code"] == 323 + and packet_vars["ApplicationId"] == 16777251 + ): + self.logger.info( + f"Received Request with command code 323 (3GPP Notify Request) from {origin_host}" + + "\n\tGenerating (NOA)" + ) + try: + response = diameter.Answer_16777251_323( + packet_vars, avps + ) # Generate Diameter packet + except: + response = diameter.Respond_ResultCode( + packet_vars, avps, 5012 + ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) + self.logger.error("Failed to generate NOA") + self.logger.info("Generated NOA") + # S6a Cancel Location Answer eater + elif ( + packet_vars["command_code"] == 317 + and packet_vars["ApplicationId"] == 16777251 + ): + self.logger.info("Received Response with command code 317 (3GPP Cancel Location Request) from " + str(origin_host)) + + # Cx Authentication Answer + elif ( + packet_vars["command_code"] == 300 + and packet_vars["ApplicationId"] == 16777216 + ): + self.logger.info( + f"Received Request with command code 300 (3GPP Cx User Authentication Request) from {origin_host}" + + "\n\tGenerating (MAA)" + ) + try: + response = diameter.Answer_16777216_300( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Cx Auth Answer" + ) + self.logger.info(e) + self.logger.info(traceback.print_exc()) + self.logger.info( + type(e).__name__, # TypeError + __file__, # /tmp/example.py + e.__traceback__.tb_lineno # 2 + ) + + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated Cx Auth Answer") + + # Cx Server Assignment Answer + elif ( + packet_vars["command_code"] == 301 + and packet_vars["ApplicationId"] == 16777216 + ): + self.logger.info( + f"Received Request with command code 301 (3GPP Cx Server Assignemnt Request) from {origin_host}" + + "\n\tGenerating (MAA)" + ) + try: + response = diameter.Answer_16777216_301( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Cx Server Assignment Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated Cx Server Assignment Answer") + + # Cx Location Information Answer + elif ( + packet_vars["command_code"] == 302 + and packet_vars["ApplicationId"] == 16777216 + ): + self.logger.info( + f"Received Request with command code 302 (3GPP Cx Location Information Request) from {origin_host}" + + "\n\tGenerating (MAA)" + ) + try: + response = diameter.Answer_16777216_302( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Cx Location Information Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated Cx Location Information Answer") + + # Cx Multimedia Authentication Answer + elif ( + packet_vars["command_code"] == 303 + and packet_vars["ApplicationId"] == 16777216 + ): + self.logger.info( + f"Received Request with command code 303 (3GPP Cx Multimedia Authentication Request) from {origin_host}" + + "\n\tGenerating (MAA)" + ) + try: + response = diameter.Answer_16777216_303( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Cx Multimedia Authentication Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated Cx Multimedia Authentication Answer") + + # Sh User-Data-Answer + elif ( + packet_vars["command_code"] == 306 + and packet_vars["ApplicationId"] == 16777217 + ): + self.logger.info( + f"Received Request with command code 306 (3GPP Sh User-Data Request) from {origin_host}" + ) + try: + response = diameter.Answer_16777217_306( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Sh User-Data Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 5001 + ) # DIAMETER_ERROR_USER_UNKNOWN + send_queue.put(bytes.fromhex(response)) + # clientsocket.sendall(bytes.fromhex(response)) + self.logger.info("Sent negative response") + return + self.logger.info("Generated Sh User-Data Answer") + + # Sh Profile-Update-Answer + elif ( + packet_vars["command_code"] == 307 + and packet_vars["ApplicationId"] == 16777217 + ): + self.logger.info( + f"Received Request with command code 307 (3GPP Sh Profile-Update Request) from {origin_host}" + ) + try: + response = diameter.Answer_16777217_307( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for Sh User-Data Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 5001 + ) # DIAMETER_ERROR_USER_UNKNOWN + send_queue.put(bytes.fromhex(response)) + # clientsocket.sendall(bytes.fromhex(response)) + self.logger.info("Sent negative response") + return + self.logger.info("Generated Sh Profile-Update Answer") + + # S13 ME-Identity-Check Answer + elif ( + packet_vars["command_code"] == 324 + and packet_vars["ApplicationId"] == 16777252 + ): + self.logger.info( + f"Received Request with command code 324 (3GPP S13 ME-Identity-Check Request) from {origin_host}" + + "\n\tGenerating (MICA)" + ) + try: + response = diameter.Answer_16777252_324( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for S13 ME-Identity Check Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated S13 ME-Identity Check Answer") + + # SLh LCS-Routing-Info-Answer + elif ( + packet_vars["command_code"] == 8388622 + and packet_vars["ApplicationId"] == 16777291 + ): + self.logger.info( + f"Received Request with command code 324 (3GPP SLh LCS-Routing-Info-Answer Request) from {origin_host}" + + "\n\tGenerating (MICA)" + ) + try: + response = diameter.Answer_16777291_8388622( + packet_vars, avps + ) # Generate Diameter packet + except Exception as e: + self.logger.info( + "Failed to generate Diameter Response for SLh LCS-Routing-Info-Answer" + ) + self.logger.info(e) + traceback.print_exc() + response = diameter.Respond_ResultCode( + packet_vars, avps, 4100 + ) # DIAMETER_USER_DATA_NOT_AVAILABLE + self.logger.info("Generated SLh LCS-Routing-Info-Answer") + + # Handle Responses generated by the Async functions + elif packet_vars["flags"] == "00": + self.logger.info( + "Got response back with command code " + + str(packet_vars["command_code"]) + ) + self.logger.info("response packet_vars: " + str(packet_vars)) + self.logger.info("response avps: " + str(avps)) + response = "" + else: + self.logger.error( + "\n\nRecieved unrecognised request with Command Code: " + + str(packet_vars["command_code"]) + + ", ApplicationID: " + + str(packet_vars["ApplicationId"]) + + " and flags " + + str(packet_vars["flags"]) + ) + for keys in packet_vars: + self.logger.error(keys) + self.logger.error("\t" + str(packet_vars[keys])) + self.logger.error(avps) + self.logger.error("Sending negative response") + response = diameter.Respond_ResultCode( + packet_vars, avps, 3001 + ) # Generate Diameter response with "Command Unsupported" (3001) + send_queue.put(bytes.fromhex(response)) + # clientsocket.sendall(bytes.fromhex(response)) # Send it + + prom_diam_response_time_method.labels( + str(packet_vars["ApplicationId"]), + str(packet_vars["command_code"]), + origin_host, + "request", + ).observe(time.time() - start_time) + + # Diameter Transmission + retries = 0 + while retries < self.max_diameter_retries: + try: + send_queue.put(bytes.fromhex(response)) + break + except socket.error as e: + self.logger.error(f"Socket error for client {client_address}: {e}") + retries += 1 + if retries > self.max_diameter_retries: + self.logger.error( + f"Max retries reached for client {client_address}. Closing connection." + ) + self.terminate_connection( + clientsocket, client_address, thread_event + ) + break + time.sleep(1) # Wait for 1 second before retrying + except Exception as e: + self.logger.info("Failed to send Diameter Response") + self.logger.debug(f"Diameter Response Body: {str(response)}") + self.logger.info(e) + traceback.print_exc() + self.terminate_connection(clientsocket, client_address, thread_event) + self.logger.info("Thread terminated to " + str(client_address)) + break + + def manage_client( + self, clientsocket, client_address, diameter, thread_event, send_queue + ): + while True: + try: + data = clientsocket.recv(32) + if not data: + self.logger.info( + f"manage_client: Connection closed by {str(client_address)}" + ) + self.terminate_connection( + clientsocket, client_address, thread_event + ) + return + self.process_Diameter_request( + clientsocket, + client_address, + diameter, + data, + thread_event, + send_queue, + ) + + except socket.timeout: + self.logger.warning( + f"manage_client: Socket timeout for client: {client_address}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + + except socket.error as e: + self.logger.error( + f"manage_client: Socket error for client {client_address}: {e}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + + except KeyboardInterrupt: + # Clean up the connection on keyboard interrupt + response = ( + diameter.Request_282() + ) # Generate Disconnect Peer Request Diameter packet + send_queue.put(bytes.fromhex(response)) + # clientsocket.sendall(bytes.fromhex(response)) # Send it + self.terminate_connection(clientsocket, client_address, thread_event) + self.logger.info( + "manage_client: Connection closed nicely due to keyboard interrupt" + ) + sys.exit() + + except Exception as manage_client_exception: + self.logger.error( + f"manage_client: Exception in manage_client: {manage_client_exception}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + + def manage_client_async( + self, clientsocket, client_address, diameter, thread_event, send_queue + ): + # # Sleep for 10 seconds to wait for the connection to come up + time.sleep(10) + self.logger.debug("manage_client_async: Getting ActivePeerDict") + self.logger.debug( + f"manage_client_async: Total PyHSS Active Threads: {threading.active_count()}" + ) + ActivePeerDict = self.logtool.GetDiameterPeers() + self.logger.debug( + f"manage_client_async: Got Active Peer dict in Async Thread: {str(ActivePeerDict)}" + ) + if client_address[0] in ActivePeerDict: + self.logger.debug( + "manage_client_async: This is host: " + + str(ActivePeerDict[str(client_address[0])]["DiameterHostname"]) + ) + DiameterHostname = str( + ActivePeerDict[str(client_address[0])]["DiameterHostname"] + ) + else: + self.logger.debug("manage_client_async: No matching Diameter Host found.") + return + + while True: + try: + if thread_event.is_set(): + self.logger.debug( + f"manage_client_async: Closing manage_client_async thread for client: {client_address}" + ) + self.terminate_connection( + clientsocket, client_address, thread_event + ) + return + time.sleep(self.yaml_config["hss"]["async_check_interval"]) + self.logger.debug( + f"manage_client_async: Sleep interval expired for Diameter Peer {str(DiameterHostname)}" + ) + if int(self.yaml_config["hss"]["async_check_interval"]) == 0: + self.logger.error( + f"manage_client_async: No async_check_interval Timer set - Not checking Async Queue for host connection {str(DiameterHostname)}" + ) + return + try: + self.logger.debug( + "manage_client_async: Reading from request queue '" + + str(DiameterHostname) + + "_request_queue'" + ) + data_to_send = self.logtool.RedisHMGET( + str(DiameterHostname) + "_request_queue" + ) + for key in data_to_send: + data = data_to_send[key].decode("utf-8") + send_queue.put(bytes.fromhex(data)) + self.logtool.RedisHDEL( + str(DiameterHostname) + "_request_queue", key + ) + except Exception as redis_exception: + self.logger.error( + f"manage_client_async: Redis exception in manage_client_async: {redis_exception}" + ) + self.terminate_connection( + clientsocket, client_address, thread_event + ) + return + + except socket.timeout: + self.logger.warning( + f"manage_client_async: Socket timeout for client: {client_address}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + + except socket.error as e: + self.logger.error( + f"manage_client_async: Socket error for client {client_address}: {e}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + except Exception: + self.logger.error( + f"manage_client_async: Terminating for host connection {str(DiameterHostname)}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + + def manage_client_dwr( + self, clientsocket, client_address, diameter, thread_event, send_queue + ): + while True: + try: + if thread_event.is_set(): + self.logger.debug( + f"Closing manage_client_dwr thread for client: {client_address}" + ) + self.terminate_connection( + clientsocket, client_address, thread_event + ) + return + if ( + int(self.yaml_config["hss"]["device_watchdog_request_interval"]) + != 0 + ): + time.sleep( + self.yaml_config["hss"]["device_watchdog_request_interval"] + ) + else: + self.logger.info("DWR Timer to set to 0 - Not sending DWRs") + return + + except: + self.logger.error( + "No DWR Timer set - Not sending Device Watchdog Requests" + ) + return + try: + self.logger.debug("Sending Keepalive to " + str(client_address) + "...") + request = diameter.Request_280() + send_queue.put(bytes.fromhex(request)) + # clientsocket.sendall(bytes.fromhex(request)) # Send it + self.logger.debug("Sent Keepalive to " + str(client_address) + "...") + except socket.error as e: + self.logger.error( + f"manage_client_dwr: Socket error for client {client_address}: {e}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + return + except Exception as e: + self.logger.error( + f"manage_client_dwr: General exception for client {client_address}: {e}" + ) + self.terminate_connection(clientsocket, client_address, thread_event) + + def get_socket_family(self): + if ":" in self.yaml_config["hss"]["bind_ip"][0]: + self.logger.info("IPv6 Address Specified") + return socket.AF_INET6 + else: + self.logger.info("IPv4 Address Specified") + return socket.AF_INET + + def send_data(self, clientsocket, send_queue, thread_event): + while not thread_event.is_set(): + try: + data = send_queue.get(timeout=1) + # Check if data is bytes, otherwise convert it using bytes.fromhex() + if not isinstance(data, bytes): + data = bytes.fromhex(data) + + clientsocket.sendall(data) + except ( + queue.Empty + ): # Catch the Empty exception when the queue is empty and the timeout has expired + continue + except Exception as e: + self.logger.error(f"send_data_thread: Exception: {e}") + return + + def start_server(self): + if self.yaml_config["hss"]["transport"] == "SCTP": + self.logger.debug("Using SCTP for Transport") + # Create a SCTP socket + sock = sctp.sctpsocket_tcp(self.get_socket_family()) + sock.initparams.num_ostreams = 64 + # Loop through the possible Binding IPs from the config and bind to each for Multihoming + server_addresses = [] + + # Prepend each entry into list, so the primary IP is bound first + for host in self.yaml_config["hss"]["bind_ip"]: + self.logger.info("Seting up SCTP binding on IP address " + str(host)) + this_IP_binding = [ + (str(host), int(self.yaml_config["hss"]["bind_port"])) + ] + server_addresses = this_IP_binding + server_addresses + + print("server_addresses are: " + str(server_addresses)) + sock.bindx(server_addresses) + self.logger.info("PyHSS listening on SCTP port " + str(server_addresses)) + systemd.daemon.notify("READY=1") + # Listen for up to 20 incoming SCTP connections + sock.listen(20) + elif self.yaml_config["hss"]["transport"] == "TCP": + self.logger.debug("Using TCP socket") + # Create a TCP/IP socket + sock = socket.socket(self.get_socket_family(), socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Bind the socket to the port + server_address = ( + str(self.yaml_config["hss"]["bind_ip"][0]), + int(self.yaml_config["hss"]["bind_port"]), + ) + sock.bind(server_address) + self.logger.debug( + "PyHSS listening on TCP port " + + str(self.yaml_config["hss"]["bind_ip"][0]) + ) + systemd.daemon.notify("READY=1") + # Listen for up to 20 incoming TCP connections + sock.listen(20) + else: + self.logger.error("No valid transports found (No SCTP or TCP) - Exiting") + quit() + + while True: + # Wait for a connection + self.logger.info("Waiting for a connection...") + connection, client_address = sock.accept() + _thread.start_new_thread( + self.handle_new_connection, + ( + connection, + client_address, + ), + ) + + + def prom_async_function(self): + while True: + self.logger.debug("Running prom_async_function") + self.diameter_instance.Generate_Prom_Stats() + time.sleep(120) + + +if __name__ == "__main__": + pyHss = PyHSS() + pyHss.start_server() diff --git a/services/diameterService.py b/services/diameterService.py index 32db996..d9c45e4 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,71 +1,103 @@ import asyncio import sctp, socket -import sys, os, binascii -import time +import sys, os, json +import time, yaml sys.path.append(os.path.realpath('../lib')) -from messaging import RedisMessaging +from messagingAsync import RedisMessagingAsync from diameter import Diameter +from banners import Banners +from logtool import LogTool - -class DiameterService(): +class DiameterService: """ PyHSS Diameter Service A class for handling diameter requests and replies on Port 3868, via TCP or SCTP. """ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): - self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Diameter] Fatal Error - config.yaml not found, exiting.") + quit() + + self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.diameterLibrary = Diameter() - pass + self.banners = Banners() + self.logTool = LogTool() + self.diameterLogger = self.logTool.setupLogger(loggerName='Diameter', config=self.config) + self.socketTimeout = int(self.config.get('hss', {}).get('client_socket_timeout', 300)) def validateDiameterRequest(self, requestData) -> bool: try: packetVars, avps = self.diameterLibrary.decode_diameter_packet(requestData) originHost = self.diameterLibrary.get_avp_data(avps, 264)[0] - originHost = binascii.unhexlify(originHost).decode("utf-8") + originHost = bytes.fromhex(originHost).decode("utf-8") except Exception as e: return False return True - async def readRequestData(self, reader, clientAddress: str, clientPort: str) -> bool: - requestQueueName = f"{clientAddress}-{clientPort}-requests" - print("In readRequestData") + async def readRequestData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int) -> bool: + self.diameterLogger.info(f"[Diameter] New connection from {clientAddress} on port {clientPort}") while True: - requestData = await reader.read(1024) - if len(requestData) > 0: - print(f"Received data from {clientAddress} on port {clientPort}") - print(f"Data: {binascii.hexlify(requestData)}") - - if not self.validateDiameterRequest(requestData): - print(f"Invalid Diameter Request.") - break + try: + requestData = await asyncio.wait_for(reader.read(1024), timeout=socketTimeout) + if len(requestData) > 0: + self.diameterLogger.debug(f"[Diameter] Received data from {clientAddress} on port {clientPort}") + + if not self.validateDiameterRequest(requestData): + self.diameterLogger.debug(f"[Diameter] Invalid Diameter Request, terminating connection.") + return False - requestHexString = binascii.hexlify(requestData) - print(requestHexString) - self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString) + requestQueueName = f"diameter-request-{clientAddress}-{clientPort}-{time.time_ns()}" + requestHexString = json.dumps({f"diameter-request": requestData.hex()}) + self.diameterLogger.debug(f"[Diameter] Queueing {requestHexString}") + await(self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString)) + except asyncio.TimeoutError: + self.diameterLogger.info(f"[Diameter] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") + return False async def writeResponseData(self, writer, clientAddress: str, clientPort: str) -> bool: - responseQueueName = f"{clientAddress}-{clientPort}-responses" - print("In writeResponseData") - + self.diameterLogger.debug(f"[Diameter] writeResponseData with host {clientAddress} on port {clientPort}") while True: - responseHexString = self.redisMessaging.getMessage(queue=responseQueueName) - if not len(responseHexString) > 0: + try: + pendingResponseQueues = await(self.redisMessaging.getQueues()) + if not len(pendingResponseQueues) > 0: + assert() + for responseQueue in pendingResponseQueues: + queuedMessageType = str(responseQueue).split('-')[1] + diameterResponseHost = str(responseQueue).split('-')[2] + diameterResponsePort = str(responseQueue).split('-')[3] + if str(diameterResponseHost) == str(clientAddress) and str(diameterResponsePort) == str(clientPort) and queuedMessageType == 'response': + self.diameterLogger.debug(f"[Diameter] Matched {responseQueue} to host {clientAddress} on port {clientPort}") + try: + diameterResponse = json.loads(await(self.redisMessaging.getMessage(queue=responseQueue))) + self.diameterLogger.debug(f"[Diameter] Attempting to send outbound response to {clientAddress} on {clientPort}.") + diameterResponseBinary = bytes.fromhex(next(iter(diameterResponse.values()))) + self.diameterLogger.debug(f"[Diameter] Sending: {diameterResponseBinary.hex()} to to {clientAddress} on {clientPort}.") + writer.write(diameterResponseBinary) + await writer.drain() + except Exception as e: + print(e) + except ConnectionError: + self.diameterLogger.info(f"[Diameter] Connection closed for {clientAddress} on port {clientPort}, closing writer.") + return False + except Exception as e: await asyncio.sleep(0.005) continue - diameterResponse = f'Received diameter request successfully.' - print(f"Sending: {diameterResponse}") - writer.write(diameterResponse) - await writer.drain() - async def handleConnection(self, reader, writer): (clientAddress, clientPort) = writer.get_extra_info('peername') - if not await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort), - self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort)): - print("Closing Connection") + self.diameterLogger.debug(f"[Diameter] Initial Connection from: {clientAddress} on port {clientPort}") + + if False in await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout), + self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort)): + self.diameterLogger.debug(f"[Diameter] Closing Writer for {clientAddress} on port {clientPort}.") writer.close() + await writer.wait_closed() + self.diameterLogger.debug(f"[Diameter] Closed Writer for {clientAddress} on port {clientPort}.") return async def startServer(self, host: str='0.0.0.0', port: int=3868, type: str='TCP'): @@ -77,7 +109,8 @@ async def startServer(self, host: str='0.0.0.0', port: int=3868, type: str='TCP' else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) - print(f'Serving on {servingAddresses}') + self.diameterLogger.info(self.banners.diameterService()) + self.diameterLogger.info(f'[Diameter] Serving on {servingAddresses}') async with server: await server.serve_forever() diff --git a/services/hssService.py b/services/hssService.py new file mode 100644 index 0000000..7fdcfc9 --- /dev/null +++ b/services/hssService.py @@ -0,0 +1,71 @@ +import os, sys, json, yaml +import time, logging +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from diameter import Diameter +from banners import Banners +from logtool import LogTool + +class HssService: + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[HSS] Fatal Error - config.yaml not found, exiting.") + quit() + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.logTool = LogTool() + self.banners = Banners() + self.hssLogger = self.logTool.setupLogger(loggerName='HSS', config=self.config) + self.mnc = self.config.get('hss', {}).get('MNC', '999') + self.mcc = self.config.get('hss', {}).get('MCC', '999') + self.originRealm = self.config.get('hss', {}).get('OriginRealm', f'mnc{self.mnc}.mcc{self.mcc}.3gppnetwork.org') + self.originHost = self.config.get('hss', {}).get('OriginHost', f'hss01') + self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') + self.diameterLibrary = Diameter(originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) + self.hssLogger.info(self.banners.hssService()) + + + + def handleOutboundResponse(self, queue: str, diameterResponse: str): + self.redisMessaging.sendMessage(queue=queue, message=diameterResponse, queueExpiry=60) + + def handleRequestQueue(self): + try: + requestQueue = self.redisMessaging.getNextQueue(pattern='diameter-request*') + requestMessage = self.redisMessaging.getMessage(queue=requestQueue) + assert(len(requestMessage)) + self.hssLogger.debug(f"[HSS] Inbound Diameter Request Queue: {requestQueue}") + self.hssLogger.debug(f"[HSS] Inbound Diameter Request: {requestMessage}") + + requestDict = json.loads(requestMessage) + requestBinary = bytes.fromhex(next(iter(requestDict.values()))) + requestHost = str(requestQueue).split('-')[2] + requestPort = str(requestQueue).split('-')[3] + requestTimestamp = str(requestQueue).split('-')[4] + + diameterResponse = self.diameterLibrary.generateDiameterResponse(requestBinaryData=requestBinary) + self.hssLogger.debug(f"[HSS] Generated Diameter Response: {diameterResponse}") + if not len(diameterResponse) > 0: + return False + + outboundResponseQueue = f"diameter-response-{requestHost}-{requestPort}-{requestTimestamp}" + outboundResponse = json.dumps({"diameter-response": diameterResponse}) + + self.hssLogger.debug(f"[HSS] Outbound Diameter Response Queue: {outboundResponseQueue}") + self.hssLogger.debug(f"[HSS] Outbound Diameter Response: {outboundResponse}") + + self.handleOutboundResponse(queue=outboundResponseQueue, diameterResponse=outboundResponse) + time.sleep(0.005) + + except Exception as e: + return False + + +if __name__ == '__main__': + hssService = HssService() + while True: + hssService.handleRequestQueue() \ No newline at end of file From bfbbadcaa5d9b4eac36b3a80d2c8aed7509bf6f5 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 23 Aug 2023 15:26:35 +1000 Subject: [PATCH 03/43] Performance tuned diameterService.py, add GeoredService --- lib/database.py | 638 ++++++++++++++-------------------- lib/logtool.py | 2 +- lib/messaging.py | 28 ++ lib/messagingAsync.py | 51 ++- services/diameterService.py | 64 ++-- services/georedService.py | 139 ++++++++ services/hssService.py | 71 ++-- services/metricService.py | 100 ++++++ services/prometheusService.py | 0 9 files changed, 650 insertions(+), 443 deletions(-) create mode 100644 services/metricService.py delete mode 100644 services/prometheusService.py diff --git a/lib/database.py b/lib/database.py index e27583d..aa400d6 100755 --- a/lib/database.py +++ b/lib/database.py @@ -16,27 +16,24 @@ import socket import traceback from contextlib import contextmanager -import logging import pprint from construct import Default import S6a_crypt import requests -from requests.exceptions import ConnectionError, Timeout -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry import threading +from logtool import LogTool +from messaging import RedisMessaging import yaml with open("../config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -# logtool = logtool.LogTool() -# logtool.setup_logger('DBLogger', yaml_config['logging']['logfiles']['database_logging_file'], level=yaml_config['logging']['level']) -DBLogger = logging.getLogger('DBLogger') -DBLogger.info("DB Log Initialised.") +logTool = LogTool() +dbLogger = logTool.setupLogger(loggerName='Database', config=yaml_config) +dbLogger.info("DB Log Initialised.") +redisMessaging = RedisMessaging() db_string = 'mysql://' + str(yaml_config['database']['username']) + ':' + str(yaml_config['database']['password']) + '@' + str(yaml_config['database']['server']) + '/' + str(yaml_config['database']['database'] + "?autocommit=true") -# print(db_string) engine = create_engine( db_string, echo = yaml_config['logging'].get('sqlalchemy_sql_echo', True), @@ -282,54 +279,51 @@ class SUBSCRIBER_ATTRIBUTES(Base): # Create database if it does not exist. if not database_exists(engine.url): - DBLogger.debug("Creating database") + dbLogger.debug("Creating database") create_database(engine.url) Base.metadata.create_all(engine) else: - DBLogger.debug("Database already created") + dbLogger.debug("Database already created") def load_IMEI_database_into_Redis(): - return - #@@Fixme - # try: - # DBLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) - # csvfile = open(str(yaml_config['eir']['tac_database_csv'])) - # DBLogger.info("This may take a few seconds to buffer into Redis...") - # except: - # DBLogger.error("Failed to read CSV file of IMEI TAC database") - # return - # try: - # count = 0 - # for line in csvfile: - # line = line.replace('"', '') #Strip excess invered commas - # line = line.replace("'", '') #Strip excess invered commas - # line = line.rstrip() #Strip newlines - # result = line.split(',') - # tac_prefix = result[0] - # name = result[1].lstrip() - # model = result[2].lstrip() - # if count == 0: - # DBLogger.info("Checking to see if entries are already present...") - # #DBLogger.info("Searching Redis for key " + str(tac_prefix) + " to see if data already provisioned") - # redis_imei_result = logtool.RedisHMGET(key=str(tac_prefix)) - # if len(redis_imei_result) != 0: - # DBLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") - # break - # else: - # DBLogger.info("No data loaded into Redis, proceeding to load...") - # imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - # logtool.RedisHMSET(key=str(tac_prefix), value_dict=imei_result) - # count = count +1 - # DBLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") - # except Exception as E: - # DBLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) - # return + try: + dbLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) + csvfile = open(str(yaml_config['eir']['tac_database_csv'])) + dbLogger.info("This may take a few seconds to buffer into Redis...") + except: + dbLogger.error("Failed to read CSV file of IMEI TAC database") + return + try: + count = 0 + for line in csvfile: + line = line.replace('"', '') #Strip excess invered commas + line = line.replace("'", '') #Strip excess invered commas + line = line.rstrip() #Strip newlines + result = line.split(',') + tac_prefix = result[0] + name = result[1].lstrip() + model = result[2].lstrip() + if count == 0: + dbLogger.info("Checking to see if entries are already present...") + redis_imei_result = redisMessaging.getMessage(key=str(tac_prefix)) + if len(redis_imei_result) != 0: + dbLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") + break + else: + dbLogger.info("No data loaded into Redis, proceeding to load...") + imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} + redisMessaging.sendMessage(key=str(tac_prefix), value_dict=imei_result) + count = count +1 + dbLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") + except Exception as E: + dbLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) + return #Load IMEI TAC database into Redis if enabled if ('tac_database_csv' in yaml_config['eir']) and (yaml_config['redis']['enabled'] == True): load_IMEI_database_into_Redis() else: - DBLogger.info("Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config") + dbLogger.info("Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config") def safe_rollback(session): @@ -337,14 +331,14 @@ def safe_rollback(session): if session.is_active: session.rollback() except Exception as E: - DBLogger.error(f"Failed to rollback session, error: {E}") + dbLogger.error(f"Failed to rollback session, error: {E}") def safe_close(session): try: if session.is_active: session.close() except Exception as E: - DBLogger.error(f"Failed to run safe_close on session, error: {E}") + dbLogger.error(f"Failed to run safe_close on session, error: {E}") def sqlalchemy_type_to_json_schema_type(sqlalchemy_type): """ @@ -388,10 +382,10 @@ def generate_json_schema(model_class, required=None): inspector = Inspector.from_engine(engine) for table_name in Base.metadata.tables.keys(): if table_name not in inspector.get_table_names(): - DBLogger.debug(f"Creating table {table_name}") + dbLogger.debug(f"Creating table {table_name}") Base.metadata.tables[table_name].create(bind=engine) else: - DBLogger.debug(f"Table {table_name} already exists") + dbLogger.debug(f"Table {table_name} already exists") def update_old_record(session, operation_log): oldest_log = session.query(OPERATION_LOG_BASE).order_by(OPERATION_LOG_BASE.timestamp.asc()).first() @@ -404,41 +398,6 @@ def update_old_record(session, operation_log): else: raise ValueError("Unable to find record to update") -def notify_webhook(operation, external_webhook_notification_url, externalNotification, externalNotificationHeaders): - try: - if operation == 'UPDATE': - webhookResponse = requests.patch(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - elif operation == 'DELETE': - webhookResponse = requests.delete(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - elif operation == 'CREATE': - webhookResponse = requests.put(external_webhook_notification_url, json=externalNotification, headers=externalNotificationHeaders, timeout=5) - except requests.exceptions.Timeout: - DBLogger.error(f"Timeout occurred when sending webhook to {external_webhook_notification_url}") - return False - except requests.exceptions.RequestException as e: - DBLogger.error(f"Request exception when sending webhook to {external_webhook_notification_url}") - return False - - if webhookResponse.status_code != 200: - DBLogger.error(f"Response code from external webhook at {external_webhook_notification_url} is != 200.\nResponse Code is: {webhookResponse.status_code}\nResponse Body is: {webhookResponse.content}") - return False - return True - -def handle_external_webhook(objectData, operation): - external_webhook_notification_enabled = yaml_config.get('external', {}).get('external_webhook_notification_enabled', False) - external_webhook_notification_url = yaml_config.get('external', {}).get('external_webhook_notification_url', '') - if not external_webhook_notification_enabled: - return False - if not external_webhook_notification_url: - DBLogger.error("External webhook notification enabled, but external_webhook_notification_url is not defined.") - - externalNotification = Sanitize_Datetime(objectData) - externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} - - # Using separate thread to process webhook - threading.Thread(target=notify_webhook, args=(operation, external_webhook_notification_url, externalNotification, externalNotificationHeaders), daemon=True).start() - return True - def log_change(session, item_id, operation, changes, table_name, operation_id, generated_id=None): # We don't want to log rollback operations if session.info.get("operation") == 'ROLLBACK': @@ -465,7 +424,7 @@ def log_change(session, item_id, operation, changes, table_name, operation_id, g session.add(change) session.flush() except Exception as E: - DBLogger.error("Failed to commit changelog, error: " + str(E)) + dbLogger.error("Failed to commit changelog, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -506,11 +465,11 @@ def log_changes_before_commit(session): changes = [] for attr in class_mapper(obj.__class__).column_attrs: hist = get_history(obj, attr.key) - DBLogger.info(f"History {hist}") + dbLogger.info(f"History {hist}") if hist.has_changes() and hist.added and hist.deleted: old_value, new_value = hist.deleted[0], hist.added[0] - DBLogger.info(f"Old Value {old_value}") - DBLogger.info(f"New Value {new_value}") + dbLogger.info(f"Old Value {old_value}") + dbLogger.info(f"New Value {new_value}") changes.append((attr.key, old_value, new_value)) continue @@ -628,11 +587,11 @@ def rollback_last_change(existingSession=None): # Extract type and value old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - DBLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") + dbLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") old_value = str_to_type(old_type_str, old_value_repr) old_values_dict[column_name] = old_value - DBLogger.error("old_value_dict: " + str(old_values_dict)) + dbLogger.error("old_value_dict: " + str(old_values_dict)) if not target_item: try: @@ -653,7 +612,7 @@ def rollback_last_change(existingSession=None): session.commit() safe_close(session) except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) + dbLogger.error("rollback_last_change error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -661,7 +620,7 @@ def rollback_last_change(existingSession=None): return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) + dbLogger.error("rollback_last_change error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -732,11 +691,11 @@ def rollback_change_by_operation_id(operation_id, existingSession=None): # Extract type and value old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - DBLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") + dbLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") old_value = str_to_type(old_type_str, old_value_repr) old_values_dict[column_name] = old_value - DBLogger.error("old_value_dict: " + str(old_values_dict)) + dbLogger.error("old_value_dict: " + str(old_values_dict)) if not target_item: try: @@ -757,7 +716,7 @@ def rollback_change_by_operation_id(operation_id, existingSession=None): session.commit() safe_close(session) except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) + dbLogger.error("rollback_last_change error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -765,7 +724,7 @@ def rollback_change_by_operation_id(operation_id, existingSession=None): return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) except Exception as E: - DBLogger.error("rollback_last_change error: " + str(E)) + dbLogger.error("rollback_last_change error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -801,8 +760,8 @@ def get_all_operation_logs(page=0, page_size=yaml_config['api'].get('page_size', safe_close(session) return all_operations except Exception as E: - DBLogger.error(f"get_all_operation_logs error: {E}") - DBLogger.error(E) + dbLogger.error(f"get_all_operation_logs error: {E}") + dbLogger.error(E) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -838,8 +797,8 @@ def get_all_operation_logs_by_table(table_name, page=0, page_size=yaml_config['a safe_close(session) return all_operations except Exception as E: - DBLogger.error(f"get_all_operation_logs_by_table error: {E}") - DBLogger.error(E) + dbLogger.error(f"get_all_operation_logs_by_table error: {E}") + dbLogger.error(E) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -869,103 +828,34 @@ def get_last_operation_log(existingSession=None): safe_close(session) return None except Exception as E: - DBLogger.error(f"get_last_operation_log error: {E}") - DBLogger.error(E) + dbLogger.error(f"get_last_operation_log error: {E}") + dbLogger.error(E) safe_rollback(session) safe_close(session) raise ValueError(E) - - -def GeoRed_Push_Request(remote_hss, json_data, transaction_id, url=None): - headers = {"Content-Type": "application/json", "Transaction-Id": str(transaction_id)} - DBLogger.debug("transaction_id: " + str(transaction_id) + " pushing update to " + str(remote_hss).replace('http://', '')) - #@@Fixme - # try: - # session = requests.Session() - # # Create a Retry object with desired parameters - # retries = Retry(total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]) - - # # Create an HTTPAdapter and pass the Retry object - # adapter = HTTPAdapter(max_retries=retries) - - # session.mount('http://', adapter) - # if url == None: - # endpoint = 'geored' - # r = session.patch(str(remote_hss) + '/geored/', data=json.dumps(json_data), headers=headers) - # else: - # endpoint = url.split('/', 1)[0] - # r = session.patch(url, data=json.dumps(json_data), headers=headers) - # DBLogger.debug("transaction_id: " + str(transaction_id) + " updated on " + str(remote_hss).replace('http://', '') + " with status code " + str(r.status_code)) - # if str(r.status_code).startswith('2'): - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code=str(r.status_code), - # error="" - # ).inc() - # else: - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code=str(r.status_code), - # error=str(r.reason) - # ).inc() - # except ConnectionError as e: - # error_message = str(e) - # if "Name or service not known" in error_message: - # DBLogger.error("transaction_id: " + str(transaction_id) + " name or service not known") - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code='000', - # error="No matching DNS entry found" - # ).inc() - # else: - # print("Other ConnectionError:", error_message) - # DBLogger.error("transaction_id: " + str(transaction_id) + " " + str(error_message)) - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code='000', - # error="Connection Refused" - # ).inc() - # except Timeout: - # DBLogger.error("transaction_id: " + str(transaction_id) + " timed out connecting to peer " + str(remote_hss).replace('http://', '')) - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code='000', - # error="Timeout" - # ).inc() - # except Exception as e: - # DBLogger.error("transaction_id: " + str(transaction_id) + " unexpected error " + str(e) + " when connecting to peer " + str(remote_hss).replace('http://', '')) - # prom_http_geored.labels( - # geored_host=str(remote_hss).replace('http://', ''), - # endpoint=endpoint, - # http_response_code='000', - # error=str(e) - # ).inc() - return - - - -def GeoRed_Push_Async(json_data): +def handleGeored(jsonData): try: - if yaml_config['geored']['enabled'] == True: - if yaml_config['geored']['sync_endpoints'] is not None and len(yaml_config['geored']['sync_endpoints']) > 0: + if yaml_config.get('geored', {}).get('enabled', False): + if yaml_config.get('geored', {}).get('sync_endpoints', []) is not None and len(yaml_config.get('geored', {}).get('sync_endpoints', [])) > 0: transaction_id = str(uuid.uuid4()) - DBLogger.info("Pushing out data to GeoRed peers with transaction_id " + str(transaction_id) + " and JSON body: " + str(json_data)) - for remote_hss in yaml_config['geored']['sync_endpoints']: - GeoRed_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(remote_hss, json_data, transaction_id)) - GeoRed_Push_thread.start() + redisMessaging.sendMessage(queue=f'geored-{time.time_ns()}', message=jsonData, queueExpiry=120) except Exception as E: - DBLogger.debug("Failed to push Async jobs due to error: " + str(E)) + dbLogger.warning("Failed to send Geored message due to error: " + str(E)) -def Webhook_Push_Async(target, json_data): - transaction_id = str(uuid.uuid4()) - Webook_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(target, json_data, transaction_id)) - Webook_Push_thread.start() +def handleWebhook(objectData, operation): + external_webhook_notification_enabled = yaml_config.get('external', {}).get('external_webhook_notification_enabled', False) + external_webhook_notification_url = yaml_config.get('external', {}).get('external_webhook_notification_url', '') + if not external_webhook_notification_enabled: + return False + if not external_webhook_notification_url: + dbLogger.error("External webhook notification enabled, but external_webhook_notification_url is not defined.") + + externalNotification = Sanitize_Datetime(objectData) + externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} + #@@Fixme + redisMessaging.sendMessage(queue=f'webhook-{time.time_ns()}', message=jsonData, queueExpiry=120) + return True def Sanitize_Datetime(result): for keys in result: @@ -973,7 +863,7 @@ def Sanitize_Datetime(result): if result[keys] == None: continue else: - DBLogger.debug("Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String") + dbLogger.debug("Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String") result[keys] = str(result[keys]) return result @@ -987,7 +877,7 @@ def Sanitize_Keys(result): return result def GetObj(obj_type, obj_id=None, page=None, page_size=None): - DBLogger.debug("Called GetObj for type " + str(obj_type)) + dbLogger.debug("Called GetObj for type " + str(obj_type)) Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) @@ -1024,7 +914,7 @@ def GetObj(obj_type, obj_id=None, page=None, page_size=None): raise ValueError("Provide either obj_id or both page and page_size") except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) + dbLogger.error("Failed to query, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -1033,7 +923,7 @@ def GetObj(obj_type, obj_id=None, page=None, page_size=None): return result def GetAll(obj_type): - DBLogger.debug("Called GetAll for type " + str(obj_type)) + dbLogger.debug("Called GetAll for type " + str(obj_type)) Base.metadata.create_all(engine) Session = sessionmaker(bind = engine) @@ -1043,7 +933,7 @@ def GetAll(obj_type): try: result = session.query(obj_type) except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) + dbLogger.error("Failed to query, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -1059,7 +949,7 @@ def GetAll(obj_type): return final_result_list def getAllPaginated(obj_type, page=0, page_size=0, existingSession=None): - DBLogger.debug("Called getAllPaginated for type " + str(obj_type)) + dbLogger.debug("Called getAllPaginated for type " + str(obj_type)) if not existingSession: Base.metadata.create_all(engine) @@ -1091,14 +981,14 @@ def getAllPaginated(obj_type, page=0, page_size=0, existingSession=None): return final_result_list except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) + dbLogger.error("Failed to query, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) def GetAllByTable(obj_type, table): - DBLogger.debug(f"Called GetAll for type {str(obj_type)} and table {table}") + dbLogger.debug(f"Called GetAll for type {str(obj_type)} and table {table}") Base.metadata.create_all(engine) Session = sessionmaker(bind = engine) @@ -1108,7 +998,7 @@ def GetAllByTable(obj_type, table): try: result = session.query(obj_type).filter_by(table_name=str(table)) except Exception as E: - DBLogger.error("Failed to query, error: " + str(E)) + dbLogger.error("Failed to query, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) @@ -1124,11 +1014,11 @@ def GetAllByTable(obj_type, table): return final_result_list def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=None): - DBLogger.debug(f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}") + dbLogger.debug(f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}") Session = sessionmaker(bind=engine) session = Session() obj_type_str = str(obj_type.__table__.name).upper() - DBLogger.debug(f"obj_type_str is {obj_type_str}") + dbLogger.debug(f"obj_type_str is {obj_type_str}") filter_input = eval(obj_type_str + "." + obj_type_str.lower() + "_id==obj_id") try: obj = session.query(obj_type).filter(filter_input).one() @@ -1137,7 +1027,7 @@ def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=N setattr(obj, key, value) setattr(obj, "last_modified", datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z') except Exception as E: - DBLogger.error(f"Failed to query or update object, error: {E}") + dbLogger.error(f"Failed to query or update object, error: {E}") raise ValueError(E) try: session.info["operation_id"] = operation_id # Pass the operation id @@ -1146,13 +1036,13 @@ def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=N log_changes_before_commit(session) objectData = GetObj(obj_type, obj_id) session.commit() - handle_external_webhook(objectData, 'UPDATE') + handleWebhook(objectData, 'UPDATE') except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") + dbLogger.error(f"Failed to commit session, error: {E}") safe_rollback(session) raise ValueError(E) except Exception as E: - DBLogger.error(f"Exception in UpdateObj, error: {E}") + dbLogger.error(f"Exception in UpdateObj, error: {E}") raise ValueError(E) finally: safe_close(session) @@ -1160,7 +1050,7 @@ def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=N return GetObj(obj_type, obj_id) def DeleteObj(obj_type, obj_id, disable_logging=False, operation_id=None): - DBLogger.debug(f"Called DeleteObj for type {obj_type} with id {obj_id}") + dbLogger.debug(f"Called DeleteObj for type {obj_type} with id {obj_id}") Session = sessionmaker(bind=engine) session = Session() @@ -1176,14 +1066,14 @@ def DeleteObj(obj_type, obj_id, disable_logging=False, operation_id=None): if not disable_logging: log_changes_before_commit(session) session.commit() - handle_external_webhook(objectData, 'DELETE') + handleWebhook(objectData, 'DELETE') except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") + dbLogger.error(f"Failed to commit session, error: {E}") safe_rollback(session) raise ValueError(E) except Exception as E: - DBLogger.error(f"Exception in DeleteObj, error: {E}") + dbLogger.error(f"Exception in DeleteObj, error: {E}") raise ValueError(E) finally: safe_close(session) @@ -1192,7 +1082,7 @@ def DeleteObj(obj_type, obj_id, disable_logging=False, operation_id=None): def CreateObj(obj_type, json_data, disable_logging=False, operation_id=None): - DBLogger.debug("Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data)) + dbLogger.debug("Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data)) last_modified_value = datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z' json_data["last_modified"] = last_modified_value # set last_modified value in json_data newObj = obj_type(**json_data) @@ -1207,22 +1097,22 @@ def CreateObj(obj_type, json_data, disable_logging=False, operation_id=None): log_changes_before_commit(session) session.commit() except Exception as E: - DBLogger.error(f"Failed to commit session, error: {E}") + dbLogger.error(f"Failed to commit session, error: {E}") safe_rollback(session) raise ValueError(E) session.refresh(newObj) result = newObj.__dict__ result.pop('_sa_instance_state') - handle_external_webhook(result, 'CREATE') + handleWebhook(result, 'CREATE') return result except Exception as E: - DBLogger.error(f"Exception in CreateObj, error: {E}") + dbLogger.error(f"Exception in CreateObj, error: {E}") raise ValueError(E) finally: safe_close(session) def Generate_JSON_Model_for_Flask(obj_type): - DBLogger.debug("Generating JSON model for Flask for object type: " + str(obj_type)) + dbLogger.debug("Generating JSON model for Flask for object type: " + str(obj_type)) dictty = dict(generate_json_schema(obj_type)) pprint.pprint(dictty) @@ -1249,14 +1139,14 @@ def Get_AuC(**kwargs): session = Session() if 'iccid' in kwargs: - DBLogger.debug("Get_AuC for iccid " + str(kwargs['iccid'])) + dbLogger.debug("Get_AuC for iccid " + str(kwargs['iccid'])) try: result = session.query(AUC).filter_by(iccid=str(kwargs['iccid'])).one() except Exception as E: safe_close(session) raise ValueError(E) elif 'imsi' in kwargs: - DBLogger.debug("Get_AuC for imsi " + str(kwargs['imsi'])) + dbLogger.debug("Get_AuC for imsi " + str(kwargs['imsi'])) try: result = session.query(AUC).filter_by(imsi=str(kwargs['imsi'])).one() except Exception as E: @@ -1267,7 +1157,7 @@ def Get_AuC(**kwargs): result = Sanitize_Datetime(result) result.pop('_sa_instance_state') - DBLogger.debug("Got back result: " + str(result)) + dbLogger.debug("Got back result: " + str(result)) safe_close(session) return result @@ -1276,27 +1166,27 @@ def Get_IMS_Subscriber(**kwargs): Session = sessionmaker(bind = engine) session = Session() if 'msisdn' in kwargs: - DBLogger.debug("Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn'])) + dbLogger.debug("Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn'])) try: result = session.query(IMS_SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() except Exception as E: safe_close(session) raise ValueError(E) elif 'imsi' in kwargs: - DBLogger.debug("Get_IMS_Subscriber for imsi " + str(kwargs['imsi'])) + dbLogger.debug("Get_IMS_Subscriber for imsi " + str(kwargs['imsi'])) try: result = session.query(IMS_SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() except Exception as E: safe_close(session) raise ValueError(E) - DBLogger.debug("Converting result to dict") + dbLogger.debug("Converting result to dict") result = result.__dict__ try: result.pop('_sa_instance_state') except: pass result = Sanitize_Datetime(result) - DBLogger.debug("Returning IMS Subscriber Data: " + str(result)) + dbLogger.debug("Returning IMS Subscriber Data: " + str(result)) safe_close(session) return result @@ -1307,14 +1197,14 @@ def Get_Subscriber(**kwargs): session = Session() if 'msisdn' in kwargs: - DBLogger.debug("Get_Subscriber for msisdn " + str(kwargs['msisdn'])) + dbLogger.debug("Get_Subscriber for msisdn " + str(kwargs['msisdn'])) try: result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() except Exception as E: safe_close(session) raise ValueError(E) elif 'imsi' in kwargs: - DBLogger.debug("Get_Subscriber for imsi " + str(kwargs['imsi'])) + dbLogger.debug("Get_Subscriber for imsi " + str(kwargs['imsi'])) try: result = session.query(SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() except Exception as E: @@ -1330,7 +1220,7 @@ def Get_Subscriber(**kwargs): attributes = Get_Subscriber_Attributes(result['subscriber_id']) result['attributes'] = attributes - DBLogger.debug("Got back result: " + str(result)) + dbLogger.debug("Got back result: " + str(result)) safe_close(session) return result @@ -1338,7 +1228,7 @@ def Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id): Session = sessionmaker(bind = engine) session = Session() - DBLogger.debug("Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id)) + dbLogger.debug("Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id)) try: result = session.query(SUBSCRIBER_ROUTING).filter_by(subscriber_id=subscriber_id, apn_id=apn_id).one() except Exception as E: @@ -1349,7 +1239,7 @@ def Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id): result = Sanitize_Datetime(result) result.pop('_sa_instance_state') - DBLogger.debug("Got back result: " + str(result)) + dbLogger.debug("Got back result: " + str(result)) safe_close(session) return result @@ -1359,7 +1249,7 @@ def Get_Subscriber_Attributes(subscriber_id): Session = sessionmaker(bind = engine) session = Session() - DBLogger.debug("Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id)) + dbLogger.debug("Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id)) try: result = session.query(SUBSCRIBER_ATTRIBUTES).filter_by(subscriber_id=subscriber_id) except Exception as E: @@ -1371,13 +1261,13 @@ def Get_Subscriber_Attributes(subscriber_id): result = Sanitize_Datetime(result) result.pop('_sa_instance_state') final_res.append(result) - DBLogger.debug("Got back result: " + str(final_res)) + dbLogger.debug("Got back result: " + str(final_res)) safe_close(session) return final_res def Get_Served_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this HSS") + dbLogger.debug("Getting all subscribers served by this HSS") Session = sessionmaker(bind = engine) session = Session() @@ -1387,41 +1277,41 @@ def Get_Served_Subscribers(get_local_users_only=False): results = session.query(SUBSCRIBER).filter(SUBSCRIBER.serving_mme.isnot(None)) for result in results: result = result.__dict__ - DBLogger.debug("Result: " + str(result) + " type: " + str(type(result))) + dbLogger.debug("Result: " + str(result) + " type: " + str(type(result))) result = Sanitize_Datetime(result) result.pop('_sa_instance_state') if get_local_users_only == True: - DBLogger.debug("Filtering to locally served IMS Subs only") + dbLogger.debug("Filtering to locally served IMS Subs only") try: serving_hss = result['serving_mme_peer'].split(';')[1] - DBLogger.debug("Serving HSS: " + str(serving_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) + dbLogger.debug("Serving HSS: " + str(serving_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) if serving_hss == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving HSS matches local HSS") + dbLogger.debug("Serving HSS matches local HSS") Served_Subs[result['imsi']] = {} Served_Subs[result['imsi']] = result - #DBLogger.debug("Processed result") + #dbLogger.debug("Processed result") continue else: - DBLogger.debug("Sub is served by remote HSS: " + str(serving_hss)) + dbLogger.debug("Sub is served by remote HSS: " + str(serving_hss)) except Exception as E: - DBLogger.debug("Error in filtering Get_Served_Subscribers to local peer only: " + str(E)) + dbLogger.debug("Error in filtering Get_Served_Subscribers to local peer only: " + str(E)) continue else: Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") + dbLogger.debug("Processed result") except Exception as E: safe_close(session) raise ValueError(E) - DBLogger.debug("Final Served_Subs: " + str(Served_Subs)) + dbLogger.debug("Final Served_Subs: " + str(Served_Subs)) safe_close(session) return Served_Subs def Get_Served_IMS_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this IMS-HSS") + dbLogger.debug("Getting all subscribers served by this IMS-HSS") Session = sessionmaker(bind=engine) session = Session() @@ -1432,40 +1322,40 @@ def Get_Served_IMS_Subscribers(get_local_users_only=False): IMS_SUBSCRIBER.scscf.isnot(None)) for result in results: result = result.__dict__ - DBLogger.debug("Result: " + str(result) + + dbLogger.debug("Result: " + str(result) + " type: " + str(type(result))) result = Sanitize_Datetime(result) result.pop('_sa_instance_state') if get_local_users_only == True: - DBLogger.debug("Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only") + dbLogger.debug("Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only") try: serving_ims_hss = result['scscf_peer'].split(';')[1] - DBLogger.debug("Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) + dbLogger.debug("Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) if serving_ims_hss == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving IMS-HSS matches local HSS for " + str(result['imsi'])) + dbLogger.debug("Serving IMS-HSS matches local HSS for " + str(result['imsi'])) Served_Subs[result['imsi']] = {} Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") + dbLogger.debug("Processed result") continue else: - DBLogger.debug("Sub is served by remote IMS-HSS: " + str(serving_ims_hss)) + dbLogger.debug("Sub is served by remote IMS-HSS: " + str(serving_ims_hss)) except Exception as E: - DBLogger.debug("Error in filtering to local peer only: " + str(E)) + dbLogger.debug("Error in filtering to local peer only: " + str(E)) continue else: Served_Subs[result['imsi']] = result - DBLogger.debug("Processed result") + dbLogger.debug("Processed result") except Exception as E: safe_close(session) raise ValueError(E) - DBLogger.debug("Final Served_Subs: " + str(Served_Subs)) + dbLogger.debug("Final Served_Subs: " + str(Served_Subs)) safe_close(session) return Served_Subs def Get_Served_PCRF_Subscribers(get_local_users_only=False): - DBLogger.debug("Getting all subscribers served by this PCRF") + dbLogger.debug("Getting all subscribers served by this PCRF") Session = sessionmaker(bind=engine) session = Session() Served_Subs = {} @@ -1473,47 +1363,47 @@ def Get_Served_PCRF_Subscribers(get_local_users_only=False): results = session.query(SERVING_APN).all() for result in results: result = result.__dict__ - DBLogger.debug("Result: " + str(result) + " type: " + str(type(result))) + dbLogger.debug("Result: " + str(result) + " type: " + str(type(result))) result = Sanitize_Datetime(result) result.pop('_sa_instance_state') if get_local_users_only == True: - DBLogger.debug("Filtering to locally served IMS Subs only") + dbLogger.debug("Filtering to locally served IMS Subs only") try: serving_pcrf = result['serving_pgw_peer'].split(';')[1] - DBLogger.debug("Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(yaml_config['hss']['OriginHost'])) + dbLogger.debug("Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(yaml_config['hss']['OriginHost'])) if serving_pcrf == yaml_config['hss']['OriginHost']: - DBLogger.debug("Serving PCRF matches local PCRF") - DBLogger.debug("Processed result") + dbLogger.debug("Serving PCRF matches local PCRF") + dbLogger.debug("Processed result") else: - DBLogger.debug("Sub is served by remote PCRF: " + str(serving_pcrf)) + dbLogger.debug("Sub is served by remote PCRF: " + str(serving_pcrf)) continue except Exception as E: - DBLogger.debug("Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E)) + dbLogger.debug("Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E)) continue # Get APN Info apn_info = GetObj(APN, result['apn']) - #DBLogger.debug("Got APN Info: " + str(apn_info)) + #dbLogger.debug("Got APN Info: " + str(apn_info)) result['apn_info'] = apn_info # Get Subscriber Info subscriber_info = GetObj(SUBSCRIBER, result['subscriber_id']) result['subscriber_info'] = subscriber_info - #DBLogger.debug("Got Subscriber Info: " + str(subscriber_info)) + #dbLogger.debug("Got Subscriber Info: " + str(subscriber_info)) Served_Subs[subscriber_info['imsi']] = result - DBLogger.debug("Processed result") + dbLogger.debug("Processed result") except Exception as E: raise ValueError(E) - #DBLogger.debug("Final SERVING_APN: " + str(Served_Subs)) + #dbLogger.debug("Final SERVING_APN: " + str(Served_Subs)) safe_close(session) return Served_Subs def Get_Vectors_AuC(auc_id, action, **kwargs): - DBLogger.debug("Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action)) + dbLogger.debug("Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action)) key_data = GetObj(AUC, auc_id) vector_dict = {} @@ -1530,17 +1420,17 @@ def Get_Vectors_AuC(auc_id, action, **kwargs): return vector_dict elif action == "sqn_resync": - DBLogger.debug("Resync SQN") + dbLogger.debug("Resync SQN") rand = kwargs['rand'] sqn, mac_s = S6a_crypt.generate_resync_s6a(key_data['ki'], key_data['opc'], key_data['amf'], kwargs['auts'], rand) - DBLogger.debug("SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")") + dbLogger.debug("SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")") Update_AuC(auc_id, sqn=sqn+100) return elif action == "sip_auth": rand, autn, xres, ck, ik = S6a_crypt.generate_maa_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) - DBLogger.debug("RAND is: " + str(rand)) - DBLogger.debug("AUTN is: " + str(autn)) + dbLogger.debug("RAND is: " + str(rand)) + dbLogger.debug("AUTN is: " + str(autn)) vector_dict['SIP_Authenticate'] = rand + autn vector_dict['xres'] = xres vector_dict['ck'] = ck @@ -1549,8 +1439,8 @@ def Get_Vectors_AuC(auc_id, action, **kwargs): return vector_dict elif action == "Digest-MD5": - DBLogger.debug("Generating Digest-MD5 Auth vectors") - DBLogger.debug("key_data: " + str(key_data)) + dbLogger.debug("Generating Digest-MD5 Auth vectors") + dbLogger.debug("key_data: " + str(key_data)) nonce = uuid.uuid4().hex #nonce = "beef4d878f2642ed98afe491b943ca60" vector_dict['nonce'] = nonce @@ -1558,7 +1448,7 @@ def Get_Vectors_AuC(auc_id, action, **kwargs): return vector_dict def Get_APN(apn_id): - DBLogger.debug("Getting APN " + str(apn_id)) + dbLogger.debug("Getting APN " + str(apn_id)) Session = sessionmaker(bind = engine) session = Session() @@ -1573,7 +1463,7 @@ def Get_APN(apn_id): return result def Get_APN_by_Name(apn): - DBLogger.debug("Getting APN named " + str(apn_id)) + dbLogger.debug("Getting APN named " + str(apn_id)) Session = sessionmaker(bind = engine) session = Session() try: @@ -1587,36 +1477,36 @@ def Get_APN_by_Name(apn): return result def Update_AuC(auc_id, sqn=1): - DBLogger.debug("Updating AuC record for sub " + str(auc_id)) - DBLogger.debug(UpdateObj(AUC, {'sqn': sqn}, auc_id, True)) + dbLogger.debug("Updating AuC record for sub " + str(auc_id)) + dbLogger.debug(UpdateObj(AUC, {'sqn': sqn}, auc_id, True)) return def Update_Serving_MME(imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, propagate=True): - DBLogger.debug("Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme)) + dbLogger.debug("Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme)) Session = sessionmaker(bind = engine) session = Session() try: result = session.query(SUBSCRIBER).filter_by(imsi=imsi).one() if yaml_config['hss']['CancelLocationRequest_Enabled'] == True: - DBLogger.debug("Evaluating if we should trigger sending a CLR.") + dbLogger.debug("Evaluating if we should trigger sending a CLR.") serving_hss = str(result.serving_mme_peer).split(';',1)[1] serving_mme_peer = str(result.serving_mme_peer).split(';',1)[0] - DBLogger.debug("Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer)) - DBLogger.debug("Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer)) + dbLogger.debug("Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer)) + dbLogger.debug("Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer)) #Evaluate if we need to send a CLR to the old MME if result.serving_mme != None: if str(result.serving_mme) == str(serving_mme): - DBLogger.debug("This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR") + dbLogger.debug("This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR") elif (str(result.serving_mme) != str(serving_mme)): - DBLogger.debug("There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR") + dbLogger.debug("There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR") if serving_hss != yaml_config['hss']['OriginHost']: - DBLogger.debug("This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss)) + dbLogger.debug("This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss)) URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) else: - DBLogger.debug("This subscriber is served by this HSS we need to send a CLR to old MME from this HSS") + dbLogger.debug("This subscriber is served by this HSS we need to send a CLR to old MME from this HSS") URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) - DBLogger.debug("Sending CLR to API at " + str(URL)) + dbLogger.debug("Sending CLR to API at " + str(URL)) json_data = { "DestinationRealm": result.serving_mme_realm, "DestinationHost": result.serving_mme, @@ -1624,23 +1514,23 @@ def Update_Serving_MME(imsi, serving_mme, serving_mme_realm=None, serving_mme_pe "diameterPeer": serving_mme_peer, } - DBLogger.debug("Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data)) + dbLogger.debug("Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data)) transaction_id = str(uuid.uuid4()) GeoRed_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(serving_hss, json_data, transaction_id, URL)) GeoRed_Push_thread.start() else: #No currently serving MME - No action to take - DBLogger.debug("No currently serving MME - No need to send CLR") + dbLogger.debug("No currently serving MME - No need to send CLR") if type(serving_mme) == str: - DBLogger.debug("Updating serving MME & Timestamp") + dbLogger.debug("Updating serving MME & Timestamp") result.serving_mme = serving_mme result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) result.serving_mme_realm = serving_mme_realm result.serving_mme_peer = serving_mme_peer else: #Clear values - DBLogger.debug("Clearing serving MME") + dbLogger.debug("Clearing serving MME") result.serving_mme = None result.serving_mme_timestamp = None result.serving_mme_realm = None @@ -1648,29 +1538,29 @@ def Update_Serving_MME(imsi, serving_mme, serving_mme_realm=None, serving_mme_pe session.commit() objectData = GetObj(SUBSCRIBER, result.subscriber_id) - handle_external_webhook(objectData, 'UPDATE') + handleWebhook(objectData, 'UPDATE') #Sync state change with geored if propagate == True: if 'HSS' in yaml_config['geored'].get('sync_actions', []) and yaml_config['geored'].get('enabled', False) == True: - DBLogger.debug("Propagate MME changes to Geographic PyHSS instances") - GeoRed_Push_Async({ + dbLogger.debug("Propagate MME changes to Geographic PyHSS instances") + handleGeored({ "imsi": str(imsi), "serving_mme": result.serving_mme, "serving_mme_realm": str(result.serving_mme_realm), "serving_mme_peer": str(result.serving_mme_peer) }) else: - DBLogger.debug("Config does not allow sync of HSS events") + dbLogger.debug("Config does not allow sync of HSS events") except Exception as E: - DBLogger.error("Error occurred, rolling back session: " + str(E)) + dbLogger.error("Error occurred, rolling back session: " + str(E)) raise finally: safe_close(session) def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): - DBLogger.debug("Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer)) + dbLogger.debug("Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer)) Session = sessionmaker(bind = engine) session = Session() @@ -1679,7 +1569,7 @@ def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, p try: assert(type(serving_cscf) == str) assert(len(serving_cscf) > 0) - DBLogger.debug("Setting serving CSCF") + dbLogger.debug("Setting serving CSCF") #Strip duplicate SIP prefix before storing serving_cscf = serving_cscf.replace("sip:sip:", "sip:") result.scscf = serving_cscf @@ -1688,7 +1578,7 @@ def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, p result.scscf_peer = str(scscf_peer) except: #Clear values - DBLogger.debug("Clearing serving CSCF") + dbLogger.debug("Clearing serving CSCF") result.scscf = None result.scscf_timestamp = None result.scscf_realm = None @@ -1696,17 +1586,17 @@ def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, p session.commit() objectData = GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) - handle_external_webhook(objectData, 'UPDATE') + handleWebhook(objectData, 'UPDATE') #Sync state change with geored if propagate == True: if 'IMS' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate IMS changes to Geographic PyHSS instances") - GeoRed_Push_Async({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) + dbLogger.debug("Propagate IMS changes to Geographic PyHSS instances") + handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) else: - DBLogger.debug("Config does not allow sync of IMS events") + dbLogger.debug("Config does not allow sync of IMS events") except Exception as E: - DBLogger.error("An error occurred, rolling back session: " + str(E)) + dbLogger.error("An error occurred, rolling back session: " + str(E)) safe_rollback(session) raise finally: @@ -1714,10 +1604,10 @@ def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, p def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): - DBLogger.debug("Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn)) - DBLogger.debug("PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing)) - DBLogger.debug("Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer)) - DBLogger.debug("subscriber_routing: " + str(subscriber_routing)) + dbLogger.debug("Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn)) + dbLogger.debug("PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing)) + dbLogger.debug("Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer)) + dbLogger.debug("subscriber_routing: " + str(subscriber_routing)) #Get Subscriber ID from IMSI subscriber_details = Get_Subscriber(imsi=str(imsi)) @@ -1725,12 +1615,12 @@ def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routi #Split the APN list into a list apn_list = subscriber_details['apn_list'].split(',') - DBLogger.debug("Current APN List: " + str(apn_list)) + dbLogger.debug("Current APN List: " + str(apn_list)) #Remove the default APN from the list try: apn_list.remove(str(subscriber_details['default_apn'])) except: - DBLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") + dbLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") pass #Add default APN in first position apn_list.insert(0, str(subscriber_details['default_apn'])) @@ -1739,11 +1629,11 @@ def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routi for apn_id in apn_list: #Get each APN in List apn_data = Get_APN(apn_id) - DBLogger.debug(apn_data) + dbLogger.debug(apn_data) if str(apn_data['apn']).lower() == str(apn).lower(): - DBLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) + dbLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) break - DBLogger.debug("APN ID is " + str(apn_id)) + dbLogger.debug("APN ID is " + str(apn_id)) json_data = { 'apn' : apn_id, @@ -1758,9 +1648,9 @@ def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routi try: #Check if already a serving APN on record - DBLogger.debug("Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id)) + dbLogger.debug("Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id)) ServingAPN = Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) - DBLogger.debug("Existing Serving APN ID on record, updating") + dbLogger.debug("Existing Serving APN ID on record, updating") try: assert(type(serving_pgw) == str) assert(len(serving_pgw) > 0) @@ -1768,25 +1658,25 @@ def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routi UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'UPDATE') + handleWebhook(objectData, 'UPDATE') except: - DBLogger.debug("Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id'])) + dbLogger.debug("Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id'])) objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'DELETE') + handleWebhook(objectData, 'DELETE') DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) except Exception as E: - DBLogger.info("Failed to update existing APN " + str(E)) + dbLogger.info("Failed to update existing APN " + str(E)) #Create if does not exist CreateObj(SERVING_APN, json_data, True) objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handle_external_webhook(objectData, 'CREATE') + handleWebhook(objectData, 'CREATE') #Sync state change with geored if propagate == True: try: if 'PCRF' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate PCRF changes to Geographic PyHSS instances") - GeoRed_Push_Async({"imsi": str(imsi), + dbLogger.debug("Propagate PCRF changes to Geographic PyHSS instances") + handleGeored({"imsi": str(imsi), 'serving_apn' : str(apn), 'pcrf_session_id': str(pcrf_session_id), 'serving_pgw': str(serving_pgw), @@ -1795,22 +1685,22 @@ def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routi 'subscriber_routing': str(subscriber_routing) }) else: - DBLogger.debug("Config does not allow sync of PCRF events") + dbLogger.debug("Config does not allow sync of PCRF events") except Exception as E: - DBLogger.debug("Nothing synced to Geographic PyHSS instances for event PCRF") + dbLogger.debug("Nothing synced to Geographic PyHSS instances for event PCRF") return def Get_Serving_APN(subscriber_id, apn_id): - DBLogger.debug("Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id)) + dbLogger.debug("Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id)) Session = sessionmaker(bind = engine) session = Session() try: result = session.query(SERVING_APN).filter_by(subscriber_id=subscriber_id, apn=apn_id).first() except Exception as E: - DBLogger.debug(E) + dbLogger.debug(E) safe_close(session) raise ValueError(E) result = result.__dict__ @@ -1820,7 +1710,7 @@ def Get_Serving_APN(subscriber_id, apn_id): return result def Get_Charging_Rule(charging_rule_id): - DBLogger.debug("Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id)) + dbLogger.debug("Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id)) Session = sessionmaker(bind = engine) session = Session() #Get base Rule @@ -1840,57 +1730,57 @@ def Get_Charging_Rule(charging_rule_id): return ChargingRule def Get_Charging_Rules(imsi, apn): - DBLogger.debug("Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn)) + dbLogger.debug("Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn)) #Get Subscriber ID from IMSI subscriber_details = Get_Subscriber(imsi=str(imsi)) #Split the APN list into a list apn_list = subscriber_details['apn_list'].split(',') - DBLogger.debug("Current APN List: " + str(apn_list)) + dbLogger.debug("Current APN List: " + str(apn_list)) #Remove the default APN from the list try: apn_list.remove(str(subscriber_details['default_apn'])) except: - DBLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") + dbLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") pass #Add default APN in first position apn_list.insert(0, str(subscriber_details['default_apn'])) #Get APN ID from APN for apn_id in apn_list: - DBLogger.debug("Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn)) + dbLogger.debug("Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn)) #Get each APN in List apn_data = Get_APN(apn_id) - DBLogger.debug(apn_data) + dbLogger.debug(apn_data) if str(apn_data['apn']).lower() == str(apn).lower(): - DBLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) + dbLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) - DBLogger.debug("Getting charging rule list from " + str(apn_data['charging_rule_list'])) + dbLogger.debug("Getting charging rule list from " + str(apn_data['charging_rule_list'])) ChargingRule = {} ChargingRule['charging_rule_list'] = str(apn_data['charging_rule_list']).split(',') ChargingRule['apn_data'] = apn_data #Get Charging Rules list if apn_data['charging_rule_list'] == None: - DBLogger.debug("No Charging Rule associated with this APN") + dbLogger.debug("No Charging Rule associated with this APN") ChargingRule['charging_rules'] = None return ChargingRule - DBLogger.debug("ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list'])) + dbLogger.debug("ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list'])) #Empty dict for the Charging Rules to go into ChargingRule['charging_rules'] = [] #Add each of the Charging Rules for the APN for individual_charging_rule in ChargingRule['charging_rule_list']: - DBLogger.debug("Getting Charging rule " + str(individual_charging_rule)) + dbLogger.debug("Getting Charging rule " + str(individual_charging_rule)) individual_charging_rule_complete = Get_Charging_Rule(individual_charging_rule) - DBLogger.debug("Got individual_charging_rule_complete: " + str(individual_charging_rule_complete)) + dbLogger.debug("Got individual_charging_rule_complete: " + str(individual_charging_rule_complete)) ChargingRule['charging_rules'].append(individual_charging_rule_complete) - DBLogger.debug("Completed Get_Charging_Rules()") - DBLogger.debug(ChargingRule) + dbLogger.debug("Completed Get_Charging_Rules()") + dbLogger.debug(ChargingRule) return ChargingRule def Get_UE_by_IP(subscriber_routing): - DBLogger.debug("Called Get_UE_by_IP() for IP " + str(subscriber_routing)) + dbLogger.debug("Called Get_UE_by_IP() for IP " + str(subscriber_routing)) Session = sessionmaker(bind = engine) session = Session() @@ -1911,9 +1801,9 @@ def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): #IMSI 14-15 Digits #IMEI 15 Digits #IMEI-SV 2 Digits - DBLogger.debug("Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code)) + dbLogger.debug("Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code)) if yaml_config['eir']['imsi_imei_logging'] != True: - DBLogger.debug("Skipping storing binding") + dbLogger.debug("Skipping storing binding") return #Concat IMEI + IMSI imsi_imei = str(imsi) + "," + str(imei) @@ -1923,7 +1813,7 @@ def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): #Check if exist already & update try: session.query(IMSI_IMEI_HISTORY).filter_by(imsi_imei=imsi_imei).one() - DBLogger.debug("Entry already present for IMSI/IMEI Combo") + dbLogger.debug("Entry already present for IMSI/IMEI Combo") safe_close(session) return except Exception as E: @@ -1932,27 +1822,27 @@ def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): try: session.commit() except Exception as E: - DBLogger.error("Failed to commit session, error: " + str(E)) + dbLogger.error("Failed to commit session, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) safe_close(session) - DBLogger.debug("Added new IMSI_IMEI_HISTORY binding") + dbLogger.debug("Added new IMSI_IMEI_HISTORY binding") if 'sim_swap_notify_webhook' in yaml_config['eir']: - DBLogger.debug("Sending SIM Swap notification to Webhook") + dbLogger.debug("Sending SIM Swap notification to Webhook") try: dictToSend = {'imei':imei, 'imsi': imsi, 'match_response_code': match_response_code} - Webhook_Push_Async(str(yaml_config['eir']['sim_swap_notify_webhook']), json_data=dictToSend) + handleWebhook(dictToSend) except Exception as E: - DBLogger.debug("Failed to post to Webhook") - DBLogger.debug(str(E)) + dbLogger.debug("Failed to post to Webhook") + dbLogger.debug(str(E)) #Lookup Device Info if 'tac_database_csv' in yaml_config['eir']: try: device_info = get_device_info_from_TAC(imei=str(imei)) - DBLogger.debug("Got Device Info: " + str(device_info)) + dbLogger.debug("Got Device Info: " + str(device_info)) #@@Fixme # prom_eir_devices.labels( # imei_prefix=device_info['tac_prefix'], @@ -1960,35 +1850,35 @@ def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): # device_name=device_info['model'] # ).inc() except Exception as E: - DBLogger.debug("Failed to get device info from TAC") + dbLogger.debug("Failed to get device info from TAC") # prom_eir_devices.labels( # imei_prefix=str(imei)[0:8], # device_type='Unknown', # device_name='Unknown' # ).inc() else: - DBLogger.debug("No TAC database configured, skipping device info lookup") + dbLogger.debug("No TAC database configured, skipping device info lookup") #Sync state change with geored if propagate == True: try: if 'EIR' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - DBLogger.debug("Propagate EIR changes to Geographic PyHSS instances") - GeoRed_Push_Async( + dbLogger.debug("Propagate EIR changes to Geographic PyHSS instances") + handleGeored( {"imsi": str(imsi), "imei": str(imei), "match_response_code": str(match_response_code)} ) else: - DBLogger.debug("Config does not allow sync of EIR events") + dbLogger.debug("Config does not allow sync of EIR events") except Exception as E: - DBLogger.debug("Nothing synced to Geographic PyHSS instances for EIR event") - DBLogger.debug(E) + dbLogger.debug("Nothing synced to Geographic PyHSS instances for EIR event") + dbLogger.debug(E) return def Get_IMEI_IMSI_History(attribute): - DBLogger.debug("Called Get_IMEI_IMSI_History() for entry matching " + str(Get_IMEI_IMSI_History)) + dbLogger.debug("Called Get_IMEI_IMSI_History() for entry matching " + str(Get_IMEI_IMSI_History)) Session = sessionmaker(bind = engine) session = Session() result_array = [] @@ -2015,11 +1905,11 @@ def Get_IMEI_IMSI_History(attribute): def Check_EIR(imsi, imei): eir_response_code_table = {0 : 'Whitelist', 1: 'Blacklist', 2: 'Greylist'} - DBLogger.debug("Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei)) + dbLogger.debug("Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei)) Session = sessionmaker(bind = engine) session = Session() #Check for Exact Matches - DBLogger.debug("Looking for exact matches") + dbLogger.debug("Looking for exact matches") #Check for exact Matches try: results = session.query(EIR).filter_by(imei=str(imei), regex_mode=0) @@ -2027,11 +1917,11 @@ def Check_EIR(imsi, imei): result = result.__dict__ match_response_code = result['match_response_code'] if result['imsi'] == '': - DBLogger.debug("No IMSI specified in DB, so matching only on IMEI") + dbLogger.debug("No IMSI specified in DB, so matching only on IMEI") Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) return match_response_code elif result['imsi'] == str(imsi): - DBLogger.debug("Matched on IMEI and IMSI") + dbLogger.debug("Matched on IMEI and IMSI") Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) return match_response_code except Exception as E: @@ -2039,23 +1929,23 @@ def Check_EIR(imsi, imei): safe_close(session) raise ValueError(E) - DBLogger.debug("Did not match any Exact Matches - Checking Regex") + dbLogger.debug("Did not match any Exact Matches - Checking Regex") try: results = session.query(EIR).filter_by(regex_mode=1) #Get all Regex records from DB for result in results: result = result.__dict__ match_response_code = result['match_response_code'] if re.match(result['imei'], imei): - DBLogger.debug("IMEI matched " + str(result['imei'])) + dbLogger.debug("IMEI matched " + str(result['imei'])) #Check if IMSI also specified if len(result['imsi']) != 0: - DBLogger.debug("With IMEI matched, now checking if IMSI matches regex") + dbLogger.debug("With IMEI matched, now checking if IMSI matches regex") if re.match(result['imsi'], imsi): - DBLogger.debug("IMSI also matched, so match OK!") + dbLogger.debug("IMSI also matched, so match OK!") Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) return match_response_code else: - DBLogger.debug("No IMSI specified, so match OK!") + dbLogger.debug("No IMSI specified, so match OK!") Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) return match_response_code except Exception as E: @@ -2066,17 +1956,17 @@ def Check_EIR(imsi, imei): try: session.commit() except Exception as E: - DBLogger.error("Failed to commit session, error: " + str(E)) + dbLogger.error("Failed to commit session, error: " + str(E)) safe_rollback(session) safe_close(session) raise ValueError(E) - DBLogger.debug("No matches at all - Returning default response") + dbLogger.debug("No matches at all - Returning default response") Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=yaml_config['eir']['no_match_response']) safe_close(session) return yaml_config['eir']['no_match_response'] def Get_EIR_Rules(): - DBLogger.debug("Getting all EIR Rules") + dbLogger.debug("Getting all EIR Rules") Session = sessionmaker(bind = engine) session = Session() EIR_Rules = [] @@ -2090,7 +1980,7 @@ def Get_EIR_Rules(): safe_rollback(session) safe_close(session) raise ValueError(E) - DBLogger.debug("Final EIR_Rules: " + str(EIR_Rules)) + dbLogger.debug("Final EIR_Rules: " + str(EIR_Rules)) safe_close(session) return EIR_Rules @@ -2103,33 +1993,33 @@ def dict_bytes_to_dict_string(dict_bytes): def get_device_info_from_TAC(imei): - DBLogger.debug("Getting Device Info from IMEI: " + str(imei)) + dbLogger.debug("Getting Device Info from IMEI: " + str(imei)) #Try 8 digit TAC try: - DBLogger.debug("Trying to match on 8 Digit IMEI") + dbLogger.debug("Trying to match on 8 Digit IMEI") #@@Fixme # imei_result = logtool.RedisHMGET(str(imei[0:8])) # print("Got back: " + str(imei_result)) # imei_result = dict_bytes_to_dict_string(imei_result) # assert(len(imei_result) != 0) - # DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) + # dbLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) # return imei_result return "0" except: - DBLogger.debug("Failed to match on 8 digit IMEI") + dbLogger.debug("Failed to match on 8 digit IMEI") try: - DBLogger.debug("Trying to match on 6 Digit IMEI") + dbLogger.debug("Trying to match on 6 Digit IMEI") #@@Fixme # imei_result = logtool.RedisHMGET(str(imei[0:6])) # print("Got back: " + str(imei_result)) # imei_result = dict_bytes_to_dict_string(imei_result) # assert(len(imei_result) != 0) - # DBLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) + # dbLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) # return imei_result return "0" except: - DBLogger.debug("Failed to match on 6 digit IMEI") + dbLogger.debug("Failed to match on 6 digit IMEI") raise ValueError("No matching TAC in IMEI Database") @@ -2343,7 +2233,7 @@ def get_device_info_from_TAC(imei): GetAPN_Result = Get_APN(GetSubscriber_Result['default_apn']) print(GetAPN_Result) - #GeoRed_Push_Async({"imsi": "001001000000006", "serving_mme": "abc123"}) + #handleGeored({"imsi": "001001000000006", "serving_mme": "abc123"}) if DeleteAfter == True: diff --git a/lib/logtool.py b/lib/logtool.py index b6f0773..20cb414 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -7,7 +7,7 @@ class LogTool: def setupLogger(self, loggerName: str, config: dict): - logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_diameter.log') + logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_general.log') logLevel = config.get('logging', {}).get('level', 'INFO') logger = logging.getLogger(loggerName) formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s {%(pathname)s:%(lineno)d} %(message)s", datefmt="%m/%d/%Y %H:%M:%S %Z") diff --git a/lib/messaging.py b/lib/messaging.py index fa2e218..33975c7 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,4 +1,5 @@ from redis import Redis +import time, json class RedisMessaging: """ @@ -22,6 +23,33 @@ def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: except Exception as e: return '' + def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + """ + Stores a prometheus metric in a format readable by the metric service. + """ + if not metricValue.isdigit(): + return 'Invalid Argument: metricValue must be a digit' + metricValue = float(metricValue) + prometheusMetricBody = json.dumps([{ + 'NAME': metricName, + 'TYPE': metricType, + 'HELP': metricHelp, + 'LABELS': metricLabels, + 'ACTION': metricAction, + 'VALUE': metricValue, + } + ]) + + metricQueueName = f"metric-{serviceName}-{metricTimestamp}" + + try: + self.redisClient.rpush(metricQueueName, prometheusMetricBody) + if metricExpiry is not None: + self.redisClient.expire(metricQueueName, metricExpiry) + return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' + except Exception as e: + return '' + def getMessage(self, queue: str) -> str: """ Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 34fb25a..33314bd 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -1,5 +1,6 @@ import asyncio import redis.asyncio as redis +import time, json class RedisMessagingAsync: """ @@ -12,7 +13,7 @@ def __init__(self, host: str='localhost', port: int=6379): async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: """ - Stores a message in a given Queue (Key), and sets an expiry (in seconds) if provided. + Stores a message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -23,9 +24,37 @@ async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> except Exception as e: return '' + async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + """ + Stores a prometheus metric in a format readable by the metric service, asynchronously. + """ + if not isinstance(metricValue, (int, float)): + return 'Invalid Argument: metricValue must be a digit' + metricValue = float(metricValue) + prometheusMetricBody = json.dumps([{ + 'NAME': metricName, + 'TYPE': metricType, + 'HELP': metricHelp, + 'LABELS': metricLabels, + 'ACTION': metricAction, + 'VALUE': metricValue, + } + ]) + + metricQueueName = f"metric-{serviceName}-{metricTimestamp}" + + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + sendMetricResult = await(redisPipe.rpush(metricQueueName, prometheusMetricBody).execute()) + if metricExpiry is not None: + expireKeyResult = await(redisPipe.expire(metricQueueName, metricExpiry).execute()) + return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' + except Exception as e: + return '' + async def getMessage(self, queue: str) -> str: """ - Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. + Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -34,7 +63,10 @@ async def getMessage(self, queue: str) -> str: message = '' else: try: - message = message[0].decode() + if message[0] is None: + return '' + else: + message = message[0].decode() except (UnicodeDecodeError, AttributeError): pass return message @@ -44,7 +76,7 @@ async def getMessage(self, queue: str) -> str: async def getQueues(self, pattern: str='*') -> list: """ - Returns all Queues (Keys) in the database. + Returns all Queues (Keys) in the database, asynchronously. """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -53,19 +85,20 @@ async def getQueues(self, pattern: str='*') -> list: except Exception as e: return [] - async def getNextQueue(self, pattern: str='*') -> dict: + async def getNextQueue(self, pattern: str='*') -> str: """ - Returns the next Queue (Key) in the list. + Returns the next Queue (Key) in the list, asynchronously. """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: - return await(redisPipe.keys(pattern).execute())[1][0].decode() + nextQueue = await(redisPipe.keys(pattern).execute()) + return nextQueue[0][0].decode() except Exception as e: - return {} + return '' async def deleteQueue(self, queue: str) -> bool: """ - Deletes the given Queue (Key) + Deletes the given Queue (Key) asynchronously. """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: diff --git a/services/diameterService.py b/services/diameterService.py index d9c45e4..0ca2e63 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,7 +1,7 @@ import asyncio import sctp, socket import sys, os, json -import time, yaml +import time, yaml, uuid sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameter import Diameter @@ -12,6 +12,7 @@ class DiameterService: """ PyHSS Diameter Service A class for handling diameter requests and replies on Port 3868, via TCP or SCTP. + Functions in this class are high-performance, please edit with care. Last benchmarked on 23-08-2023. """ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): @@ -19,7 +20,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): with open("../config.yaml", "r") as self.configFile: self.config = yaml.safe_load(self.configFile) except: - print(f"[Diameter] Fatal Error - config.yaml not found, exiting.") + print(f"[Diameter] [__init__] Fatal Error - config.yaml not found, exiting.") quit() self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) @@ -29,71 +30,86 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.diameterLogger = self.logTool.setupLogger(loggerName='Diameter', config=self.config) self.socketTimeout = int(self.config.get('hss', {}).get('client_socket_timeout', 300)) - def validateDiameterRequest(self, requestData) -> bool: + async def validateDiameterRequest(self, requestData) -> bool: + """ + Asynchronously validates a given diameter request, and increments the 'Number of Diameter Requests' metric. + """ try: packetVars, avps = self.diameterLibrary.decode_diameter_packet(requestData) originHost = self.diameterLibrary.get_avp_data(avps, 264)[0] originHost = bytes.fromhex(originHost).decode("utf-8") + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_request_count', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Diameter Requests', + metricLabels={ + "diameter_application_id": str(packetVars["ApplicationId"]), + "diameter_cmd_code": str(packetVars["command_code"]), + "endpoint": originHost, + "type": "request"}, + metricExpiry=60)) except Exception as e: return False return True - async def readRequestData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int) -> bool: - self.diameterLogger.info(f"[Diameter] New connection from {clientAddress} on port {clientPort}") - + async def readRequestData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: + """ + Reads and parses incoming data from a connected client. Terminates the connection if diameter traffic is not received. + """ + self.diameterLogger.info(f"[Diameter] [readRequestData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}") while True: try: requestData = await asyncio.wait_for(reader.read(1024), timeout=socketTimeout) if len(requestData) > 0: - self.diameterLogger.debug(f"[Diameter] Received data from {clientAddress} on port {clientPort}") + self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}") - if not self.validateDiameterRequest(requestData): - self.diameterLogger.debug(f"[Diameter] Invalid Diameter Request, terminating connection.") + if not await(self.validateDiameterRequest(requestData)): + self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Invalid Diameter Request, terminating connection.") return False requestQueueName = f"diameter-request-{clientAddress}-{clientPort}-{time.time_ns()}" requestHexString = json.dumps({f"diameter-request": requestData.hex()}) - self.diameterLogger.debug(f"[Diameter] Queueing {requestHexString}") - await(self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString)) + self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Queueing {requestHexString}") + asyncio.ensure_future(self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString, queueExpiry=60)) except asyncio.TimeoutError: - self.diameterLogger.info(f"[Diameter] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") + self.diameterLogger.info(f"[Diameter] [readRequestData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") return False - async def writeResponseData(self, writer, clientAddress: str, clientPort: str) -> bool: - self.diameterLogger.debug(f"[Diameter] writeResponseData with host {clientAddress} on port {clientPort}") + async def writeResponseData(self, writer, clientAddress: str, clientPort: str, coroutineUuid: str) -> bool: + self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] writeResponseData with host {clientAddress} on port {clientPort}") while True: try: pendingResponseQueues = await(self.redisMessaging.getQueues()) if not len(pendingResponseQueues) > 0: assert() for responseQueue in pendingResponseQueues: - queuedMessageType = str(responseQueue).split('-')[1] - diameterResponseHost = str(responseQueue).split('-')[2] - diameterResponsePort = str(responseQueue).split('-')[3] + responseQueueSplit = str(responseQueue).split('-') + queuedMessageType = responseQueueSplit[1] + diameterResponseHost = responseQueueSplit[2] + diameterResponsePort = responseQueueSplit[3] if str(diameterResponseHost) == str(clientAddress) and str(diameterResponsePort) == str(clientPort) and queuedMessageType == 'response': - self.diameterLogger.debug(f"[Diameter] Matched {responseQueue} to host {clientAddress} on port {clientPort}") + self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Matched {responseQueue} to host {clientAddress} on port {clientPort}") try: diameterResponse = json.loads(await(self.redisMessaging.getMessage(queue=responseQueue))) - self.diameterLogger.debug(f"[Diameter] Attempting to send outbound response to {clientAddress} on {clientPort}.") + self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Attempting to send outbound response to {clientAddress} on {clientPort}.") diameterResponseBinary = bytes.fromhex(next(iter(diameterResponse.values()))) - self.diameterLogger.debug(f"[Diameter] Sending: {diameterResponseBinary.hex()} to to {clientAddress} on {clientPort}.") + self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Sending: {diameterResponseBinary.hex()} to to {clientAddress} on {clientPort}.") writer.write(diameterResponseBinary) await writer.drain() except Exception as e: print(e) except ConnectionError: - self.diameterLogger.info(f"[Diameter] Connection closed for {clientAddress} on port {clientPort}, closing writer.") + self.diameterLogger.info(f"[Diameter] [writeResponseData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.") return False except Exception as e: - await asyncio.sleep(0.005) continue async def handleConnection(self, reader, writer): (clientAddress, clientPort) = writer.get_extra_info('peername') self.diameterLogger.debug(f"[Diameter] Initial Connection from: {clientAddress} on port {clientPort}") + coroutineUuid = uuid.uuid4() - if False in await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout), - self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort)): + if False in await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid), + self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, coroutineUuid=coroutineUuid)): self.diameterLogger.debug(f"[Diameter] Closing Writer for {clientAddress} on port {clientPort}.") writer.close() await writer.wait_closed() diff --git a/services/georedService.py b/services/georedService.py index e69de29..9217a1f 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -0,0 +1,139 @@ +import os, sys, json, yaml +import requests, uuid +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from banners import Banners +from logtool import LogTool + +class GeoredService: + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Geored] Fatal Error - config.yaml not found, exiting.") + quit() + self.logTool = LogTool() + self.banners = Banners() + self.georedLogger = self.logTool.setupLogger(loggerName='Geored', config=self.config) + self.georedLogger.info(self.banners.georedService()) + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.remotePeers = self.config.get('geored', {}).get('sync_endpoints', []) + if not self.config.get('geored', {}).get('enabled'): + self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") + quit() + if not (len(self.remotePeers) > 0): + self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") + quit() + + def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + operation = operation.upper() + requestOperations = {'GET': requests.get, 'PUT': requests.put, 'POST': requests.post, 'PATCH':requests.patch, 'DELETE': requests.delete} + + if not url or not operation or not body: + return False + + if operation not in requestOperations: + return False + + headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} + + for attempt in range(retryCount): + try: + if operation in ['PUT', 'POST', 'PATCH']: + response = requestOperations[operation](url, json=body, headers=headers) + else: + response = requestOperations[operation](url, headers=headers) + if 200 <= response.status_code <= 299: + self.georedLogger.debug(f"[Geored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}") + + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(response.status_code), + "error": ""}, + metricExpiry=60) + break + else: + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(response.status_code), + "error": str(response.reason)}, + metricExpiry=60) + except requests.exceptions.ConnectionError as e: + error_message = str(e) + self.georedLogger.warning(f"[Geored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + if "Name or service not known" in error_message: + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "No matching DNS entry found"}, + metricExpiry=60) + else: + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "Connection Refused"}, + metricExpiry=60) + except requests.exceptions.Timeout: + self.georedLogger.warning(f"[Geored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": "Timeout"}, + metricExpiry=60) + except Exception as e: + self.georedLogger.error(f"[Geored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": "000", + "error": e}, + metricExpiry=60) + return True + + def handleGeoredQueue(self): + try: + georedQueue = self.redisMessaging.getNextQueue(pattern='geored-*') + georedMessage = self.redisMessaging.getMessage(queue=georedQueue) + assert(len(georedMessage)) + self.georedLogger.debug(f"[Geored] Queue: {georedQueue}") + self.georedLogger.debug(f"[Geored] Message: {georedMessage}") + + georedDict = json.loads(georedMessage) + georedOperation = georedDict['operation'] + georedBody = georedDict['body'] + + for remotePeer in self.remotePeers: + self.sendGeored(url=remotePeer+'/geored/', operation=georedOperation, body=georedBody) + + except Exception as e: + return False + +if __name__ == '__main__': + georedService = GeoredService() + while True: + georedService.handleGeoredQueue() \ No newline at end of file diff --git a/services/hssService.py b/services/hssService.py index 7fdcfc9..368680b 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,7 +1,7 @@ import os, sys, json, yaml -import time, logging +import time, asyncio sys.path.append(os.path.realpath('../lib')) -from messaging import RedisMessaging +from messagingAsync import RedisMessagingAsync from diameter import Diameter from banners import Banners from logtool import LogTool @@ -16,7 +16,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): except: print(f"[HSS] Fatal Error - config.yaml not found, exiting.") quit() - self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.logTool = LogTool() self.banners = Banners() self.hssLogger = self.logTool.setupLogger(loggerName='HSS', config=self.config) @@ -28,44 +28,45 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.diameterLibrary = Diameter(originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.hssLogger.info(self.banners.hssService()) + async def handleRequestQueue(self): + while True: + try: + requestQueue = await(self.redisMessaging.getNextQueue(pattern='diameter-request*')) + requestMessage = await(self.redisMessaging.getMessage(queue=requestQueue)) + assert(len(requestMessage)) + self.hssLogger.debug(f"[HSS] Inbound Diameter Request Queue: {requestQueue}") + self.hssLogger.debug(f"[HSS] Inbound Diameter Request: {requestMessage}") + requestDict = json.loads(requestMessage) + requestBinary = bytes.fromhex(next(iter(requestDict.values()))) + requestSplit = str(requestQueue).split('-') + requestHost = requestSplit[2] + requestPort = requestSplit[3] + requestTimestamp = requestSplit[4] - def handleOutboundResponse(self, queue: str, diameterResponse: str): - self.redisMessaging.sendMessage(queue=queue, message=diameterResponse, queueExpiry=60) + try: + diameterResponse = self.diameterLibrary.generateDiameterResponse(requestBinaryData=requestBinary) + except Exception as e: + self.hssLogger.warn(f"[HSS] Failed to generate diameter response: {e}") + continue + + self.hssLogger.debug(f"[HSS] Generated Diameter Response: {diameterResponse}") + if not len(diameterResponse) > 0: + continue + + outboundResponseQueue = f"diameter-response-{requestHost}-{requestPort}-{requestTimestamp}" + outboundResponse = json.dumps({"diameter-response": diameterResponse}) - def handleRequestQueue(self): - try: - requestQueue = self.redisMessaging.getNextQueue(pattern='diameter-request*') - requestMessage = self.redisMessaging.getMessage(queue=requestQueue) - assert(len(requestMessage)) - self.hssLogger.debug(f"[HSS] Inbound Diameter Request Queue: {requestQueue}") - self.hssLogger.debug(f"[HSS] Inbound Diameter Request: {requestMessage}") - - requestDict = json.loads(requestMessage) - requestBinary = bytes.fromhex(next(iter(requestDict.values()))) - requestHost = str(requestQueue).split('-')[2] - requestPort = str(requestQueue).split('-')[3] - requestTimestamp = str(requestQueue).split('-')[4] - - diameterResponse = self.diameterLibrary.generateDiameterResponse(requestBinaryData=requestBinary) - self.hssLogger.debug(f"[HSS] Generated Diameter Response: {diameterResponse}") - if not len(diameterResponse) > 0: - return False - - outboundResponseQueue = f"diameter-response-{requestHost}-{requestPort}-{requestTimestamp}" - outboundResponse = json.dumps({"diameter-response": diameterResponse}) - - self.hssLogger.debug(f"[HSS] Outbound Diameter Response Queue: {outboundResponseQueue}") - self.hssLogger.debug(f"[HSS] Outbound Diameter Response: {outboundResponse}") + self.hssLogger.debug(f"[HSS] Outbound Diameter Response Queue: {outboundResponseQueue}") + self.hssLogger.debug(f"[HSS] Outbound Diameter Response: {outboundResponse}") - self.handleOutboundResponse(queue=outboundResponseQueue, diameterResponse=outboundResponse) - time.sleep(0.005) + asyncio.ensure_future(self.redisMessaging.sendMessage(queue=outboundResponseQueue, message=outboundResponse, queueExpiry=60)) - except Exception as e: - return False + except Exception as e: + continue + if __name__ == '__main__': hssService = HssService() - while True: - hssService.handleRequestQueue() \ No newline at end of file + asyncio.run(hssService.handleRequestQueue()) \ No newline at end of file diff --git a/services/metricService.py b/services/metricService.py new file mode 100644 index 0000000..f0ea88f --- /dev/null +++ b/services/metricService.py @@ -0,0 +1,100 @@ +import asyncio +import sys, os, json +import time, json, yaml +from prometheus_client import make_wsgi_app, start_http_server, Counter, Gauge, Summary, Histogram, CollectorRegistry +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from flask import Flask +import threading +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from banners import Banners +from logtool import LogTool + +class MetricService: + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[HSS] Fatal Error - config.yaml not found, exiting.") + quit() + + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.banners = Banners() + self.logTool = LogTool() + self.registry = CollectorRegistry(auto_describe=True) + self.metricLogger = self.logTool.setupLogger(loggerName='Metric', config=self.config) + self.metricLogger.info(self.banners.metricService()) + + + def handleMetrics(self): + try: + actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} + prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} + + metricQueue = self.redisMessaging.getNextQueue(pattern='metric-*') + metric = self.redisMessaging.getMessage(queue=metricQueue) + if not (len(metric) > 0): + return + self.metricLogger.info(f"Received Metric: {metric}") + prometheusJsonList = json.loads(metric) + for prometheusJson in prometheusJsonList: + self.metricLogger.debug(prometheusJson) + if not all(key in prometheusJson for key in ('NAME', 'TYPE', 'ACTION', 'VALUE')): + raise ValueError('All fields are not available for parsing') + counterName = prometheusJson['NAME'] + counterType = prometheusTypes.get(prometheusJson['TYPE'].lower()) + counterAction = prometheusJson['ACTION'].lower() + counterValue = float(prometheusJson['VALUE']) + counterHelp = prometheusJson.get('HELP', '') + counterLabels = prometheusJson.get('LABELS', {}) + + if isinstance(counterLabels, list): + counterLabels = dict() + + if counterType is not None: + try: + counterRecord = counterType(counterName, counterHelp, labelnames=counterLabels.keys(), registry=self.registry) + if counterLabels: + counterRecord = counterRecord.labels(*counterLabels.values()) + except ValueError as e: + counterRecord = self.registry._names_to_collectors.get(counterName) + if counterLabels and counterRecord: + counterRecord = counterRecord.labels(*counterLabels.values()) + action = actions.get(counterAction) + if action is not None: + # Here we dynamically lookup the class from prometheus_client, and grab the matched function name called 'action'. + prometheusMethod = getattr(counterRecord, action) + prometheusMethod(counterValue) + else: + self.metricLogger.debug(f"Invalid action `{counterAction}` in message: {metric}, skipping.") + continue + else: + self.metricLogger.debug(f"Invalid type `{counterType}` in message: {metric}, skipping.") + continue + + except Exception as e: + self.metricLogger.error(f"Unable to parse message: {metric}, due to {e}. Skipping.") + return + + + def getMetrics(self): + while True: + self.handleMetrics() + + +if __name__ == '__main__': + + metricService = MetricService() + metricServiceThread = threading.Thread(target=metricService.getMetrics) + metricServiceThread.start() + + prometheusWebClient = Flask(__name__) + prometheusWebClient.wsgi_app = DispatcherMiddleware(prometheusWebClient.wsgi_app, { + '/metrics': make_wsgi_app(registry=metricService.registry) + }) + + #Uncomment the statement below to run a local testing instance. + + prometheusWebClient.run(host='0.0.0.0', port=9191) \ No newline at end of file diff --git a/services/prometheusService.py b/services/prometheusService.py deleted file mode 100644 index e69de29..0000000 From b606aca8b77368bc6873649c054a4f064b668cbf Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Thu, 24 Aug 2023 14:57:34 +1000 Subject: [PATCH 04/43] Add diameterAsync --- lib/diameter.py | 70 +++++++----- lib/diameterAsync.py | 215 ++++++++++++++++++++++++++++++++++++ lib/logtool.py | 27 +++-- lib/messagingAsync.py | 29 +++-- services/diameterService.py | 201 +++++++++++++++++++++------------ services/hssService.py | 57 +++++----- 6 files changed, 448 insertions(+), 151 deletions(-) create mode 100644 lib/diameterAsync.py diff --git a/lib/diameter.py b/lib/diameter.py index 1be7c4c..0be72ae 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -13,12 +13,7 @@ import traceback import database import yaml - -# #Setup Logging -# import logtool -# from logtool import * -# logtool = logtool.LogTool() -# logtool.setup_logger('DiameterLogger', self.yaml_config['logging']['logfiles']['diameter_logging_file'], level=self.yaml_config['logging']['level']) +from typing import Literal class Diameter: @@ -36,6 +31,25 @@ def __init__(self, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999. self.diameterLibLogger.info("Initialized Diameter for " + str(self.OriginHost) + " at Realm " + str(self.OriginRealm) + " serving as Product Name " + str(self.ProductName)) self.diameterLibLogger.info("PLMN is " + str(self.MCC) + "/" + str(self.MNC)) + self.diameterCommandList = [ + {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, + {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, + {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, + {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, + {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, + {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, + {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, + {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, + {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + ] + #Generates rounding for calculating padding def myround(self, n, base=4): if(n > 0): @@ -272,6 +286,7 @@ def generate_diameter_packet(self, packet_version, packet_flags, packet_command_ packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp return packet_hex + def decode_diameter_packet(self, data): packet_vars = {} avps = [] @@ -301,6 +316,7 @@ def decode_diameter_packet(self, data): pass return packet_vars, avps + def decode_avp_packet(self, data): if len(data) <= 8: @@ -367,6 +383,7 @@ def get_avp_data(self, avps, avp_code): #Loops through list of dic misc_data.append(keys['misc_data']) return misc_data + def decode_diameter_packet_length(self, data): packet_vars = {} data = data.hex() @@ -377,31 +394,28 @@ def decode_diameter_packet_length(self, data): else: return False - def generateDiameterResponse(self, requestBinaryData: str) -> str: - packet_vars, avps = self.decode_diameter_packet(requestBinaryData) + def getDiameterMessageType(self, binaryData: str) -> dict: + packet_vars, avps = self.decode_diameter_packet(binaryData) + response = {} + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + response['inbound'] = diameterApplication["requestAcronym"] + response['outbound'] = diameterApplication["responseAcronym"] + self.diameterLibLogger.debug(f"[diameter.py] Successfully generated response: {response}") + except Exception as e: + continue + + return response + + def generateDiameterResponse(self, binaryData: str) -> str: + packet_vars, avps = self.decode_diameter_packet(binaryData) origin_host = self.get_avp_data(avps, 264)[0] origin_host = binascii.unhexlify(origin_host).decode("utf-8") response = '' - diameterList = [ - {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, - {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, - {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, - {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, - {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, - {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, - {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, - {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, - {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, - {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, - {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, - {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, - {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, - {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, - {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, - {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, - ] - self.diameterLibLogger.debug(f"Generating a diameter response") # Drop packet if it's a response packet: @@ -410,7 +424,7 @@ def generateDiameterResponse(self, requestBinaryData: str) -> str: self.diameterLibLogger.debug(packet_vars) return - for diameterApplication in diameterList: + for diameterApplication in self.diameterCommandList: try: assert(packet_vars["command_code"] == diameterApplication["commandCode"]) assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py new file mode 100644 index 0000000..e9fd657 --- /dev/null +++ b/lib/diameterAsync.py @@ -0,0 +1,215 @@ +#Diameter Packet Decoder / Encoder & Tools +import math +import asyncio + +class DiameterAsync: + + def __init__(self, logger): + self.diameterCommandList = [ + {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, + {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, + {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, + {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, + {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, + {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, + {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, + {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, + {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + ] + + self.logger = logger + + #Generates rounding for calculating padding + async def myRoundAsync(self, n, base=4): + if(n > 0): + return math.ceil(n/4.0) * 4 + elif( n < 0): + return math.floor(n/4.0) * 4 + else: + return 4 + + async def getAvpDataAsync(self, avps, avp_code): + #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) + misc_data = [] + for keys in avps: + if keys['avp_code'] == avp_code: + misc_data.append(keys['misc_data']) + return misc_data + + async def decodeDiameterPacketAsync(self, data): + packet_vars = {} + avps = [] + + if type(data) is bytes: + data = data.hex() + + packet_vars['packet_version'] = data[0:2] + packet_vars['length'] = int(data[2:8], 16) + packet_vars['flags'] = data[8:10] + packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + packet_vars['command_code'] = int(data[10:16], 16) + packet_vars['ApplicationId'] = int(data[16:24], 16) + packet_vars['hop-by-hop-identifier'] = data[24:32] + packet_vars['end-to-end-identifier'] = data[32:40] + + avp_sum = data[40:] + + avp_vars, remaining_avps = await(self.decodeAvpPacketAsync(avp_sum)) + avps.append(avp_vars) + + while len(remaining_avps) > 0: + avp_vars, remaining_avps = await(self.decodeAvpPacketAsync(remaining_avps)) + avps.append(avp_vars) + else: + pass + return packet_vars, avps + + async def decodeAvpPacketAsync(self, data): + + if len(data) <= 8: + #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP + raise ValueError("Length of data is too short to be valid AVP") + + avp_vars = {} + avp_vars['avp_code'] = int(data[0:8], 16) + + avp_vars['avp_flags'] = data[8:10] + avp_vars['avp_length'] = int(data[10:16], 16) + if avp_vars['avp_flags'] == "c0": + #If c0 is present AVP is Vendor AVP + avp_vars['vendor_id'] = int(data[16:24], 16) + avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] + else: + #if is not a vendor AVP + avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] + + if avp_vars['avp_length'] % 4 == 0: + #Multiple of 4 - No Padding needed + avp_vars['padding'] = 0 + else: + #Not multiple of 4 - Padding needed + rounded_value = await(self.myRoundAsync(avp_vars['avp_length'])) + avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 + avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] + + + #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] + try: + sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacketAsync(avp_vars['misc_data'])) + #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... + if int(sub_avp_vars['avp_code']) > 9999: + pass + else: + #If the decoded AVP is valid store it + avp_vars['misc_data'] = [] + avp_vars['misc_data'].append(sub_avp_vars) + #While there are more AVPs to be decoded, decode them: + while len(sub_remaining_avps) > 0: + sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacketAsync(sub_remaining_avps)) + avp_vars['misc_data'].append(sub_avp_vars) + + except Exception as e: + if str(e) == "invalid literal for int() with base 16: ''": + pass + elif str(e) == "Length of data is too short to be valid AVP": + pass + else: + self.logger.warn("[Diameter] [decodeAvpPacketAsync] failed to decode sub-avp - error: " + str(e)) + pass + + remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again + return avp_vars, remaining_avps + + + async def getDiameterMessageTypeAsync(self, binaryData: str) -> dict: + packet_vars, avps = await(self.decodeDiameterPacketAsync(binaryData)) + response = {} + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + response['inbound'] = diameterApplication["requestAcronym"] + response['outbound'] = diameterApplication["responseAcronym"] + self.logger.debug(f"[Diameter] [getDiameterMessageTypeAsync] Successfully got message type: {response}") + except Exception as e: + continue + + return response + + async def generateDiameterResponseAsync(self, binaryData: str) -> str: + packet_vars, avps = await(self.decodeDiameterPacketAsync(binaryData)) + response = '' + + # Drop packet if it's a response packet: + if packet_vars["flags_bin"][0:1] == "0": + self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Got a Response, not a request - dropping it: {packet_vars}") + return + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if 'flags' in diameterApplication: + assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + response = diameterApplication["responseMethod"](packet_vars, avps) + self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Successfully generated response: {response}") + except Exception as e: + continue + + return response + + async def Answer_257(self): + pass + + async def Answer_16777238_272(self): + pass + + async def Answer_280(self): + pass + + async def Answer_282(self): + pass + + async def Answer_16777251_318(self): + pass + + async def Answer_16777251_316(self): + pass + + async def Answer_16777251_321(self): + pass + + async def Answer_16777251_323(self): + pass + + async def Answer_16777216_300(self): + pass + + async def Answer_16777216_301(self): + pass + + async def Answer_16777216_302(self): + pass + + async def Answer_16777216_303(self): + pass + + async def Answer_16777217_306(self): + pass + + async def Answer_16777217_307(self): + pass + + async def Answer_16777252_324(self): + pass + + async def Answer_16777291_8388622(self): + pass \ No newline at end of file diff --git a/lib/logtool.py b/lib/logtool.py index 20cb414..ec53ce6 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -1,29 +1,28 @@ import logging import logging.handlers as handlers -import os -import sys +import os, sys sys.path.append(os.path.realpath('../')) class LogTool: def setupLogger(self, loggerName: str, config: dict): - logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_general.log') + # logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_general.log') logLevel = config.get('logging', {}).get('level', 'INFO') logger = logging.getLogger(loggerName) formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s {%(pathname)s:%(lineno)d} %(message)s", datefmt="%m/%d/%Y %H:%M:%S %Z") - try: - rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) - except PermissionError: - logFileName = logFile.split('/')[-1] - pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) - print(f"[LogTool] Warning - Unable to write to {logFile}, using {pyhssRootDir}/log/{logFileName} instead.") - logFile = f"{pyhssRootDir}/log/{logFileName}" - rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) - pass + # try: + # rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) + # except PermissionError: + # logFileName = logFile.split('/')[-1] + # pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + # print(f"[LogTool] Warning - Unable to write to {logFile}, using {pyhssRootDir}/log/{logFileName} instead.") + # logFile = f"{pyhssRootDir}/log/{logFileName}" + # rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) + # pass streamHandler = logging.StreamHandler() streamHandler.setFormatter(formatter) - rolloverHandler.setFormatter(formatter) + # rolloverHandler.setFormatter(formatter) logger.setLevel(logLevel) logger.addHandler(streamHandler) - logger.addHandler(rolloverHandler) + # logger.addHandler(rolloverHandler) return logger \ No newline at end of file diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 33314bd..7aceafe 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -17,9 +17,10 @@ async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> """ try: async with self.redisClient.pipeline(transaction=True) as redisPipe: - sendMessageResult = await(redisPipe.rpush(queue, message).execute()) + await redisPipe.rpush(queue, message) if queueExpiry is not None: - expireKeyResult = await(redisPipe.expire(queue, queueExpiry).execute()) + await redisPipe.expire(queue, queueExpiry) + sendMessageResult, expireKeyResult = await redisPipe.execute() return f'{message} stored in {queue} successfully.' except Exception as e: return '' @@ -45,9 +46,10 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m try: async with self.redisClient.pipeline(transaction=True) as redisPipe: - sendMetricResult = await(redisPipe.rpush(metricQueueName, prometheusMetricBody).execute()) - if metricExpiry is not None: - expireKeyResult = await(redisPipe.expire(metricQueueName, metricExpiry).execute()) + await(redisPipe.rpush(metricQueueName, prometheusMetricBody).execute()) + if metricExpiry is not None: + await(redisPipe.expire(metricQueueName, metricExpiry).execute()) + sendMetricResult, expireKeyResult = await redisPipe.execute() return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' except Exception as e: return '' @@ -57,8 +59,7 @@ async def getMessage(self, queue: str) -> str: Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. """ try: - async with self.redisClient.pipeline(transaction=True) as redisPipe: - message = await(redisPipe.lpop(queue).execute()) + message = await(self.redisClient.lpop(queue)) if message is None: message = '' else: @@ -79,9 +80,9 @@ async def getQueues(self, pattern: str='*') -> list: Returns all Queues (Keys) in the database, asynchronously. """ try: - async with self.redisClient.pipeline(transaction=True) as redisPipe: - allQueues = await(redisPipe.keys(pattern).execute()) - return [x.decode() for x in allQueues[0]] + allQueuesBinary = await(self.redisClient.keys(pattern)) + allQueues = [x.decode() for x in allQueuesBinary] + return allQueues except Exception as e: return [] @@ -90,9 +91,8 @@ async def getNextQueue(self, pattern: str='*') -> str: Returns the next Queue (Key) in the list, asynchronously. """ try: - async with self.redisClient.pipeline(transaction=True) as redisPipe: - nextQueue = await(redisPipe.keys(pattern).execute()) - return nextQueue[0][0].decode() + nextQueue = await(self.redisClient.keys(pattern)) + return nextQueue[0].decode() except Exception as e: return '' @@ -101,8 +101,7 @@ async def deleteQueue(self, queue: str) -> bool: Deletes the given Queue (Key) asynchronously. """ try: - async with self.redisClient.pipeline(transaction=True) as redisPipe: - await(redisPipe.delete(queue).execute()) + deleteQueueResult = await(self.redisClient.delete(queue)) return True except Exception as e: return False diff --git a/services/diameterService.py b/services/diameterService.py index 0ca2e63..1a167dc 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,18 +1,17 @@ import asyncio -import sctp, socket import sys, os, json import time, yaml, uuid sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync -from diameter import Diameter +from diameterAsync import DiameterAsync from banners import Banners from logtool import LogTool class DiameterService: """ PyHSS Diameter Service - A class for handling diameter requests and replies on Port 3868, via TCP or SCTP. - Functions in this class are high-performance, please edit with care. Last benchmarked on 23-08-2023. + A class for handling diameter inbounds and replies on Port 3868, via TCP. + Functions in this class are high-performance, please edit with care. Last benchmarked on 24-08-2023. """ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): @@ -24,114 +23,180 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) - self.diameterLibrary = Diameter() self.banners = Banners() self.logTool = LogTool() self.diameterLogger = self.logTool.setupLogger(loggerName='Diameter', config=self.config) - self.socketTimeout = int(self.config.get('hss', {}).get('client_socket_timeout', 300)) + self.diameterLibrary = DiameterAsync(logger=self.diameterLogger) + self.activeConnections = set() - async def validateDiameterRequest(self, requestData) -> bool: + async def validateDiameterInbound(self, inboundData) -> bool: """ - Asynchronously validates a given diameter request, and increments the 'Number of Diameter Requests' metric. + Asynchronously validates a given diameter inbound, and increments the 'Number of Diameter Inbounds' metric. """ try: - packetVars, avps = self.diameterLibrary.decode_diameter_packet(requestData) - originHost = self.diameterLibrary.get_avp_data(avps, 264)[0] + packetVars, avps = await(self.diameterLibrary.decodeDiameterPacketAsync(inboundData)) + originHost = (await self.diameterLibrary.getAvpDataAsync(avps, 264))[0] originHost = bytes.fromhex(originHost).decode("utf-8") - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_request_count', + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Diameter Requests', + metricValue=1.0, metricHelp='Number of Diameter Inbounds', metricLabels={ "diameter_application_id": str(packetVars["ApplicationId"]), "diameter_cmd_code": str(packetVars["command_code"]), "endpoint": originHost, - "type": "request"}, + "type": "inbound"}, metricExpiry=60)) except Exception as e: + print(e) return False return True - async def readRequestData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: + async def logActiveConnections(self): + """ + Logs the number of active connections on a rolling basis. + """ + while True: + activeConnections = self.activeConnections + if not len(activeConnections) > 0: + activeConnections = '' + self.diameterLogger.info(f"[Diameter] [logActiveConnections] {len(self.activeConnections)} Active Connections {activeConnections}") + await(asyncio.sleep(60)) + + async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ - Reads and parses incoming data from a connected client. Terminates the connection if diameter traffic is not received. + Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. + Terminates the connection if diameter traffic is not received, or if the client disconnects. """ - self.diameterLogger.info(f"[Diameter] [readRequestData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}") + self.diameterLogger.info(f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}") while True: try: - requestData = await asyncio.wait_for(reader.read(1024), timeout=socketTimeout) - if len(requestData) > 0: - self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}") + + inboundData = await(asyncio.wait_for(reader.read(1024), timeout=socketTimeout)) + + if reader.at_eof(): + return False + + if len(inboundData) > 0: + self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}") - if not await(self.validateDiameterRequest(requestData)): - self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Invalid Diameter Request, terminating connection.") + if not await(self.validateDiameterInbound(inboundData)): + self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, terminating connection.") return False + + diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=inboundData)) + diameterMessageType = diameterMessageType.get('inbound', '') - requestQueueName = f"diameter-request-{clientAddress}-{clientPort}-{time.time_ns()}" - requestHexString = json.dumps({f"diameter-request": requestData.hex()}) - self.diameterLogger.debug(f"[Diameter] [readRequestData] [{coroutineUuid}] Queueing {requestHexString}") - asyncio.ensure_future(self.redisMessaging.sendMessage(queue=requestQueueName, message=requestHexString, queueExpiry=60)) - except asyncio.TimeoutError: - self.diameterLogger.info(f"[Diameter] [readRequestData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") + inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" + inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) + self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}") + asyncio.ensure_future(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=60)) + + except Exception as e: + self.diameterLogger.info(f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") + self.diameterLogger.debug(e) return False - async def writeResponseData(self, writer, clientAddress: str, clientPort: str, coroutineUuid: str) -> bool: - self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] writeResponseData with host {clientAddress} on port {clientPort}") + async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: + """ + Continually polls the Redis queue for outbound messages. Received messages from the queue are validated against the connected client, and sent. + """ + self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}") while True: try: - pendingResponseQueues = await(self.redisMessaging.getQueues()) - if not len(pendingResponseQueues) > 0: - assert() - for responseQueue in pendingResponseQueues: - responseQueueSplit = str(responseQueue).split('-') - queuedMessageType = responseQueueSplit[1] - diameterResponseHost = responseQueueSplit[2] - diameterResponsePort = responseQueueSplit[3] - if str(diameterResponseHost) == str(clientAddress) and str(diameterResponsePort) == str(clientPort) and queuedMessageType == 'response': - self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Matched {responseQueue} to host {clientAddress} on port {clientPort}") - try: - diameterResponse = json.loads(await(self.redisMessaging.getMessage(queue=responseQueue))) - self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Attempting to send outbound response to {clientAddress} on {clientPort}.") - diameterResponseBinary = bytes.fromhex(next(iter(diameterResponse.values()))) - self.diameterLogger.debug(f"[Diameter] [writeResponseData] [{coroutineUuid}] Sending: {diameterResponseBinary.hex()} to to {clientAddress} on {clientPort}.") - writer.write(diameterResponseBinary) - await writer.drain() - except Exception as e: - print(e) - except ConnectionError: - self.diameterLogger.info(f"[Diameter] [writeResponseData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.") + + if writer.transport.is_closing(): + return False + + pendingOutboundQueues = await(self.redisMessaging.getQueues(pattern='diameter-outbound*')) + if not len(pendingOutboundQueues) > 0: + await(asyncio.sleep(0)) + continue + + self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queues: {pendingOutboundQueues}") + for outboundQueue in pendingOutboundQueues: + outboundQueueSplit = str(outboundQueue).split('-') + queuedMessageType = outboundQueueSplit[1] + diameterOutboundHost = outboundQueueSplit[2] + diameterOutboundPort = outboundQueueSplit[3] + + if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': + self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {outboundQueue} to host {clientAddress} on port {clientPort}") + diameterOutbound = json.loads(await(self.redisMessaging.getMessage(queue=outboundQueue))) + diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) + diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=diameterOutboundBinary)) + diameterMessageType = diameterMessageType.get('outbound', '') + self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.") + writer.write(diameterOutboundBinary) + await(writer.drain()) + await(asyncio.sleep(0)) + + except Exception: + self.diameterLogger.info(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.") return False - except Exception as e: - continue + await(asyncio.sleep(0)) async def handleConnection(self, reader, writer): - (clientAddress, clientPort) = writer.get_extra_info('peername') - self.diameterLogger.debug(f"[Diameter] Initial Connection from: {clientAddress} on port {clientPort}") - coroutineUuid = uuid.uuid4() + """ + For each new connection on port 3868, create an asynchronous reader and writer. If a reader or writer returns false, ensure that the connection is torn down entirely. + """ + try: + (clientAddress, clientPort) = writer.get_extra_info('peername') + self.diameterLogger.debug(f"[Diameter] Initial Connection from: {clientAddress} on port {clientPort}") + coroutineUuid = str(uuid.uuid4()) + self.activeConnections.add((clientAddress, clientPort, coroutineUuid)) - if False in await asyncio.gather(self.readRequestData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid), - self.writeResponseData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, coroutineUuid=coroutineUuid)): - self.diameterLogger.debug(f"[Diameter] Closing Writer for {clientAddress} on port {clientPort}.") + readTask = asyncio.create_task(self.readInboundData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) + writeTask = asyncio.create_task(self.writeOutboundData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) + + completeTasks, pendingTasks = await(asyncio.wait([readTask, writeTask], return_when=asyncio.FIRST_COMPLETED)) + + for pendingTask in pendingTasks: + try: + pendingTask.cancel() + await(asyncio.sleep(0)) + except asyncio.CancelledError: + pass + writer.close() - await writer.wait_closed() - self.diameterLogger.debug(f"[Diameter] Closed Writer for {clientAddress} on port {clientPort}.") + await(writer.wait_closed()) + self.activeConnections.discard((clientAddress, clientPort, coroutineUuid)) + + return + + except Exception as e: + self.diameterLogger.warning(f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}") return - async def startServer(self, host: str='0.0.0.0', port: int=3868, type: str='TCP'): + async def startServer(self, host: str=None, port: int=None, type: str=None): + """ + Start a server with the given parameters and handle new clients with self.handleConnection. + Also create a single instance of self.logActiveConnections. + """ + + if host is None: + host=str(self.config.get('hss', {}).get('bind_ip', '0.0.0.0')[0]) + + if port is None: + port=int(self.config.get('hss', {}).get('bind_port', 3868)) + + if type is None: + type=str(self.config.get('hss', {}).get('transport', 'TCP')) + + self.socketTimeout = int(self.config.get('hss', {}).get('client_socket_timeout', 300)) + if type.upper() == 'TCP': - server = await asyncio.start_server(self.handleConnection, host, port) - elif type.upper() == 'SCTP': - sctpSocket = sctp.sctpsocket_tcp(socket.AF_INET) - server = await asyncio.start_server(self.handleConnection, host, port, socket=sctpSocket) + server = await(asyncio.start_server(self.handleConnection, host, port)) else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) self.diameterLogger.info(self.banners.diameterService()) self.diameterLogger.info(f'[Diameter] Serving on {servingAddresses}') - + asyncio.create_task(self.logActiveConnections()) + async with server: - await server.serve_forever() + await(server.serve_forever()) if __name__ == '__main__': diameterService = DiameterService() - asyncio.run(diameterService.startServer()) \ No newline at end of file + asyncio.run(diameterService.startServer(), debug=True) \ No newline at end of file diff --git a/services/hssService.py b/services/hssService.py index 368680b..f69f1bc 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,7 +1,6 @@ import os, sys, json, yaml -import time, asyncio sys.path.append(os.path.realpath('../lib')) -from messagingAsync import RedisMessagingAsync +from messaging import RedisMessaging from diameter import Diameter from banners import Banners from logtool import LogTool @@ -16,7 +15,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): except: print(f"[HSS] Fatal Error - config.yaml not found, exiting.") quit() - self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) self.logTool = LogTool() self.banners = Banners() self.hssLogger = self.logTool.setupLogger(loggerName='HSS', config=self.config) @@ -28,39 +27,45 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.diameterLibrary = Diameter(originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.hssLogger.info(self.banners.hssService()) - async def handleRequestQueue(self): + def handleQueue(self): + """ + Gets and parses inbound diameter requests, processes them and queues the response. + """ while True: try: - requestQueue = await(self.redisMessaging.getNextQueue(pattern='diameter-request*')) - requestMessage = await(self.redisMessaging.getMessage(queue=requestQueue)) - assert(len(requestMessage)) - self.hssLogger.debug(f"[HSS] Inbound Diameter Request Queue: {requestQueue}") - self.hssLogger.debug(f"[HSS] Inbound Diameter Request: {requestMessage}") + inboundQueue = self.redisMessaging.getNextQueue(pattern='diameter-inbound*') + inboundMessage = self.redisMessaging.getMessage(queue=inboundQueue) + assert(len(inboundMessage)) - requestDict = json.loads(requestMessage) - requestBinary = bytes.fromhex(next(iter(requestDict.values()))) - requestSplit = str(requestQueue).split('-') - requestHost = requestSplit[2] - requestPort = requestSplit[3] - requestTimestamp = requestSplit[4] + inboundDict = json.loads(inboundMessage) + inboundBinary = bytes.fromhex(next(iter(inboundDict.values()))) + inboundSplit = str(inboundQueue).split('-') + inboundHost = inboundSplit[2] + inboundPort = inboundSplit[3] + inboundTimestamp = inboundSplit[4] try: - diameterResponse = self.diameterLibrary.generateDiameterResponse(requestBinaryData=requestBinary) + diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) + diameterMessageTypeDict = self.diameterLibrary.getDiameterMessageType(binaryData=inboundBinary) + diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') + diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') except Exception as e: - self.hssLogger.warn(f"[HSS] Failed to generate diameter response: {e}") + self.hssLogger.warn(f"[HSS] [handleInboundQueue] Failed to generate diameter outbound: {e}") continue - - self.hssLogger.debug(f"[HSS] Generated Diameter Response: {diameterResponse}") - if not len(diameterResponse) > 0: + + self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}") + self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}") + if not len(diameterOutbound) > 0: continue - outboundResponseQueue = f"diameter-response-{requestHost}-{requestPort}-{requestTimestamp}" - outboundResponse = json.dumps({"diameter-response": diameterResponse}) + outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}-{inboundTimestamp}" + outboundMessage = json.dumps({"diameter-outbound": diameterOutbound}) - self.hssLogger.debug(f"[HSS] Outbound Diameter Response Queue: {outboundResponseQueue}") - self.hssLogger.debug(f"[HSS] Outbound Diameter Response: {outboundResponse}") + self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}") + self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}") + self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}") - asyncio.ensure_future(self.redisMessaging.sendMessage(queue=outboundResponseQueue, message=outboundResponse, queueExpiry=60)) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) except Exception as e: continue @@ -69,4 +74,4 @@ async def handleRequestQueue(self): if __name__ == '__main__': hssService = HssService() - asyncio.run(hssService.handleRequestQueue()) \ No newline at end of file + hssService.handleQueue() \ No newline at end of file From c9f48b235e6639cfb4d6ddb7fb6f32521ad5209b Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 25 Aug 2023 13:53:12 +1000 Subject: [PATCH 05/43] Add LogService --- lib/banners.py | 18 +++++++ lib/database.py | 5 +- lib/diameterAsync.py | 11 ++--- lib/logtool.py | 98 +++++++++++++++++++++++++++++-------- lib/messaging.py | 14 ++++++ lib/messagingAsync.py | 16 ++++++ services/diameterService.py | 56 ++++++++++----------- services/hssService.py | 16 +++--- services/logService.py | 82 +++++++++++++++++++++++++++++++ 9 files changed, 251 insertions(+), 65 deletions(-) create mode 100644 services/logService.py diff --git a/lib/banners.py b/lib/banners.py index 0c3f51b..933a361 100644 --- a/lib/banners.py +++ b/lib/banners.py @@ -70,5 +70,23 @@ def metricService(self) -> str: Metric Service +""" + return bannerText + + def logService(self) -> str: + bannerText = """ + + ###### ## ## ##### ##### + ## ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## + ###### ## ## ####### ##### ##### + ## ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## + ## ##### ## ## ##### ##### + ## + #### + + Log Service + """ return bannerText \ No newline at end of file diff --git a/lib/database.py b/lib/database.py index aa400d6..5228a00 100755 --- a/lib/database.py +++ b/lib/database.py @@ -22,14 +22,15 @@ import requests import threading from logtool import LogTool +import logging from messaging import RedisMessaging import yaml with open("../config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -logTool = LogTool() -dbLogger = logTool.setupLogger(loggerName='Database', config=yaml_config) +logTool = LogTool(yaml_config) +dbLogger = logging.getLogger('Database') dbLogger.info("DB Log Initialised.") redisMessaging = RedisMessaging() diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index e9fd657..a8a60c7 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -4,7 +4,7 @@ class DiameterAsync: - def __init__(self, logger): + def __init__(self): self.diameterCommandList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, @@ -24,7 +24,6 @@ def __init__(self, logger): {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] - self.logger = logger #Generates rounding for calculating padding async def myRoundAsync(self, n, base=4): @@ -121,7 +120,7 @@ async def decodeAvpPacketAsync(self, data): elif str(e) == "Length of data is too short to be valid AVP": pass else: - self.logger.warn("[Diameter] [decodeAvpPacketAsync] failed to decode sub-avp - error: " + str(e)) + #self.logger.warn("[Diameter] [decodeAvpPacketAsync] failed to decode sub-avp - error: " + str(e)) pass remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again @@ -138,7 +137,7 @@ async def getDiameterMessageTypeAsync(self, binaryData: str) -> dict: assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) response['inbound'] = diameterApplication["requestAcronym"] response['outbound'] = diameterApplication["responseAcronym"] - self.logger.debug(f"[Diameter] [getDiameterMessageTypeAsync] Successfully got message type: {response}") + #self.logger.debug(f"[Diameter] [getDiameterMessageTypeAsync] Successfully got message type: {response}") except Exception as e: continue @@ -150,7 +149,7 @@ async def generateDiameterResponseAsync(self, binaryData: str) -> str: # Drop packet if it's a response packet: if packet_vars["flags_bin"][0:1] == "0": - self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Got a Response, not a request - dropping it: {packet_vars}") + #self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Got a Response, not a request - dropping it: {packet_vars}") return for diameterApplication in self.diameterCommandList: @@ -160,7 +159,7 @@ async def generateDiameterResponseAsync(self, binaryData: str) -> str: if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) response = diameterApplication["responseMethod"](packet_vars, avps) - self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Successfully generated response: {response}") + #self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Successfully generated response: {response}") except Exception as e: continue diff --git a/lib/logtool.py b/lib/logtool.py index ec53ce6..668a30e 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -1,28 +1,84 @@ import logging import logging.handlers as handlers -import os, sys +import os, sys, time +from datetime import datetime sys.path.append(os.path.realpath('../')) +import asyncio + +class TimestampFilter (logging.Filter): + """ + Logging filter which checks for a `timestamp` attribute on a + given LogRecord, and if present it will override the LogRecord creation time. + Expects time.time() or equivalent integer. + """ + + def filter(self, record): + if hasattr(record, 'timestamp'): + record.created = record.timestamp + return True class LogTool: + """ + Reusable logging class, providing both asynchronous and synchronous logging functions. + """ + def __init__(self, config: dict): + self.logLevels = { + 'CRITICAL': {'verbosity': 1, 'logging': logging.CRITICAL}, + 'ERROR': {'verbosity': 2, 'logging': logging.ERROR}, + 'WARNING': {'verbosity': 3, 'logging': logging.WARNING}, + 'INFO': {'verbosity': 4, 'logging': logging.INFO}, + 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, + 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, + } + self.logLevel = config.get('logging', {}).get('level', 'INFO') + + async def logAsync(self, service: str, level: str, message: str, redisClient) -> bool: + """ + Tests loglevel, prints to console and queues a log message to an asynchronous redis messaging client. + """ + configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) + messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) + if not messageLogLevelVerbosity <= configLogLevelVerbosity: + return False + timestamp = time.time() + dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() + print(f"[{dateTimeString}] [{level.upper()}] {message}") + asyncio.ensure_future(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) + return True + + def log(self, service: str, level: str, message: str, redisClient) -> bool: + """ + Tests loglevel, prints to console and queues a log message to a synchronous redis messaging client. + """ + configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) + messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) + if not messageLogLevelVerbosity <= configLogLevelVerbosity: + return False + timestamp = time.time() + dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() + print(f"[{dateTimeString}] [{level.upper()}] {message}") + redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60) + return True - def setupLogger(self, loggerName: str, config: dict): - # logFile = config.get('logging', {}).get('logfiles', {}).get(f'{loggerName.lower()}_logging_file', '/var/log/pyhss_general.log') - logLevel = config.get('logging', {}).get('level', 'INFO') - logger = logging.getLogger(loggerName) + def setupFileLogger(self, loggerName: str, logFilePath: str): + """ + Sets up and returns a file logger, given a loggerName and logFilePath. + Defaults to {pyhssRootDir}/log/{logFileName} if the configured file location is not writable. + """ + try: + rolloverHandler = handlers.RotatingFileHandler(logFilePath, maxBytes=50000000, backupCount=5) + except PermissionError: + logFileName = logFilePath.split('/')[-1] + pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + print(f"[LogTool] Warning - Unable to write to {logFilePath}, using {pyhssRootDir}/log/{logFileName} instead.") + logFilePath = f"{pyhssRootDir}/log/{logFileName}" + rolloverHandler = handlers.RotatingFileHandler(logFilePath, maxBytes=50000000, backupCount=5) + fileLogger = logging.getLogger(loggerName) + print(logFilePath) formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s {%(pathname)s:%(lineno)d} %(message)s", datefmt="%m/%d/%Y %H:%M:%S %Z") - # try: - # rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) - # except PermissionError: - # logFileName = logFile.split('/')[-1] - # pyhssRootDir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) - # print(f"[LogTool] Warning - Unable to write to {logFile}, using {pyhssRootDir}/log/{logFileName} instead.") - # logFile = f"{pyhssRootDir}/log/{logFileName}" - # rolloverHandler = handlers.RotatingFileHandler(logFile, maxBytes=50000000, backupCount=5) - # pass - streamHandler = logging.StreamHandler() - streamHandler.setFormatter(formatter) - # rolloverHandler.setFormatter(formatter) - logger.setLevel(logLevel) - logger.addHandler(streamHandler) - # logger.addHandler(rolloverHandler) - return logger \ No newline at end of file + filter = TimestampFilter() + fileLogger.addFilter(filter) + rolloverHandler.setFormatter(formatter) + fileLogger.addHandler(rolloverHandler) + fileLogger.setLevel(logging.DEBUG) + return fileLogger \ No newline at end of file diff --git a/lib/messaging.py b/lib/messaging.py index 33975c7..3adb7a5 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -50,6 +50,20 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA except Exception as e: return '' + def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + """ + Stores a message in a given Queue (Key). + """ + try: + logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}" + logMessage = json.dumps({"message": message}) + self.redisClient.rpush(logQueueName, logMessage) + if logExpiry is not None: + self.redisClient.expire(logQueueName, logExpiry) + return f'{message} stored in {logQueueName} successfully.' + except Exception as e: + return '' + def getMessage(self, queue: str) -> str: """ Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 7aceafe..843fb87 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -54,6 +54,22 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m except Exception as e: return '' + async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + """ + Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. + """ + try: + logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}" + logMessage = json.dumps({"message": message}) + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await redisPipe.rpush(logQueueName, logMessage) + if logExpiry is not None: + await redisPipe.expire(logQueueName, logExpiry) + sendMessageResult, expireKeyResult = await redisPipe.execute() + return f'{message} stored in {logQueueName} successfully.' + except Exception as e: + return '' + async def getMessage(self, queue: str) -> str: """ Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. diff --git a/services/diameterService.py b/services/diameterService.py index 1a167dc..8ddcdda 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,6 +1,8 @@ import asyncio import sys, os, json import time, yaml, uuid +import concurrent.futures +import logging sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync @@ -24,11 +26,10 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.banners = Banners() - self.logTool = LogTool() - self.diameterLogger = self.logTool.setupLogger(loggerName='Diameter', config=self.config) - self.diameterLibrary = DiameterAsync(logger=self.diameterLogger) + self.logTool = LogTool(config=self.config) + self.diameterLibrary = DiameterAsync() self.activeConnections = set() - + async def validateDiameterInbound(self, inboundData) -> bool: """ Asynchronously validates a given diameter inbound, and increments the 'Number of Diameter Inbounds' metric. @@ -55,32 +56,31 @@ async def logActiveConnections(self): """ Logs the number of active connections on a rolling basis. """ - while True: - activeConnections = self.activeConnections - if not len(activeConnections) > 0: - activeConnections = '' - self.diameterLogger.info(f"[Diameter] [logActiveConnections] {len(self.activeConnections)} Active Connections {activeConnections}") - await(asyncio.sleep(60)) + activeConnections = self.activeConnections + if not len(activeConnections) > 0: + activeConnections = '' + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActiveConnections] {len(self.activeConnections)} Active Connections {activeConnections}", redisClient=self.redisMessaging)) async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. Terminates the connection if diameter traffic is not received, or if the client disconnects. """ - self.diameterLogger.info(f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) while True: try: inboundData = await(asyncio.wait_for(reader.read(1024), timeout=socketTimeout)) if reader.at_eof(): + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.", redisClient=self.redisMessaging)) return False if len(inboundData) > 0: - self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) if not await(self.validateDiameterInbound(inboundData)): - self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, terminating connection.") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, terminating connection.", redisClient=self.redisMessaging)) return False diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=inboundData)) @@ -88,19 +88,18 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) - self.diameterLogger.debug(f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}", redisClient=self.redisMessaging)) asyncio.ensure_future(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=60)) except Exception as e: - self.diameterLogger.info(f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.") - self.diameterLogger.debug(e) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}", redisClient=self.redisMessaging)) return False async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Continually polls the Redis queue for outbound messages. Received messages from the queue are validated against the connected client, and sent. """ - self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) while True: try: @@ -111,8 +110,7 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s if not len(pendingOutboundQueues) > 0: await(asyncio.sleep(0)) continue - - self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queues: {pendingOutboundQueues}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queues: {pendingOutboundQueues}", redisClient=self.redisMessaging)) for outboundQueue in pendingOutboundQueues: outboundQueueSplit = str(outboundQueue).split('-') queuedMessageType = outboundQueueSplit[1] @@ -120,18 +118,18 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s diameterOutboundPort = outboundQueueSplit[3] if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': - self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {outboundQueue} to host {clientAddress} on port {clientPort}") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {outboundQueue} to host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) diameterOutbound = json.loads(await(self.redisMessaging.getMessage(queue=outboundQueue))) diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=diameterOutboundBinary)) diameterMessageType = diameterMessageType.get('outbound', '') - self.diameterLogger.debug(f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.") + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.", redisClient=self.redisMessaging)) writer.write(diameterOutboundBinary) await(writer.drain()) await(asyncio.sleep(0)) except Exception: - self.diameterLogger.info(f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.") + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.", redisClient=self.redisMessaging)) return False await(asyncio.sleep(0)) @@ -140,10 +138,11 @@ async def handleConnection(self, reader, writer): For each new connection on port 3868, create an asynchronous reader and writer. If a reader or writer returns false, ensure that the connection is torn down entirely. """ try: - (clientAddress, clientPort) = writer.get_extra_info('peername') - self.diameterLogger.debug(f"[Diameter] Initial Connection from: {clientAddress} on port {clientPort}") coroutineUuid = str(uuid.uuid4()) + (clientAddress, clientPort) = writer.get_extra_info('peername') + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) self.activeConnections.add((clientAddress, clientPort, coroutineUuid)) + await(self.logActiveConnections()) readTask = asyncio.create_task(self.readInboundData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) writeTask = asyncio.create_task(self.writeOutboundData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) @@ -160,11 +159,14 @@ async def handleConnection(self, reader, writer): writer.close() await(writer.wait_closed()) self.activeConnections.discard((clientAddress, clientPort, coroutineUuid)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.", redisClient=self.redisMessaging)) + await(self.logActiveConnections()) + return except Exception as e: - self.diameterLogger.warning(f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}") + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}", redisClient=self.redisMessaging)) return async def startServer(self, host: str=None, port: int=None, type: str=None): @@ -189,9 +191,7 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) - self.diameterLogger.info(self.banners.diameterService()) - self.diameterLogger.info(f'[Diameter] Serving on {servingAddresses}') - asyncio.create_task(self.logActiveConnections()) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}", redisClient=self.redisMessaging)) async with server: await(server.serve_forever()) diff --git a/services/hssService.py b/services/hssService.py index f69f1bc..beaedb9 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -16,16 +16,15 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): print(f"[HSS] Fatal Error - config.yaml not found, exiting.") quit() self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) - self.logTool = LogTool() + self.logTool = LogTool(config=self.config) self.banners = Banners() - self.hssLogger = self.logTool.setupLogger(loggerName='HSS', config=self.config) self.mnc = self.config.get('hss', {}).get('MNC', '999') self.mcc = self.config.get('hss', {}).get('MCC', '999') self.originRealm = self.config.get('hss', {}).get('OriginRealm', f'mnc{self.mnc}.mcc{self.mcc}.3gppnetwork.org') self.originHost = self.config.get('hss', {}).get('OriginHost', f'hss01') self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') self.diameterLibrary = Diameter(originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) - self.hssLogger.info(self.banners.hssService()) + self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) def handleQueue(self): """ @@ -53,17 +52,18 @@ def handleQueue(self): self.hssLogger.warn(f"[HSS] [handleInboundQueue] Failed to generate diameter outbound: {e}") continue - self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}") - self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}") + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) + if not len(diameterOutbound) > 0: continue outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}-{inboundTimestamp}" outboundMessage = json.dumps({"diameter-outbound": diameterOutbound}) - self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}") - self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}") - self.hssLogger.debug(f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}") + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) diff --git a/services/logService.py b/services/logService.py new file mode 100644 index 0000000..568c0d7 --- /dev/null +++ b/services/logService.py @@ -0,0 +1,82 @@ +import os, sys, json, yaml +from datetime import datetime +import logging +sys.path.append(os.path.realpath('../lib')) +from messaging import RedisMessaging +from banners import Banners +from logtool import LogTool + +class LogService: + """ + PyHSS Log Service + A class for handling queued log entries in the Redis DB. + This class is synchronous and not high-performance. + """ + + def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + try: + with open("../config.yaml", "r") as self.configFile: + self.config = yaml.safe_load(self.configFile) + except: + print(f"[Log] Fatal Error - config.yaml not found, exiting.") + quit() + self.logTool = LogTool(config=self.config) + self.banners = Banners() + self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.logFilePaths = self.config.get('logging', {}).get('logfiles', {}) + self.logLevels = { + 'CRITICAL': {'verbosity': 1, 'logging': logging.CRITICAL}, + 'ERROR': {'verbosity': 2, 'logging': logging.ERROR}, + 'WARNING': {'verbosity': 3, 'logging': logging.WARNING}, + 'INFO': {'verbosity': 4, 'logging': logging.INFO}, + 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, + 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, + } + print(f"{self.banners.logService()}") + + + def handleLogs(self): + """ + Continually polls the Redis DB for queued log files. Parses and writes log files to disk, using LogTool. + """ + activeLoggers = {} + while True: + try: + logQueue = self.redisMessaging.getNextQueue(pattern='log-*') + logMessage = self.redisMessaging.getMessage(queue=logQueue) + + if not len(logMessage) > 0: + continue + + print(f"[Log] Queue: {logQueue}") + print(f"[Log] Message: {logMessage}") + + logSplit = logQueue.split('-') + logService = logSplit[1].lower() + logLevel = logSplit[2].upper() + logTimestamp = logSplit[3] + + logDict = json.loads(logMessage) + logFileMessage = logDict['message'] + + + if f"{logService}_logging_file" not in self.logFilePaths: + continue + + logFileName = f"{logService}_logging_file" + logFilePath = self.logFilePaths.get(logFileName, '/var/log/pyhss.log') + + if logService not in activeLoggers: + activeLoggers[logService] = self.logTool.setupFileLogger(loggerName=logService, logFilePath=logFilePath) + + fileLogger = activeLoggers[logService] + fileLogger.log(self.logLevels.get(logLevel.upper(), {}).get('logging', logging.INFO), logFileMessage, extra={'timestamp': float(logTimestamp)}) + + + except Exception as e: + self.logTool.log(service='Log', level='error', message=f"[Log] Error: {e}", redisClient=self.redisMessaging) + continue + +if __name__ == '__main__': + logService = LogService() + logService.handleLogs() \ No newline at end of file From 018e88ab90bd0fc66312d34beba797562be5df5c Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Mon, 28 Aug 2023 11:06:32 +1000 Subject: [PATCH 06/43] Database partial refactor --- lib/database.py | 3284 +++++++++++++++++------------------ lib/diameter.py | 835 ++++----- lib/messaging.py | 8 +- lib/messagingAsync.py | 6 +- services/diameterService.py | 2 - services/georedService.py | 18 +- services/hssService.py | 14 +- services/metricService.py | 18 +- 8 files changed, 2111 insertions(+), 2074 deletions(-) diff --git a/lib/database.py b/lib/database.py index 5228a00..d7cfb02 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1,47 +1,24 @@ -from sqlalchemy import Column, Integer, String, MetaData, Table, Boolean, ForeignKey, select, UniqueConstraint, DateTime, BigInteger, event, Text, DateTime, Float +from sqlalchemy import Column, Integer, String, MetaData, Table, Boolean, ForeignKey, select, UniqueConstraint, DateTime, BigInteger, Text, DateTime, Float from sqlalchemy import create_engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql import desc, func from sqlalchemy_utils import database_exists, create_database from sqlalchemy.orm import sessionmaker, relationship, Session, class_mapper from sqlalchemy.orm.attributes import History, get_history -import sys, os -from functools import wraps -import json +from sqlalchemy.ext.declarative import declarative_base +import os import datetime, time from datetime import timezone import re import binascii import uuid import socket -import traceback -from contextlib import contextmanager import pprint -from construct import Default import S6a_crypt -import requests import threading -from logtool import LogTool -import logging from messaging import RedisMessaging - import yaml -with open("../config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -logTool = LogTool(yaml_config) -dbLogger = logging.getLogger('Database') -dbLogger.info("DB Log Initialised.") -redisMessaging = RedisMessaging() - -db_string = 'mysql://' + str(yaml_config['database']['username']) + ':' + str(yaml_config['database']['password']) + '@' + str(yaml_config['database']['server']) + '/' + str(yaml_config['database']['database'] + "?autocommit=true") -engine = create_engine( - db_string, - echo = yaml_config['logging'].get('sqlalchemy_sql_echo', True), - pool_recycle=yaml_config['logging'].get('sqlalchemy_pool_recycle', 5), - pool_size=yaml_config['logging'].get('sqlalchemy_pool_size', 30), - max_overflow=yaml_config['logging'].get('sqlalchemy_max_overflow', 0)) -from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() class OPERATION_LOG_BASE(Base): @@ -278,1755 +255,1776 @@ class SUBSCRIBER_ATTRIBUTES(Base): value = Column(String(12000), doc='Arbitrary value') operation_logs = relationship("SUBSCRIBER_ATTRIBUTES_OPERATION_LOG", back_populates="subscriber_attributes") -# Create database if it does not exist. -if not database_exists(engine.url): - dbLogger.debug("Creating database") - create_database(engine.url) - Base.metadata.create_all(engine) -else: - dbLogger.debug("Database already created") +class Database: -def load_IMEI_database_into_Redis(): - try: - dbLogger.info("Reading IMEI TAC database CSV from " + str(yaml_config['eir']['tac_database_csv'])) - csvfile = open(str(yaml_config['eir']['tac_database_csv'])) - dbLogger.info("This may take a few seconds to buffer into Redis...") - except: - dbLogger.error("Failed to read CSV file of IMEI TAC database") - return - try: - count = 0 - for line in csvfile: - line = line.replace('"', '') #Strip excess invered commas - line = line.replace("'", '') #Strip excess invered commas - line = line.rstrip() #Strip newlines - result = line.split(',') - tac_prefix = result[0] - name = result[1].lstrip() - model = result[2].lstrip() - if count == 0: - dbLogger.info("Checking to see if entries are already present...") - redis_imei_result = redisMessaging.getMessage(key=str(tac_prefix)) - if len(redis_imei_result) != 0: - dbLogger.info("IMEI TAC Database already loaded into Redis - Skipping reading from file...") - break - else: - dbLogger.info("No data loaded into Redis, proceeding to load...") - imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - redisMessaging.sendMessage(key=str(tac_prefix), value_dict=imei_result) - count = count +1 - dbLogger.info("Loaded " + str(count) + " IMEI TAC entries into Redis") - except Exception as E: - dbLogger.error("Failed to load IMEI Database into Redis due to error: " + (str(E))) - return + def __init__(self, logTool, redisMessaging): + with open("../config.yaml", 'r') as stream: + self.config = (yaml.safe_load(stream)) -#Load IMEI TAC database into Redis if enabled -if ('tac_database_csv' in yaml_config['eir']) and (yaml_config['redis']['enabled'] == True): - load_IMEI_database_into_Redis() -else: - dbLogger.info("Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config") + self.logTool = logTool + self.redisMessaging = redisMessaging + db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") + self.engine = create_engine( + db_string, + echo = self.config['logging'].get('sqlalchemy_sql_echo', True), + pool_recycle=self.config['logging'].get('sqlalchemy_pool_recycle', 5), + pool_size=self.config['logging'].get('sqlalchemy_pool_size', 30), + max_overflow=self.config['logging'].get('sqlalchemy_max_overflow', 0)) -def safe_rollback(session): - try: - if session.is_active: - session.rollback() - except Exception as E: - dbLogger.error(f"Failed to rollback session, error: {E}") + # Create database if it does not exist. + if not database_exists(self.engine.url): + self.logTool.log(service='Database', level='debug', message="Creating database", redisClient=self.redisMessaging) + create_database(self.engine.url) + Base.metadata.create_all(self.engine) + else: + self.logTool.log(service='Database', level='debug', message="Database already created", redisClient=self.redisMessaging) -def safe_close(session): - try: - if session.is_active: - session.close() - except Exception as E: - dbLogger.error(f"Failed to run safe_close on session, error: {E}") - -def sqlalchemy_type_to_json_schema_type(sqlalchemy_type): - """ - Map SQLAlchemy types to JSON Schema types. - """ - if isinstance(sqlalchemy_type, Integer): - return "integer" - elif isinstance(sqlalchemy_type, String): - return "string" - elif isinstance(sqlalchemy_type, Boolean): - return "boolean" - elif isinstance(sqlalchemy_type, DateTime): - return "string" - elif isinstance(sqlalchemy_type, Float): - return "number" - else: - return "string" # Default to string for unsupported types. - -def generate_json_schema(model_class, required=None): - properties = {} - required = required or [] - - for column in model_class.__table__.columns: - prop_type = sqlalchemy_type_to_json_schema_type(column.type) - prop_dict = { - "type": prop_type, - "description": column.doc - } - if prop_type == "string": - if hasattr(column.type, 'length'): - prop_dict["maxLength"] = column.type.length - if isinstance(column.type, DateTime): - prop_dict["format"] = "date-time" - if not column.nullable: - required.append(column.name) - properties[column.name] = prop_dict - - return {"type": "object", "title" : str(model_class.__name__), "properties": properties, "required": required} - -# Create individual tables if they do not exist. -inspector = Inspector.from_engine(engine) -for table_name in Base.metadata.tables.keys(): - if table_name not in inspector.get_table_names(): - dbLogger.debug(f"Creating table {table_name}") - Base.metadata.tables[table_name].create(bind=engine) - else: - dbLogger.debug(f"Table {table_name} already exists") - -def update_old_record(session, operation_log): - oldest_log = session.query(OPERATION_LOG_BASE).order_by(OPERATION_LOG_BASE.timestamp.asc()).first() - if oldest_log is not None: - for attr in class_mapper(oldest_log.__class__).column_attrs: - if attr.key != 'id' and hasattr(operation_log, attr.key): - setattr(oldest_log, attr.key, getattr(operation_log, attr.key)) - oldest_log.timestamp = datetime.datetime.now(tz=timezone.utc) - session.flush() - else: - raise ValueError("Unable to find record to update") - -def log_change(session, item_id, operation, changes, table_name, operation_id, generated_id=None): - # We don't want to log rollback operations - if session.info.get("operation") == 'ROLLBACK': - return - max_records = 1000 - count = session.query(OPERATION_LOG_BASE).count() - - # Combine all changes into a single string with their types - changes_string = '\r\n\r\n'.join(f"{column_name}: [{type(old_value).__name__}] {old_value} ----> [{type(new_value).__name__}] {new_value}" for column_name, old_value, new_value in changes) - - change = OPERATION_LOG_BASE( - item_id=item_id or generated_id, - operation_id=operation_id, - operation=operation, - last_modified=datetime.datetime.now(tz=timezone.utc), - changes=changes_string, - table_name=table_name - ) + #Load IMEI TAC database into Redis if enabled + if ('tac_database_csv' in self.config['eir']) and (self.config['redis']['enabled'] == True): + self.load_IMEI_database_into_Redis() + else: + self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) + + # Create individual tables if they do not exist. + inspector = Inspector.from_engine(self.engine) + for table_name in Base.metadata.tables.keys(): + if table_name not in inspector.get_table_names(): + self.logTool.log(service='Database', level='debug', message=f"Creating table {table_name}", redisClient=self.redisMessaging) + Base.metadata.tables[table_name].create(bind=self.engine) + else: + self.logTool.log(service='Database', level='debug', message=f"Table {table_name} already exists", redisClient=self.redisMessaging) - if count >= max_records: - update_old_record(session, change) - else: + + def load_IMEI_database_into_Redis(self): try: - session.add(change) - session.flush() + self.logTool.log(service='Database', level='info', message="Reading IMEI TAC database CSV from " + str(self.config['eir']['tac_database_csv']), redisClient=self.redisMessaging) + csvfile = open(str(self.config['eir']['tac_database_csv'])) + self.logTool.log(service='Database', level='info', message="This may take a few seconds to buffer into Redis...", redisClient=self.redisMessaging) + except: + self.logTool.log(service='Database', level='error', message="Failed to read CSV file of IMEI TAC database", redisClient=self.redisMessaging) + return + try: + count = 0 + for line in csvfile: + line = line.replace('"', '') #Strip excess invered commas + line = line.replace("'", '') #Strip excess invered commas + line = line.rstrip() #Strip newlines + result = line.split(',') + tac_prefix = result[0] + name = result[1].lstrip() + model = result[2].lstrip() + if count == 0: + self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) + redis_imei_result = self.redisMessaging.getMessage(key=str(tac_prefix)) + if len(redis_imei_result) != 0: + self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) + break + else: + self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) + imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} + self.redisMessaging.sendMessage(key=str(tac_prefix), value_dict=imei_result) + count = count +1 + self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) except Exception as E: - dbLogger.error("Failed to commit changelog, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - return operation_id + self.logTool.log(service='Database', level='error', message="Failed to load IMEI Database into Redis due to error: " + (str(E)), redisClient=self.redisMessaging) + return + def safe_rollback(self, session): + try: + if session.is_active: + session.rollback() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to rollback session, error: {E}", redisClient=self.redisMessaging) -def log_changes_before_commit(session): + def safe_close(self, session): + try: + if session.is_active: + session.close() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to run safe_close on session, error: {E}", redisClient=self.redisMessaging) + + def sqlalchemy_type_to_json_schema_type(self, sqlalchemy_type): + """ + Map SQLAlchemy types to JSON Schema types. + """ + if isinstance(sqlalchemy_type, Integer): + return "integer" + elif isinstance(sqlalchemy_type, String): + return "string" + elif isinstance(sqlalchemy_type, Boolean): + return "boolean" + elif isinstance(sqlalchemy_type, DateTime): + return "string" + elif isinstance(sqlalchemy_type, Float): + return "number" + else: + return "string" # Default to string for unsupported types. + + def generate_json_schema(self, model_class, required=None): + properties = {} + required = required or [] + + for column in model_class.__table__.columns: + prop_type = self.sqlalchemy_type_to_json_schema_type(column.type) + prop_dict = { + "type": prop_type, + "description": column.doc + } + if prop_type == "string": + if hasattr(column.type, 'length'): + prop_dict["maxLength"] = column.type.length + if isinstance(column.type, DateTime): + prop_dict["format"] = "date-time" + if not column.nullable: + required.append(column.name) + properties[column.name] = prop_dict + + return {"type": "object", "title" : str(model_class.__name__), "properties": properties, "required": required} + + def update_old_record(self, session, operation_log): + oldest_log = session.query(OPERATION_LOG_BASE).order_by(OPERATION_LOG_BASE.timestamp.asc()).first() + if oldest_log is not None: + for attr in class_mapper(oldest_log.__class__).column_attrs: + if attr.key != 'id' and hasattr(operation_log, attr.key): + setattr(oldest_log, attr.key, getattr(operation_log, attr.key)) + oldest_log.timestamp = datetime.datetime.now(tz=timezone.utc) + session.flush() + else: + raise ValueError("Unable to find record to update") + + def log_change(self, session, item_id, operation, changes, table_name, operation_id, generated_id=None): + # We don't want to log rollback operations + if session.info.get("operation") == 'ROLLBACK': + return + max_records = 1000 + count = session.query(OPERATION_LOG_BASE).count() + + # Combine all changes into a single string with their types + changes_string = '\r\n\r\n'.join(f"{column_name}: [{type(old_value).__name__}] {old_value} ----> [{type(new_value).__name__}] {new_value}" for column_name, old_value, new_value in changes) + + change = OPERATION_LOG_BASE( + item_id=item_id or generated_id, + operation_id=operation_id, + operation=operation, + last_modified=datetime.datetime.now(tz=timezone.utc), + changes=changes_string, + table_name=table_name + ) + + if count >= max_records: + self.update_old_record(session, change) + else: + try: + session.add(change) + session.flush() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to commit changelog, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + return operation_id + + + def log_changes_before_commit(self, session): + + operation_id = session.info.get("operation_id", None) or str(uuid.uuid4()) + if session.info.get("operation") == 'ROLLBACK': + return + + changelog_pending = any(isinstance(obj, OPERATION_LOG_BASE) for obj in session.new) + if changelog_pending: + return # Skip if there are pending OPERATION_LOG_BASE objects + + for state, operation in [ + (session.new, 'INSERT'), + (session.dirty, 'UPDATE'), + (session.deleted, 'DELETE') + ]: + for obj in state: + if isinstance(obj, OPERATION_LOG_BASE): + continue # Skip change log entries + + item_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) + generated_id = None + + #Avoid logging rollback operations + if operation == 'ROLLBACK': + return + + # Flush the session to generate primary key for new objects + if operation == 'INSERT': + session.flush() + + if operation == 'UPDATE': + changes = [] + for attr in class_mapper(obj.__class__).column_attrs: + hist = get_history(obj, attr.key) + self.logTool.log(service='Database', level='info', message=f"History {hist}", redisClient=self.redisMessaging) + if hist.has_changes() and hist.added and hist.deleted: + old_value, new_value = hist.deleted[0], hist.added[0] + self.logTool.log(service='Database', level='info', message=f"Old Value {old_value}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='info', message=f"New Value {new_value}", redisClient=self.redisMessaging) + changes.append((attr.key, old_value, new_value)) + continue + + if not changes: + continue - operation_id = session.info.get("operation_id", None) or str(uuid.uuid4()) - if session.info.get("operation") == 'ROLLBACK': - return + operation_id = self.log_change(session, item_id, operation, changes, obj.__table__.name, operation_id) + + elif operation in ['INSERT', 'DELETE']: + changes = [] + for column in obj.__table__.columns: + column_name = column.name + value = getattr(obj, column_name) + if operation == 'INSERT': + old_value, new_value = None, value + if item_id is None: + generated_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) + elif operation == 'DELETE': + old_value, new_value = value, None + changes.append((column_name, old_value, new_value)) + operation_id = self.log_change(session, item_id, operation, changes, obj.__table__.name, operation_id, generated_id) + + def get_class_by_tablename(self, base, tablename): + """ + Returns a class object based on the given tablename. + + :param base: Base class of SQLAlchemy models + :param tablename: Name of the table to retrieve the class for + :return: Class object or None if not found + """ + for mapper in base.registry.mappers: + cls = mapper.class_ + if hasattr(cls, '__tablename__') and cls.__tablename__ == tablename: + return cls + return None - changelog_pending = any(isinstance(obj, OPERATION_LOG_BASE) for obj in session.new) - if changelog_pending: - return # Skip if there are pending OPERATION_LOG_BASE objects + def str_to_type(self, type_str, value_str): + if type_str == 'int': + return int(value_str) + elif type_str == 'float': + return float(value_str) + elif type_str == 'str': + return value_str + elif type_str == 'bool': + return value_str == 'True' + elif type_str == 'NoneType': + return None + else: + raise ValueError(f'Cannot convert to type: {type_str}') - for state, operation in [ - (session.new, 'INSERT'), - (session.dirty, 'UPDATE'), - (session.deleted, 'DELETE') - ]: - for obj in state: - if isinstance(obj, OPERATION_LOG_BASE): - continue # Skip change log entries - item_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) - generated_id = None + def rollback_last_change(self, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession - #Avoid logging rollback operations - if operation == 'ROLLBACK': - return + try: + # Get the most recent operation + last_operation = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() + + if last_operation is None: + return "No operations to roll back." + + rollback_messages = [] + operation_id = str(uuid.uuid4()) + + target_class = self.get_class_by_tablename(Base, last_operation.table_name) + if not target_class: + return f"Error: Could not find table {last_operation.table_name}" + + primary_key_col = target_class.__mapper__.primary_key[0].key + filter_by_kwargs = {primary_key_col: last_operation.item_id} + target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() + + if last_operation.operation == 'UPDATE': + if not target_item: + return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" + + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) + old_value = self.str_to_type(old_type_str, old_value_repr) + + # Revert the change + setattr(target_item, column_name, old_value) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" + ) + + elif last_operation.operation == 'INSERT': + if target_item: + session.delete(target_item) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" + ) + + elif last_operation.operation == 'DELETE': + # Aggregate old values of all columns into a single dictionary + old_values_dict = {} + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) + self.logTool.log(service='Database', level='error', message=f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}", redisClient=self.redisMessaging) + old_value = self.str_to_type(old_type_str, old_value_repr) + + old_values_dict[column_name] = old_value + self.logTool.log(service='Database', level='error', message="old_value_dict: " + str(old_values_dict), redisClient=self.redisMessaging) + + if not target_item: + try: + # Create the target item using the aggregated old values + target_item = target_class(**old_values_dict) + session.add(target_item) + except Exception as e: + return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" + ) - # Flush the session to generate primary key for new objects - if operation == 'INSERT': - session.flush() + else: + return f"Error: Unknown operation {last_operation.operation}" - if operation == 'UPDATE': - changes = [] - for attr in class_mapper(obj.__class__).column_attrs: - hist = get_history(obj, attr.key) - dbLogger.info(f"History {hist}") - if hist.has_changes() and hist.added and hist.deleted: - old_value, new_value = hist.deleted[0], hist.added[0] - dbLogger.info(f"Old Value {old_value}") - dbLogger.info(f"New Value {new_value}") - changes.append((attr.key, old_value, new_value)) - continue + try: + session.commit() + self.safe_close(session) + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - if not changes: - continue + return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) - operation_id = log_change(session, item_id, operation, changes, obj.__table__.name, operation_id) - - elif operation in ['INSERT', 'DELETE']: - changes = [] - for column in obj.__table__.columns: - column_name = column.name - value = getattr(obj, column_name) - if operation == 'INSERT': - old_value, new_value = None, value - if item_id is None: - generated_id = getattr(obj, list(obj.__table__.primary_key.columns.keys())[0]) - elif operation == 'DELETE': - old_value, new_value = value, None - changes.append((column_name, old_value, new_value)) - operation_id = log_change(session, item_id, operation, changes, obj.__table__.name, operation_id, generated_id) - -def get_class_by_tablename(base, tablename): - """ - Returns a class object based on the given tablename. - - :param base: Base class of SQLAlchemy models - :param tablename: Name of the table to retrieve the class for - :return: Class object or None if not found - """ - for mapper in base.registry.mappers: - cls = mapper.class_ - if hasattr(cls, '__tablename__') and cls.__tablename__ == tablename: - return cls - return None - -def str_to_type(type_str, value_str): - if type_str == 'int': - return int(value_str) - elif type_str == 'float': - return float(value_str) - elif type_str == 'str': - return value_str - elif type_str == 'bool': - return value_str == 'True' - elif type_str == 'NoneType': - return None - else: - raise ValueError(f'Cannot convert to type: {type_str}') + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + def rollback_change_by_operation_id(self, operation_id, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession -def rollback_last_change(existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession + try: + # Get the most recent operation + last_operation = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() + + if last_operation is None: + return "No operation to roll back." + + rollback_messages = [] + operation_id = str(uuid.uuid4()) + + target_class = self.get_class_by_tablename(Base, last_operation.table_name) + if not target_class: + return f"Error: Could not find table {last_operation.table_name}" + + primary_key_col = target_class.__mapper__.primary_key[0].key + filter_by_kwargs = {primary_key_col: last_operation.item_id} + target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() + + if last_operation.operation == 'UPDATE': + if not target_item: + return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" + + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) + old_value = self.str_to_type(old_type_str, old_value_repr) + + # Revert the change + setattr(target_item, column_name, old_value) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" + ) + + elif last_operation.operation == 'INSERT': + if target_item: + session.delete(target_item) + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" + ) + + elif last_operation.operation == 'DELETE': + # Aggregate old values of all columns into a single dictionary + old_values_dict = {} + # Split the changes string into separate changes + changes = last_operation.changes.split('\r\n\r\n') + for change in changes: + column_name, old_new_values = change.split(": ", 1) + old_value_str, new_value_str = old_new_values.split(" ----> ", 1) + + # Extract type and value + old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) + self.logTool.log(service='Database', level='error', message=f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}", redisClient=self.redisMessaging) + old_value = self.str_to_type(old_type_str, old_value_repr) + + old_values_dict[column_name] = old_value + self.logTool.log(service='Database', level='error', message="old_value_dict: " + str(old_values_dict), redisClient=self.redisMessaging) + + if not target_item: + try: + # Create the target item using the aggregated old values + target_item = target_class(**old_values_dict) + session.add(target_item) + except Exception as e: + return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + + rollback_message = ( + f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" + ) - try: - # Get the most recent operation - last_operation = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() - - if last_operation is None: - return "No operations to roll back." - - rollback_messages = [] - operation_id = str(uuid.uuid4()) - - target_class = get_class_by_tablename(Base, last_operation.table_name) - if not target_class: - return f"Error: Could not find table {last_operation.table_name}" - - primary_key_col = target_class.__mapper__.primary_key[0].key - filter_by_kwargs = {primary_key_col: last_operation.item_id} - target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() - - if last_operation.operation == 'UPDATE': - if not target_item: - return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" - - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) - old_value = str_to_type(old_type_str, old_value_repr) - - # Revert the change - setattr(target_item, column_name, old_value) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" - ) - - elif last_operation.operation == 'INSERT': - if target_item: - session.delete(target_item) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" - ) - - elif last_operation.operation == 'DELETE': - # Aggregate old values of all columns into a single dictionary - old_values_dict = {} - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - dbLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") - old_value = str_to_type(old_type_str, old_value_repr) - - old_values_dict[column_name] = old_value - dbLogger.error("old_value_dict: " + str(old_values_dict)) - - if not target_item: - try: - # Create the target item using the aggregated old values - target_item = target_class(**old_values_dict) - session.add(target_item) - except Exception as e: - return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + else: + return f"Error: Unknown operation {last_operation.operation}" - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" - ) + try: + session.commit() + self.safe_close(session) + except Exception as E: + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - else: - return f"Error: Unknown operation {last_operation.operation}" + return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) - try: - session.commit() - safe_close(session) except Exception as E: - dbLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) + self.logTool.log(service='Database', level='error', message="rollback_last_change error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) raise ValueError(E) - return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) + def get_all_operation_logs(self, page=0, page_size=100, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession - except Exception as E: - dbLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) + try: + # Get all distinct operation_ids ordered by max timestamp (descending order) + operation_ids = session.query(OPERATION_LOG_BASE.operation_id).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) -def rollback_change_by_operation_id(operation_id, existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession + operation_ids = operation_ids.limit(page_size).offset(page * page_size) - try: - # Get the most recent operation - last_operation = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id).order_by(desc(OPERATION_LOG_BASE.timestamp)).first() - - if last_operation is None: - return "No operation to roll back." - - rollback_messages = [] - operation_id = str(uuid.uuid4()) - - target_class = get_class_by_tablename(Base, last_operation.table_name) - if not target_class: - return f"Error: Could not find table {last_operation.table_name}" - - primary_key_col = target_class.__mapper__.primary_key[0].key - filter_by_kwargs = {primary_key_col: last_operation.item_id} - target_item = session.query(target_class).filter_by(**filter_by_kwargs).one_or_none() - - if last_operation.operation == 'UPDATE': - if not target_item: - return f"Error: Could not find item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table" - - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:-1].split("] ", 1) - old_value = str_to_type(old_type_str, old_value_repr) - - # Revert the change - setattr(target_item, column_name, old_value) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Reverted changes" - ) - - elif last_operation.operation == 'INSERT': - if target_item: - session.delete(target_item) - - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Deleted item" - ) - - elif last_operation.operation == 'DELETE': - # Aggregate old values of all columns into a single dictionary - old_values_dict = {} - # Split the changes string into separate changes - changes = last_operation.changes.split('\r\n\r\n') - for change in changes: - column_name, old_new_values = change.split(": ", 1) - old_value_str, new_value_str = old_new_values.split(" ----> ", 1) - - # Extract type and value - old_type_str, old_value_repr = old_value_str[1:].split("] ", 1) - dbLogger.error(f"running str_to_type for: {str(old_type_str)}, {str(old_value_repr)}") - old_value = str_to_type(old_type_str, old_value_repr) - - old_values_dict[column_name] = old_value - dbLogger.error("old_value_dict: " + str(old_values_dict)) - - if not target_item: - try: - # Create the target item using the aggregated old values - target_item = target_class(**old_values_dict) - session.add(target_item) - except Exception as e: - return f"Error: Failed to recreate item with ID {last_operation.item_id} in {last_operation.table_name.upper()} table - {str(e)}" + operation_ids = operation_ids.all() - rollback_message = ( - f"Rolled back '{last_operation.operation}' operation on {last_operation.table_name.upper()} table (ID: {last_operation.item_id}): Re-inserted item" - ) + all_operations = [] - else: - return f"Error: Unknown operation {last_operation.operation}" + for operation_id in operation_ids: + operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() - try: - session.commit() - safe_close(session) + if operation_log is not None: + # Convert the object to dictionary + obj_dict = operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + all_operations.append(sanitized_obj_dict) + + self.safe_close(session) + return all_operations except Exception as E: - dbLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) + self.logTool.log(service='Database', level='error', message=f"get_all_operation_logs error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) raise ValueError(E) - return f"Rolled back operation with operation_id: {operation_id}\n" + "\n".join(rollback_messages) + def get_all_operation_logs_by_table(self, table_name, page=0, page_size=100, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession - except Exception as E: - dbLogger.error("rollback_last_change error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) + try: + # Get all distinct operation_ids ordered by max timestamp (descending order) + operation_ids = session.query(OPERATION_LOG_BASE.operation_id).filter(OPERATION_LOG_BASE.table_name == table_name).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) -def get_all_operation_logs(page=0, page_size=yaml_config['api'].get('page_size', 100), existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession + operation_ids = operation_ids.limit(page_size).offset(page * page_size) - try: - # Get all distinct operation_ids ordered by max timestamp (descending order) - operation_ids = session.query(OPERATION_LOG_BASE.operation_id).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) + operation_ids = operation_ids.all() - operation_ids = operation_ids.limit(page_size).offset(page * page_size) + all_operations = [] - operation_ids = operation_ids.all() + for operation_id in operation_ids: + operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() - all_operations = [] + if operation_log is not None: + # Convert the object to dictionary + obj_dict = operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + all_operations.append(sanitized_obj_dict) - for operation_id in operation_ids: - operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() + self.safe_close(session) + return all_operations + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"get_all_operation_logs_by_table error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - if operation_log is not None: - # Convert the object to dictionary - obj_dict = operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - all_operations.append(sanitized_obj_dict) - - safe_close(session) - return all_operations - except Exception as E: - dbLogger.error(f"get_all_operation_logs error: {E}") - dbLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def get_all_operation_logs_by_table(table_name, page=0, page_size=yaml_config['api'].get('page_size', 100), existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession + def get_last_operation_log(self, existingSession=None): + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() + else: + session = existingSession - try: - # Get all distinct operation_ids ordered by max timestamp (descending order) - operation_ids = session.query(OPERATION_LOG_BASE.operation_id).filter(OPERATION_LOG_BASE.table_name == table_name).group_by(OPERATION_LOG_BASE.operation_id).order_by(desc(func.max(OPERATION_LOG_BASE.timestamp))) + try: + # Get the top 100 records ordered by timestamp (descending order) + top_100_records = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).limit(100) - operation_ids = operation_ids.limit(page_size).offset(page * page_size) + # Get the most recent operation_id + most_recent_operation_log = top_100_records.first() - operation_ids = operation_ids.all() + # Convert the object to dictionary + if most_recent_operation_log is not None: + obj_dict = most_recent_operation_log.__dict__ + obj_dict.pop('_sa_instance_state') + sanitized_obj_dict = self.Sanitize_Datetime(obj_dict) + return sanitized_obj_dict - all_operations = [] + self.safe_close(session) + return None + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"get_last_operation_log error: {E}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='error', message=E, redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - for operation_id in operation_ids: - operation_log = session.query(OPERATION_LOG_BASE).filter(OPERATION_LOG_BASE.operation_id == operation_id[0]).order_by(OPERATION_LOG_BASE.id.asc()).first() + def handleGeored(self, jsonData): + try: + if self.config.get('geored', {}).get('enabled', False): + if self.config.get('geored', {}).get('sync_endpoints', []) is not None and len(self.config.get('geored', {}).get('sync_endpoints', [])) > 0: + transaction_id = str(uuid.uuid4()) + self.logTool.log(service='Database', level='info', message="[Database] Break 1", redisClient=self.redisMessaging) + self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=jsonData, queueExpiry=120) + self.logTool.log(service='Database', level='info', message="[Database] Break 1", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) + + def handleWebhook(self, objectData, operation): + external_webhook_notification_enabled = self.config.get('external', {}).get('external_webhook_notification_enabled', False) + external_webhook_notification_url = self.config.get('external', {}).get('external_webhook_notification_url', '') + + if not external_webhook_notification_enabled: + return False + if not external_webhook_notification_url: + self.logTool.log(service='Database', level='error', message="External webhook notification enabled, but external_webhook_notification_url is not defined.", redisClient=self.redisMessaging) + + externalNotification = self.Sanitize_Datetime(objectData) + externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} + externalNotification['headers'] = externalNotificationHeaders + self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=externalNotification, queueExpiry=120) + return True + + def Sanitize_Datetime(self, result): + for keys in result: + if "timestamp" in keys: + if result[keys] == None: + continue + else: + self.logTool.log(service='Database', level='debug', message="Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String", redisClient=self.redisMessaging) + result[keys] = str(result[keys]) + return result - if operation_log is not None: - # Convert the object to dictionary - obj_dict = operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - all_operations.append(sanitized_obj_dict) - - safe_close(session) - return all_operations - except Exception as E: - dbLogger.error(f"get_all_operation_logs_by_table error: {E}") - dbLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def get_last_operation_log(existingSession=None): - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) + def Sanitize_Keys(self, result): + names_to_strip = ['opc', 'ki', 'des', 'kid', 'psk', 'adm1'] + for name_to_strip in names_to_strip: + try: + result.pop(name_to_strip) + except: + pass + return result + + def GetObj(self, obj_type, obj_id=None, page=None, page_size=None): + self.logTool.log(service='Database', level='debug', message="Called GetObj for type " + str(obj_type), redisClient=self.redisMessaging) + + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) session = Session() - else: - session = existingSession - try: - # Get the top 100 records ordered by timestamp (descending order) - top_100_records = session.query(OPERATION_LOG_BASE).order_by(desc(OPERATION_LOG_BASE.timestamp)).limit(100) + try: + if obj_id is not None: + result = session.query(obj_type).get(obj_id) + if result is None: + raise ValueError(f"No {obj_type} found with id {obj_id}") + + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + elif page is not None and page_size is not None: + if page < 1 or page_size < 1: + raise ValueError("page and page_size should be positive integers") + + offset = (page - 1) * page_size + results = ( + session.query(obj_type) + .order_by(obj_type.id) # Assuming obj_type has an attribute 'id' + .offset(offset) + .limit(page_size) + .all() + ) + + result = [] + for item in results: + item_dict = item.__dict__ + item_dict.pop('_sa_instance_state') + result.append(self.Sanitize_Datetime(item_dict)) + else: + raise ValueError("Provide either obj_id or both page and page_size") - # Get the most recent operation_id - most_recent_operation_log = top_100_records.first() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - # Convert the object to dictionary - if most_recent_operation_log is not None: - obj_dict = most_recent_operation_log.__dict__ - obj_dict.pop('_sa_instance_state') - sanitized_obj_dict = Sanitize_Datetime(obj_dict) - return sanitized_obj_dict + self.safe_close(session) + return result - safe_close(session) - return None - except Exception as E: - dbLogger.error(f"get_last_operation_log error: {E}") - dbLogger.error(E) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - -def handleGeored(jsonData): - try: - if yaml_config.get('geored', {}).get('enabled', False): - if yaml_config.get('geored', {}).get('sync_endpoints', []) is not None and len(yaml_config.get('geored', {}).get('sync_endpoints', [])) > 0: - transaction_id = str(uuid.uuid4()) - redisMessaging.sendMessage(queue=f'geored-{time.time_ns()}', message=jsonData, queueExpiry=120) - except Exception as E: - dbLogger.warning("Failed to send Geored message due to error: " + str(E)) - -def handleWebhook(objectData, operation): - external_webhook_notification_enabled = yaml_config.get('external', {}).get('external_webhook_notification_enabled', False) - external_webhook_notification_url = yaml_config.get('external', {}).get('external_webhook_notification_url', '') - if not external_webhook_notification_enabled: - return False - if not external_webhook_notification_url: - dbLogger.error("External webhook notification enabled, but external_webhook_notification_url is not defined.") - - externalNotification = Sanitize_Datetime(objectData) - externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} - #@@Fixme - redisMessaging.sendMessage(queue=f'webhook-{time.time_ns()}', message=jsonData, queueExpiry=120) - return True - -def Sanitize_Datetime(result): - for keys in result: - if "timestamp" in keys: - if result[keys] == None: - continue - else: - dbLogger.debug("Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String") - result[keys] = str(result[keys]) - return result + def GetAll(self, obj_type): + self.logTool.log(service='Database', level='debug', message="Called GetAll for type " + str(obj_type), redisClient=self.redisMessaging) -def Sanitize_Keys(result): - names_to_strip = ['opc', 'ki', 'des', 'kid', 'psk', 'adm1'] - for name_to_strip in names_to_strip: - try: - result.pop(name_to_strip) - except: - pass - return result + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind = self.engine) + session = Session() + final_result_list = [] -def GetObj(obj_type, obj_id=None, page=None, page_size=None): - dbLogger.debug("Called GetObj for type " + str(obj_type)) + try: + result = session.query(obj_type) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + for record in result: + record = record.__dict__ + record.pop('_sa_instance_state') + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) + final_result_list.append(record) - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() + self.safe_close(session) + return final_result_list - try: - if obj_id is not None: - result = session.query(obj_type).get(obj_id) - if result is None: - raise ValueError(f"No {obj_type} found with id {obj_id}") + def getAllPaginated(self, obj_type, page=0, page_size=0, existingSession=None): + self.logTool.log(service='Database', level='debug', message="Called getAllPaginated for type " + str(obj_type), redisClient=self.redisMessaging) - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - elif page is not None and page_size is not None: - if page < 1 or page_size < 1: - raise ValueError("page and page_size should be positive integers") - - offset = (page - 1) * page_size - results = ( - session.query(obj_type) - .order_by(obj_type.id) # Assuming obj_type has an attribute 'id' - .offset(offset) - .limit(page_size) - .all() - ) - - result = [] - for item in results: - item_dict = item.__dict__ - item_dict.pop('_sa_instance_state') - result.append(Sanitize_Datetime(item_dict)) + if not existingSession: + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + session = Session() else: - raise ValueError("Provide either obj_id or both page and page_size") + session = existingSession - except Exception as E: - dbLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) + final_result_list = [] - safe_close(session) - return result + try: + # Query object type + result = session.query(obj_type) -def GetAll(obj_type): - dbLogger.debug("Called GetAll for type " + str(obj_type)) + # Apply pagination + if page_size != 0: + result = result.limit(page_size).offset(page * page_size) + + result = result.all() + + for record in result: + record = record.__dict__ + record.pop('_sa_instance_state') + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) + final_result_list.append(record) + + self.safe_close(session) + return final_result_list - Base.metadata.create_all(engine) - Session = sessionmaker(bind = engine) - session = Session() - final_result_list = [] + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - try: - result = session.query(obj_type) - except Exception as E: - dbLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - for record in result: - record = record.__dict__ - record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) - final_result_list.append(record) - - safe_close(session) - return final_result_list - -def getAllPaginated(obj_type, page=0, page_size=0, existingSession=None): - dbLogger.debug("Called getAllPaginated for type " + str(obj_type)) - - if not existingSession: - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - session = Session() - else: - session = existingSession - final_result_list = [] + def GetAllByTable(self, obj_type, table): + self.logTool.log(service='Database', level='debug', message=f"Called GetAll for type {str(obj_type)} and table {table}", redisClient=self.redisMessaging) - try: - # Query object type - result = session.query(obj_type) + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind = self.engine) + session = Session() + final_result_list = [] - # Apply pagination - if page_size != 0: - result = result.limit(page_size).offset(page * page_size) + try: + result = session.query(obj_type).filter_by(table_name=str(table)) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to query, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - result = result.all() - for record in result: record = record.__dict__ record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) + record = self.Sanitize_Datetime(record) + record = self.Sanitize_Keys(record) final_result_list.append(record) - - safe_close(session) + + self.safe_close(session) return final_result_list - except Exception as E: - dbLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) + def UpdateObj(self, obj_type, json_data, obj_id, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message=f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + obj_type_str = str(obj_type.__table__.name).upper() + self.logTool.log(service='Database', level='debug', message=f"obj_type_str is {obj_type_str}", redisClient=self.redisMessaging) + filter_input = eval(obj_type_str + "." + obj_type_str.lower() + "_id==obj_id") + try: + obj = session.query(obj_type).filter(filter_input).one() + for key, value in json_data.items(): + if hasattr(obj, key): + setattr(obj, key, value) + setattr(obj, "last_modified", datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z') + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to query or update object, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + try: + session.info["operation_id"] = operation_id # Pass the operation id + try: + if not disable_logging: + self.log_changes_before_commit(session) + objectData = self.GetObj(obj_type, obj_id) + session.commit() + self.handleWebhook(objectData, 'UPDATE') + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) + raise ValueError(E) + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Exception in UpdateObj, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + finally: + self.safe_close(session) + return self.GetObj(obj_type, obj_id) -def GetAllByTable(obj_type, table): - dbLogger.debug(f"Called GetAll for type {str(obj_type)} and table {table}") + def DeleteObj(self, obj_type, obj_id, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message=f"Called DeleteObj for type {obj_type} with id {obj_id}", redisClient=self.redisMessaging) - Base.metadata.create_all(engine) - Session = sessionmaker(bind = engine) - session = Session() - final_result_list = [] + Session = sessionmaker(bind=self.engine) + session = Session() - try: - result = session.query(obj_type).filter_by(table_name=str(table)) - except Exception as E: - dbLogger.error("Failed to query, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - for record in result: - record = record.__dict__ - record.pop('_sa_instance_state') - record = Sanitize_Datetime(record) - record = Sanitize_Keys(record) - final_result_list.append(record) - - safe_close(session) - return final_result_list - -def UpdateObj(obj_type, json_data, obj_id, disable_logging=False, operation_id=None): - dbLogger.debug(f"Called UpdateObj() for type {obj_type} id {obj_id} with JSON data: {json_data} and operation_id: {operation_id}") - Session = sessionmaker(bind=engine) - session = Session() - obj_type_str = str(obj_type.__table__.name).upper() - dbLogger.debug(f"obj_type_str is {obj_type_str}") - filter_input = eval(obj_type_str + "." + obj_type_str.lower() + "_id==obj_id") - try: - obj = session.query(obj_type).filter(filter_input).one() - for key, value in json_data.items(): - if hasattr(obj, key): - setattr(obj, key, value) - setattr(obj, "last_modified", datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z') - except Exception as E: - dbLogger.error(f"Failed to query or update object, error: {E}") - raise ValueError(E) - try: + try: + res = session.query(obj_type).get(obj_id) + if res is None: + raise ValueError("The specified row does not exist") + objectData = self.GetObj(obj_type, obj_id) + session.delete(res) session.info["operation_id"] = operation_id # Pass the operation id try: if not disable_logging: - log_changes_before_commit(session) - objectData = GetObj(obj_type, obj_id) + self.log_changes_before_commit(session) session.commit() - handleWebhook(objectData, 'UPDATE') + self.handleWebhook(objectData, 'DELETE') except Exception as E: - dbLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) raise ValueError(E) - except Exception as E: - dbLogger.error(f"Exception in UpdateObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) - return GetObj(obj_type, obj_id) + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Exception in DeleteObj, error: {E}", redisClient=self.redisMessaging) + raise ValueError(E) + finally: + self.safe_close(session) + + return {'Result': 'OK'} -def DeleteObj(obj_type, obj_id, disable_logging=False, operation_id=None): - dbLogger.debug(f"Called DeleteObj for type {obj_type} with id {obj_id}") - Session = sessionmaker(bind=engine) - session = Session() + def CreateObj(self, obj_type, json_data, disable_logging=False, operation_id=None): + self.logTool.log(service='Database', level='debug', message="Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data), redisClient=self.redisMessaging) + last_modified_value = datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z' + json_data["last_modified"] = last_modified_value # set last_modified value in json_data + newObj = obj_type(**json_data) + Session = sessionmaker(bind=self.engine) + session = Session() - try: - res = session.query(obj_type).get(obj_id) - if res is None: - raise ValueError("The specified row does not exist") - objectData = GetObj(obj_type, obj_id) - session.delete(res) - session.info["operation_id"] = operation_id # Pass the operation id + session.add(newObj) try: - if not disable_logging: - log_changes_before_commit(session) - session.commit() - handleWebhook(objectData, 'DELETE') + session.info["operation_id"] = operation_id # Pass the operation id + try: + if not disable_logging: + self.log_changes_before_commit(session) + session.commit() + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) + self.safe_rollback(session) + raise ValueError(E) + session.refresh(newObj) + result = newObj.__dict__ + result.pop('_sa_instance_state') + self.handleWebhook(result, 'CREATE') + return result except Exception as E: - dbLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) + self.logTool.log(service='Database', level='error', message=f"Exception in CreateObj, error: {E}", redisClient=self.redisMessaging) raise ValueError(E) + finally: + self.safe_close(session) - except Exception as E: - dbLogger.error(f"Exception in DeleteObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) + def Generate_JSON_Model_for_Flask(self, obj_type): + self.logTool.log(service='Database', level='debug', message="Generating JSON model for Flask for object type: " + str(obj_type), redisClient=self.redisMessaging) - return {'Result': 'OK'} + dictty = dict(self.generate_json_schema(obj_type)) + pprint.pprint(dictty) -def CreateObj(obj_type, json_data, disable_logging=False, operation_id=None): - dbLogger.debug("Called CreateObj to create " + str(obj_type) + " with value: " + str(json_data)) - last_modified_value = datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S') + 'Z' - json_data["last_modified"] = last_modified_value # set last_modified value in json_data - newObj = obj_type(**json_data) - Session = sessionmaker(bind=engine) - session = Session() + #dictty['properties'] = dict(dictty['properties']) - session.add(newObj) - try: - session.info["operation_id"] = operation_id # Pass the operation id - try: - if not disable_logging: - log_changes_before_commit(session) - session.commit() - except Exception as E: - dbLogger.error(f"Failed to commit session, error: {E}") - safe_rollback(session) - raise ValueError(E) - session.refresh(newObj) - result = newObj.__dict__ + # Exclude 'table_name' column from the properties + if 'properties' in dictty: + dictty['properties'].pop('discriminator', None) + dictty['properties'].pop('last_modified', None) + + + # Set the ID Object to not required + obj_type_str = str(dictty['title']).lower() + dictty['required'].remove(obj_type_str + '_id') + + return dictty + + def Get_AuC(self, **kwargs): + #Get AuC data by IMSI or ICCID + + Session = sessionmaker(bind = self.engine) + session = Session() + + if 'iccid' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_AuC for iccid " + str(kwargs['iccid']), redisClient=self.redisMessaging) + try: + result = session.query(AUC).filter_by(iccid=str(kwargs['iccid'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_AuC for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(AUC).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + + result = result.__dict__ + result = self.Sanitize_Datetime(result) result.pop('_sa_instance_state') - handleWebhook(result, 'CREATE') + + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) return result - except Exception as E: - dbLogger.error(f"Exception in CreateObj, error: {E}") - raise ValueError(E) - finally: - safe_close(session) -def Generate_JSON_Model_for_Flask(obj_type): - dbLogger.debug("Generating JSON model for Flask for object type: " + str(obj_type)) + def Get_IMS_Subscriber(self, **kwargs): + #Get subscriber by IMSI or MSISDN + Session = sessionmaker(bind = self.engine) + session = Session() + if 'msisdn' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn']), redisClient=self.redisMessaging) + try: + result = session.query(IMS_SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_IMS_Subscriber for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="Converting result to dict", redisClient=self.redisMessaging) + result = result.__dict__ + try: + result.pop('_sa_instance_state') + except: + pass + result = self.Sanitize_Datetime(result) + self.logTool.log(service='Database', level='debug', message="Returning IMS Subscriber Data: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result - dictty = dict(generate_json_schema(obj_type)) - pprint.pprint(dictty) + def Get_Subscriber(self, **kwargs): + #Get subscriber by IMSI or MSISDN + Session = sessionmaker(bind = self.engine) + session = Session() - #dictty['properties'] = dict(dictty['properties']) + if 'msisdn' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for msisdn " + str(kwargs['msisdn']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'imsi' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for imsi " + str(kwargs['imsi']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) - # Exclude 'table_name' column from the properties - if 'properties' in dictty: - dictty['properties'].pop('discriminator', None) - dictty['properties'].pop('last_modified', None) + result = result.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + if 'get_attributes' in kwargs: + if kwargs['get_attributes'] == True: + attributes = self.Get_Subscriber_Attributes(result['subscriber_id']) + result['attributes'] = attributes - # Set the ID Object to not required - obj_type_str = str(dictty['title']).lower() - dictty['required'].remove(obj_type_str + '_id') - - return dictty - -def Get_AuC(**kwargs): - #Get AuC data by IMSI or ICCID + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result - Session = sessionmaker(bind = engine) - session = Session() + def Get_SUBSCRIBER_ROUTING(self, subscriber_id, apn_id): + Session = sessionmaker(bind = self.engine) + session = Session() - if 'iccid' in kwargs: - dbLogger.debug("Get_AuC for iccid " + str(kwargs['iccid'])) + self.logTool.log(service='Database', level='debug', message="Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id), redisClient=self.redisMessaging) try: - result = session.query(AUC).filter_by(iccid=str(kwargs['iccid'])).one() + result = session.query(SUBSCRIBER_ROUTING).filter_by(subscriber_id=subscriber_id, apn_id=apn_id).one() except Exception as E: - safe_close(session) + self.safe_close(session) raise ValueError(E) - elif 'imsi' in kwargs: - dbLogger.debug("Get_AuC for imsi " + str(kwargs['imsi'])) + + result = result.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(result), redisClient=self.redisMessaging) + self.safe_close(session) + return result + + def Get_Subscriber_Attributes(self, subscriber_id): + #Get subscriber attributes + + Session = sessionmaker(bind = self.engine) + session = Session() + + self.logTool.log(service='Database', level='debug', message="Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id), redisClient=self.redisMessaging) try: - result = session.query(AUC).filter_by(imsi=str(kwargs['imsi'])).one() + result = session.query(SUBSCRIBER_ATTRIBUTES).filter_by(subscriber_id=subscriber_id) except Exception as E: - safe_close(session) + self.safe_close(session) raise ValueError(E) + final_res = [] + for record in result: + result = record.__dict__ + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + final_res.append(result) + self.logTool.log(service='Database', level='debug', message="Got back result: " + str(final_res), redisClient=self.redisMessaging) + self.safe_close(session) + return final_res + - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') + def Get_Served_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this HSS", redisClient=self.redisMessaging) - dbLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result + Session = sessionmaker(bind = self.engine) + session = Session() -def Get_IMS_Subscriber(**kwargs): - #Get subscriber by IMSI or MSISDN - Session = sessionmaker(bind = engine) - session = Session() - if 'msisdn' in kwargs: - dbLogger.debug("Get_IMS_Subscriber for msisdn " + str(kwargs['msisdn'])) + Served_Subs = {} try: - result = session.query(IMS_SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + results = session.query(SUBSCRIBER).filter(SUBSCRIBER.serving_mme.isnot(None)) + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_hss = result['serving_mme_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving HSS: " + str(serving_hss) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_hss == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving HSS matches local HSS", redisClient=self.redisMessaging) + Served_Subs[result['imsi']] = {} + Served_Subs[result['imsi']] = result + #self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + continue + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote HSS: " + str(serving_hss), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering Get_Served_Subscribers to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + else: + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + + except Exception as E: - safe_close(session) + self.safe_close(session) raise ValueError(E) - elif 'imsi' in kwargs: - dbLogger.debug("Get_IMS_Subscriber for imsi " + str(kwargs['imsi'])) + self.logTool.log(service='Database', level='debug', message="Final Served_Subs: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs + + + def Get_Served_IMS_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this IMS-HSS", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + + Served_Subs = {} try: - result = session.query(IMS_SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + + results = session.query(IMS_SUBSCRIBER).filter( + IMS_SUBSCRIBER.scscf.isnot(None)) + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result, redisClient=self.redisMessaging) + + " type: " + str(type(result))) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_ims_hss = result['scscf_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_ims_hss == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving IMS-HSS matches local HSS for " + str(result['imsi']), redisClient=self.redisMessaging) + Served_Subs[result['imsi']] = {} + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + continue + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote IMS-HSS: " + str(serving_ims_hss), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + else: + Served_Subs[result['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + except Exception as E: - safe_close(session) + self.safe_close(session) raise ValueError(E) - dbLogger.debug("Converting result to dict") - result = result.__dict__ - try: - result.pop('_sa_instance_state') - except: - pass - result = Sanitize_Datetime(result) - dbLogger.debug("Returning IMS Subscriber Data: " + str(result)) - safe_close(session) - return result - -def Get_Subscriber(**kwargs): - #Get subscriber by IMSI or MSISDN + self.logTool.log(service='Database', level='debug', message="Final Served_Subs: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs - Session = sessionmaker(bind = engine) - session = Session() - if 'msisdn' in kwargs: - dbLogger.debug("Get_Subscriber for msisdn " + str(kwargs['msisdn'])) + def Get_Served_PCRF_Subscribers(self, get_local_users_only=False): + self.logTool.log(service='Database', level='debug', message="Getting all subscribers served by this PCRF", redisClient=self.redisMessaging) + Session = sessionmaker(bind=self.engine) + session = Session() + Served_Subs = {} try: - result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() + results = session.query(SERVING_APN).all() + for result in results: + result = result.__dict__ + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) + result = self.Sanitize_Datetime(result) + result.pop('_sa_instance_state') + + if get_local_users_only == True: + self.logTool.log(service='Database', level='debug', message="Filtering to locally served IMS Subs only", redisClient=self.redisMessaging) + try: + serving_pcrf = result['serving_pgw_peer'].split(';')[1] + self.logTool.log(service='Database', level='debug', message="Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(self.config['hss']['OriginHost']), redisClient=self.redisMessaging) + if serving_pcrf == self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="Serving PCRF matches local PCRF", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) + + else: + self.logTool.log(service='Database', level='debug', message="Sub is served by remote PCRF: " + str(serving_pcrf), redisClient=self.redisMessaging) + continue + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E), redisClient=self.redisMessaging) + continue + + # Get APN Info + apn_info = self.GetObj(APN, result['apn']) + #self.logTool.log(service='Database', level='debug', message="Got APN Info: " + str(apn_info), redisClient=self.redisMessaging) + result['apn_info'] = apn_info + + # Get Subscriber Info + subscriber_info = self.GetObj(SUBSCRIBER, result['subscriber_id']) + result['subscriber_info'] = subscriber_info + + #self.logTool.log(service='Database', level='debug', message="Got Subscriber Info: " + str(subscriber_info), redisClient=self.redisMessaging) + + Served_Subs[subscriber_info['imsi']] = result + self.logTool.log(service='Database', level='debug', message="Processed result", redisClient=self.redisMessaging) except Exception as E: - safe_close(session) raise ValueError(E) - elif 'imsi' in kwargs: - dbLogger.debug("Get_Subscriber for imsi " + str(kwargs['imsi'])) + #self.logTool.log(service='Database', level='debug', message="Final SERVING_APN: " + str(Served_Subs), redisClient=self.redisMessaging) + self.safe_close(session) + return Served_Subs + + def Get_Vectors_AuC(self, auc_id, action, **kwargs): + self.logTool.log(service='Database', level='debug', message="Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action), redisClient=self.redisMessaging) + key_data = self.GetObj(AUC, auc_id) + vector_dict = {} + + if action == "air": + rand, xres, autn, kasme = S6a_crypt.generate_eutran_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) + vector_dict['rand'] = rand + vector_dict['xres'] = xres + vector_dict['autn'] = autn + vector_dict['kasme'] = kasme + + #Incriment SQN + self.Update_AuC(auc_id, sqn=key_data['sqn']+100) + + return vector_dict + + elif action == "sqn_resync": + self.logTool.log(service='Database', level='debug', message="Resync SQN", redisClient=self.redisMessaging) + rand = kwargs['rand'] + sqn, mac_s = S6a_crypt.generate_resync_s6a(key_data['ki'], key_data['opc'], key_data['amf'], kwargs['auts'], rand) + self.logTool.log(service='Database', level='debug', message="SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")", redisClient=self.redisMessaging) + self.Update_AuC(auc_id, sqn=sqn+100) + return + + elif action == "sip_auth": + rand, autn, xres, ck, ik = S6a_crypt.generate_maa_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) + self.logTool.log(service='Database', level='debug', message="RAND is: " + str(rand), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="AUTN is: " + str(autn), redisClient=self.redisMessaging) + vector_dict['SIP_Authenticate'] = rand + autn + vector_dict['xres'] = xres + vector_dict['ck'] = ck + vector_dict['ik'] = ik + self.Update_AuC(auc_id, sqn=key_data['sqn']+100) + return vector_dict + + elif action == "Digest-MD5": + self.logTool.log(service='Database', level='debug', message="Generating Digest-MD5 Auth vectors", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="key_data: " + str(key_data), redisClient=self.redisMessaging) + nonce = uuid.uuid4().hex + #nonce = "beef4d878f2642ed98afe491b943ca60" + vector_dict['nonce'] = nonce + vector_dict['SIP_Authenticate'] = key_data['ki'] + return vector_dict + + def Get_APN(self, apn_id): + self.logTool.log(service='Database', level='debug', message="Getting APN " + str(apn_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + try: - result = session.query(SUBSCRIBER).filter_by(imsi=str(kwargs['imsi'])).one() + result = session.query(APN).filter_by(apn_id=apn_id).one() except Exception as E: - safe_close(session) + self.safe_close(session) raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - - if 'get_attributes' in kwargs: - if kwargs['get_attributes'] == True: - attributes = Get_Subscriber_Attributes(result['subscriber_id']) - result['attributes'] = attributes + def Get_APN_by_Name(self, apn): + self.logTool.log(service='Database', level='debug', message="Getting APN named " + str(apn_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + try: + result = session.query(APN).filter_by(apn=str(apn)).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result - dbLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result + def Update_AuC(self, auc_id, sqn=1): + self.logTool.log(service='Database', level='debug', message="Updating AuC record for sub " + str(auc_id), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=self.UpdateObj(AUC, {'sqn': sqn}, auc_id, True), redisClient=self.redisMessaging) + return -def Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id): - Session = sessionmaker(bind = engine) - session = Session() + def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + try: + result = session.query(SUBSCRIBER).filter_by(imsi=imsi).one() + if self.config['hss']['CancelLocationRequest_Enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Evaluating if we should trigger sending a CLR.", redisClient=self.redisMessaging) + serving_hss = str(result.serving_mme_peer).split(';',1)[1] + serving_mme_peer = str(result.serving_mme_peer).split(';',1)[0] + self.logTool.log(service='Database', level='debug', message="Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer), redisClient=self.redisMessaging) + #Evaluate if we need to send a CLR to the old MME + if result.serving_mme != None: + if str(result.serving_mme) == str(serving_mme): + self.logTool.log(service='Database', level='debug', message="This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR", redisClient=self.redisMessaging) + elif (str(result.serving_mme) != str(serving_mme)): + self.logTool.log(service='Database', level='debug', message="There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR", redisClient=self.redisMessaging) + if serving_hss != self.config['hss']['OriginHost']: + self.logTool.log(service='Database', level='debug', message="This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss), redisClient=self.redisMessaging) + URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) + else: + self.logTool.log(service='Database', level='debug', message="This subscriber is served by this HSS we need to send a CLR to old MME from this HSS", redisClient=self.redisMessaging) + + URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) + self.logTool.log(service='Database', level='debug', message="Sending CLR to API at " + str(URL), redisClient=self.redisMessaging) + json_data = { + "DestinationRealm": result.serving_mme_realm, + "DestinationHost": result.serving_mme, + "cancellationType": 2, + "diameterPeer": serving_mme_peer, + } + + self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data), redisClient=self.redisMessaging) + transaction_id = str(uuid.uuid4()) + GeoRed_Push_thread = threading.Thread(target=self.GeoRed_Push_Request, args=(serving_hss, json_data, transaction_id, URL)) + GeoRed_Push_thread.start() + else: + #No currently serving MME - No action to take + self.logTool.log(service='Database', level='debug', message="No currently serving MME - No need to send CLR", redisClient=self.redisMessaging) + + if type(serving_mme) == str: + self.logTool.log(service='Database', level='debug', message="Updating serving MME & Timestamp", redisClient=self.redisMessaging) + result.serving_mme = serving_mme + result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + result.serving_mme_realm = serving_mme_realm + result.serving_mme_peer = serving_mme_peer + else: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing serving MME", redisClient=self.redisMessaging) + result.serving_mme = None + result.serving_mme_timestamp = None + result.serving_mme_realm = None + result.serving_mme_peer = None - dbLogger.debug("Get_SUBSCRIBER_ROUTING for subscriber_id " + str(subscriber_id) + " and apn_id " + str(apn_id)) - try: - result = session.query(SUBSCRIBER_ROUTING).filter_by(subscriber_id=subscriber_id, apn_id=apn_id).one() - except Exception as E: - safe_close(session) - raise ValueError(E) + session.commit() + objectData = self.GetObj(SUBSCRIBER, result.subscriber_id) + self.handleWebhook(objectData, 'UPDATE') + + #Sync state change with geored + if propagate == True: + if 'HSS' in self.config['geored'].get('sync_actions', []) and self.config['geored'].get('enabled', False) == True: + self.logTool.log(service='Database', level='debug', message="Propagate MME changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({ + "imsi": str(imsi), + "serving_mme": result.serving_mme, + "serving_mme_realm": str(result.serving_mme_realm), + "serving_mme_peer": str(result.serving_mme_peer) + }) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of HSS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="Error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) + raise + finally: + self.safe_close(session) - result = result.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - dbLogger.debug("Got back result: " + str(result)) - safe_close(session) - return result + def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() -def Get_Subscriber_Attributes(subscriber_id): - #Get subscriber attributes + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() + try: + assert(type(serving_cscf) == str) + assert(len(serving_cscf) > 0) + self.logTool.log(service='Database', level='debug', message="Setting serving CSCF", redisClient=self.redisMessaging) + #Strip duplicate SIP prefix before storing + serving_cscf = serving_cscf.replace("sip:sip:", "sip:") + result.scscf = serving_cscf + result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + result.scscf_realm = scscf_realm + result.scscf_peer = str(scscf_peer) + except: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing serving CSCF", redisClient=self.redisMessaging) + result.scscf = None + result.scscf_timestamp = None + result.scscf_realm = None + result.scscf_peer = None + + session.commit() + objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) + self.handleWebhook(objectData, 'UPDATE') + + #Sync state change with geored + if propagate == True: + if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="An error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + raise + finally: + self.safe_close(session) + + + def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="subscriber_routing: " + str(subscriber_routing), redisClient=self.redisMessaging) + + #Get Subscriber ID from IMSI + subscriber_details = self.Get_Subscriber(imsi=str(imsi)) + subscriber_id = subscriber_details['subscriber_id'] + + #Split the APN list into a list + apn_list = subscriber_details['apn_list'].split(',') + self.logTool.log(service='Database', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) + #Remove the default APN from the list + try: + apn_list.remove(str(subscriber_details['default_apn'])) + except: + self.logTool.log(service='Database', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) + pass + #Add default APN in first position + apn_list.insert(0, str(subscriber_details['default_apn'])) + + #Get APN ID from APN + for apn_id in apn_list: + #Get each APN in List + apn_data = self.Get_APN(apn_id) + self.logTool.log(service='Database', level='debug', message=apn_data, redisClient=self.redisMessaging) + if str(apn_data['apn']).lower() == str(apn).lower(): + self.logTool.log(service='Database', level='debug', message="Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id), redisClient=self.redisMessaging) + break + self.logTool.log(service='Database', level='debug', message="APN ID is " + str(apn_id), redisClient=self.redisMessaging) + + json_data = { + 'apn' : apn_id, + 'subscriber_id' : subscriber_id, + 'pcrf_session_id' : str(pcrf_session_id), + 'serving_pgw' : str(serving_pgw), + 'serving_pgw_realm' : str(serving_pgw_realm), + 'serving_pgw_peer' : str(serving_pgw_peer), + 'serving_pgw_timestamp' : datetime.datetime.now(tz=timezone.utc), + 'subscriber_routing' : str(subscriber_routing) + } - Session = sessionmaker(bind = engine) - session = Session() + try: + #Check if already a serving APN on record + self.logTool.log(service='Database', level='debug', message="Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id), redisClient=self.redisMessaging) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + self.logTool.log(service='Database', level='debug', message="Existing Serving APN ID on record, updating", redisClient=self.redisMessaging) + try: + assert(type(serving_pgw) == str) + assert(len(serving_pgw) > 0) + assert("None" not in serving_pgw) + + self.UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'UPDATE') + except: + self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'DELETE') + self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) + except Exception as E: + self.logTool.log(service='Database', level='info', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) + #Create if does not exist + self.CreateObj(SERVING_APN, json_data, True) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'CREATE') - dbLogger.debug("Get_Subscriber_Attributes for subscriber_id " + str(subscriber_id)) - try: - result = session.query(SUBSCRIBER_ATTRIBUTES).filter_by(subscriber_id=subscriber_id) - except Exception as E: - safe_close(session) - raise ValueError(E) - final_res = [] - for record in result: - result = record.__dict__ - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - final_res.append(result) - dbLogger.debug("Got back result: " + str(final_res)) - safe_close(session) - return final_res + #Sync state change with geored + if propagate == True: + try: + if 'PCRF' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate PCRF changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), + 'serving_apn' : str(apn), + 'pcrf_session_id': str(pcrf_session_id), + 'serving_pgw': str(serving_pgw), + 'serving_pgw_realm': str(serving_pgw_realm), + 'serving_pgw_peer': str(serving_pgw_peer), + 'subscriber_routing': str(subscriber_routing) + }) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of PCRF events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Nothing synced to Geographic PyHSS instances for event PCRF", redisClient=self.redisMessaging) -def Get_Served_Subscribers(get_local_users_only=False): - dbLogger.debug("Getting all subscribers served by this HSS") + return - Session = sessionmaker(bind = engine) - session = Session() + def Get_Serving_APN(self, subscriber_id, apn_id): + self.logTool.log(service='Database', level='debug', message="Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() - Served_Subs = {} - try: - results = session.query(SUBSCRIBER).filter(SUBSCRIBER.serving_mme.isnot(None)) - for result in results: - result = result.__dict__ - dbLogger.debug("Result: " + str(result) + " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') + try: + result = session.query(SERVING_APN).filter_by(subscriber_id=subscriber_id, apn=apn_id).first() + except Exception as E: + self.logTool.log(service='Database', level='debug', message=E, redisClient=self.redisMessaging) + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + + self.safe_close(session) + return result - if get_local_users_only == True: - dbLogger.debug("Filtering to locally served IMS Subs only") - try: - serving_hss = result['serving_mme_peer'].split(';')[1] - dbLogger.debug("Serving HSS: " + str(serving_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_hss == yaml_config['hss']['OriginHost']: - dbLogger.debug("Serving HSS matches local HSS") - Served_Subs[result['imsi']] = {} - Served_Subs[result['imsi']] = result - #dbLogger.debug("Processed result") - continue - else: - dbLogger.debug("Sub is served by remote HSS: " + str(serving_hss)) - except Exception as E: - dbLogger.debug("Error in filtering Get_Served_Subscribers to local peer only: " + str(E)) - continue - else: - Served_Subs[result['imsi']] = result - dbLogger.debug("Processed result") + def Get_Charging_Rule(self, charging_rule_id): + self.logTool.log(service='Database', level='debug', message="Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + #Get base Rule + ChargingRule = self.GetObj(CHARGING_RULE, charging_rule_id) + ChargingRule['tft'] = [] + #Get TFTs + try: + results = session.query(TFT).filter_by(tft_group_id=ChargingRule['tft_group_id']) + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + ChargingRule['tft'].append(result) + except Exception as E: + self.safe_close(session) + raise ValueError(E) + self.safe_close(session) + return ChargingRule + + def Get_Charging_Rules(self, imsi, apn): + self.logTool.log(service='Database', level='debug', message="Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn), redisClient=self.redisMessaging) + #Get Subscriber ID from IMSI + subscriber_details = self.Get_Subscriber(imsi=str(imsi)) + + #Split the APN list into a list + apn_list = subscriber_details['apn_list'].split(',') + self.logTool.log(service='Database', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) + #Remove the default APN from the list + try: + apn_list.remove(str(subscriber_details['default_apn'])) + except: + self.logTool.log(service='Database', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) + pass + #Add default APN in first position + apn_list.insert(0, str(subscriber_details['default_apn'])) + + #Get APN ID from APN + for apn_id in apn_list: + self.logTool.log(service='Database', level='debug', message="Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn), redisClient=self.redisMessaging) + #Get each APN in List + apn_data = self.Get_APN(apn_id) + self.logTool.log(service='Database', level='debug', message=apn_data, redisClient=self.redisMessaging) + if str(apn_data['apn']).lower() == str(apn).lower(): + self.logTool.log(service='Database', level='debug', message="Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id), redisClient=self.redisMessaging) + + self.logTool.log(service='Database', level='debug', message="Getting charging rule list from " + str(apn_data['charging_rule_list']), redisClient=self.redisMessaging) + ChargingRule = {} + ChargingRule['charging_rule_list'] = str(apn_data['charging_rule_list']).split(',') + ChargingRule['apn_data'] = apn_data + + #Get Charging Rules list + if apn_data['charging_rule_list'] == None: + self.logTool.log(service='Database', level='debug', message="No Charging Rule associated with this APN", redisClient=self.redisMessaging) + ChargingRule['charging_rules'] = None + return ChargingRule + + self.logTool.log(service='Database', level='debug', message="ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list']), redisClient=self.redisMessaging) + #Empty dict for the Charging Rules to go into + ChargingRule['charging_rules'] = [] + #Add each of the Charging Rules for the APN + for individual_charging_rule in ChargingRule['charging_rule_list']: + self.logTool.log(service='Database', level='debug', message="Getting Charging rule " + str(individual_charging_rule), redisClient=self.redisMessaging) + individual_charging_rule_complete = self.Get_Charging_Rule(individual_charging_rule) + self.logTool.log(service='Database', level='debug', message="Got individual_charging_rule_complete: " + str(individual_charging_rule_complete), redisClient=self.redisMessaging) + ChargingRule['charging_rules'].append(individual_charging_rule_complete) + self.logTool.log(service='Database', level='debug', message="Completed Get_Charging_Rules()", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=ChargingRule, redisClient=self.redisMessaging) + return ChargingRule + def Get_UE_by_IP(self, subscriber_routing): + self.logTool.log(service='Database', level='debug', message="Called Get_UE_by_IP() for IP " + str(subscriber_routing), redisClient=self.redisMessaging) - except Exception as E: - safe_close(session) - raise ValueError(E) - dbLogger.debug("Final Served_Subs: " + str(Served_Subs)) - safe_close(session) - return Served_Subs + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(SERVING_APN).filter_by(subscriber_routing=subscriber_routing).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + return result + #Get Subscriber ID from IMSI + subscriber_details = Get_Subscriber(imsi=str(imsi)) + + def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): + #IMSI 14-15 Digits + #IMEI 15 Digits + #IMEI-SV 2 Digits + self.logTool.log(service='Database', level='debug', message="Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code), redisClient=self.redisMessaging) + if self.config['eir']['imsi_imei_logging'] != True: + self.logTool.log(service='Database', level='debug', message="Skipping storing binding", redisClient=self.redisMessaging) + return + #Concat IMEI + IMSI + imsi_imei = str(imsi) + "," + str(imei) + Session = sessionmaker(bind = self.engine) + session = Session() + #Check if exist already & update + try: + session.query(IMSI_IMEI_HISTORY).filter_by(imsi_imei=imsi_imei).one() + self.logTool.log(service='Database', level='debug', message="Entry already present for IMSI/IMEI Combo", redisClient=self.redisMessaging) + self.safe_close(session) + return + except Exception as E: + newObj = IMSI_IMEI_HISTORY(imsi_imei=imsi_imei, match_response_code=match_response_code, imsi_imei_timestamp = datetime.datetime.now(tz=timezone.utc)) + session.add(newObj) + try: + session.commit() + except Exception as E: + self.logTool.log(service='Database', level='error', message="Failed to commit session, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + self.safe_close(session) + self.logTool.log(service='Database', level='debug', message="Added new IMSI_IMEI_HISTORY binding", redisClient=self.redisMessaging) -def Get_Served_IMS_Subscribers(get_local_users_only=False): - dbLogger.debug("Getting all subscribers served by this IMS-HSS") - Session = sessionmaker(bind=engine) - session = Session() + if 'sim_swap_notify_webhook' in self.config['eir']: + self.logTool.log(service='Database', level='debug', message="Sending SIM Swap notification to Webhook", redisClient=self.redisMessaging) + try: + dictToSend = {'imei':imei, 'imsi': imsi, 'match_response_code': match_response_code} + self.handleWebhook(dictToSend) + except Exception as E: + self.logTool.log(service='Database', level='debug', message="Failed to post to Webhook", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=str(E), redisClient=self.redisMessaging) - Served_Subs = {} - try: - - results = session.query(IMS_SUBSCRIBER).filter( - IMS_SUBSCRIBER.scscf.isnot(None)) - for result in results: - result = result.__dict__ - dbLogger.debug("Result: " + str(result) + - " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') - if get_local_users_only == True: - dbLogger.debug("Filtering Get_Served_IMS_Subscribers to locally served IMS Subs only") + #Lookup Device Info + if 'tac_database_csv' in self.config['eir']: try: - serving_ims_hss = result['scscf_peer'].split(';')[1] - dbLogger.debug("Serving IMS-HSS: " + str(serving_ims_hss) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_ims_hss == yaml_config['hss']['OriginHost']: - dbLogger.debug("Serving IMS-HSS matches local HSS for " + str(result['imsi'])) - Served_Subs[result['imsi']] = {} - Served_Subs[result['imsi']] = result - dbLogger.debug("Processed result") - continue - else: - dbLogger.debug("Sub is served by remote IMS-HSS: " + str(serving_ims_hss)) + device_info = self.get_device_info_from_TAC(imei=str(imei)) + self.logTool.log(service='Database', level='debug', message="Got Device Info: " + str(device_info), redisClient=self.redisMessaging) + #@@Fixme + # prom_eir_devices.labels( + # imei_prefix=device_info['tac_prefix'], + # device_type=device_info['name'], + # device_name=device_info['model'] + # ).inc() except Exception as E: - dbLogger.debug("Error in filtering to local peer only: " + str(E)) - continue + self.logTool.log(service='Database', level='debug', message="Failed to get device info from TAC", redisClient=self.redisMessaging) + # prom_eir_devices.labels( + # imei_prefix=str(imei)[0:8], + # device_type='Unknown', + # device_name='Unknown' + # ).inc() else: - Served_Subs[result['imsi']] = result - dbLogger.debug("Processed result") - - except Exception as E: - safe_close(session) - raise ValueError(E) - dbLogger.debug("Final Served_Subs: " + str(Served_Subs)) - safe_close(session) - return Served_Subs - - -def Get_Served_PCRF_Subscribers(get_local_users_only=False): - dbLogger.debug("Getting all subscribers served by this PCRF") - Session = sessionmaker(bind=engine) - session = Session() - Served_Subs = {} - try: - results = session.query(SERVING_APN).all() - for result in results: - result = result.__dict__ - dbLogger.debug("Result: " + str(result) + " type: " + str(type(result))) - result = Sanitize_Datetime(result) - result.pop('_sa_instance_state') + self.logTool.log(service='Database', level='debug', message="No TAC database configured, skipping device info lookup", redisClient=self.redisMessaging) - if get_local_users_only == True: - dbLogger.debug("Filtering to locally served IMS Subs only") + #Sync state change with geored + if propagate == True: try: - serving_pcrf = result['serving_pgw_peer'].split(';')[1] - dbLogger.debug("Serving PCRF: " + str(serving_pcrf) + " and this is: " + str(yaml_config['hss']['OriginHost'])) - if serving_pcrf == yaml_config['hss']['OriginHost']: - dbLogger.debug("Serving PCRF matches local PCRF") - dbLogger.debug("Processed result") - + if 'EIR' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate EIR changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored( + {"imsi": str(imsi), + "imei": str(imei), + "match_response_code": str(match_response_code)} + ) else: - dbLogger.debug("Sub is served by remote PCRF: " + str(serving_pcrf)) - continue + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of EIR events", redisClient=self.redisMessaging) except Exception as E: - dbLogger.debug("Error in filtering Get_Served_PCRF_Subscribers to local peer only: " + str(E)) - continue - - # Get APN Info - apn_info = GetObj(APN, result['apn']) - #dbLogger.debug("Got APN Info: " + str(apn_info)) - result['apn_info'] = apn_info - - # Get Subscriber Info - subscriber_info = GetObj(SUBSCRIBER, result['subscriber_id']) - result['subscriber_info'] = subscriber_info - - #dbLogger.debug("Got Subscriber Info: " + str(subscriber_info)) - - Served_Subs[subscriber_info['imsi']] = result - dbLogger.debug("Processed result") - except Exception as E: - raise ValueError(E) - #dbLogger.debug("Final SERVING_APN: " + str(Served_Subs)) - safe_close(session) - return Served_Subs - -def Get_Vectors_AuC(auc_id, action, **kwargs): - dbLogger.debug("Getting Vectors for auc_id " + str(auc_id) + " with action " + str(action)) - key_data = GetObj(AUC, auc_id) - vector_dict = {} - - if action == "air": - rand, xres, autn, kasme = S6a_crypt.generate_eutran_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) - vector_dict['rand'] = rand - vector_dict['xres'] = xres - vector_dict['autn'] = autn - vector_dict['kasme'] = kasme - - #Incriment SQN - Update_AuC(auc_id, sqn=key_data['sqn']+100) - - return vector_dict - - elif action == "sqn_resync": - dbLogger.debug("Resync SQN") - rand = kwargs['rand'] - sqn, mac_s = S6a_crypt.generate_resync_s6a(key_data['ki'], key_data['opc'], key_data['amf'], kwargs['auts'], rand) - dbLogger.debug("SQN from resync: " + str(sqn) + " SQN in DB is " + str(key_data['sqn']) + "(Difference of " + str(int(sqn) - int(key_data['sqn'])) + ")") - Update_AuC(auc_id, sqn=sqn+100) - return - - elif action == "sip_auth": - rand, autn, xres, ck, ik = S6a_crypt.generate_maa_vector(key_data['ki'], key_data['opc'], key_data['amf'], key_data['sqn'], kwargs['plmn']) - dbLogger.debug("RAND is: " + str(rand)) - dbLogger.debug("AUTN is: " + str(autn)) - vector_dict['SIP_Authenticate'] = rand + autn - vector_dict['xres'] = xres - vector_dict['ck'] = ck - vector_dict['ik'] = ik - Update_AuC(auc_id, sqn=key_data['sqn']+100) - return vector_dict - - elif action == "Digest-MD5": - dbLogger.debug("Generating Digest-MD5 Auth vectors") - dbLogger.debug("key_data: " + str(key_data)) - nonce = uuid.uuid4().hex - #nonce = "beef4d878f2642ed98afe491b943ca60" - vector_dict['nonce'] = nonce - vector_dict['SIP_Authenticate'] = key_data['ki'] - return vector_dict - -def Get_APN(apn_id): - dbLogger.debug("Getting APN " + str(apn_id)) - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(APN).filter_by(apn_id=apn_id).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - safe_close(session) - return result - -def Get_APN_by_Name(apn): - dbLogger.debug("Getting APN named " + str(apn_id)) - Session = sessionmaker(bind = engine) - session = Session() - try: - result = session.query(APN).filter_by(apn=str(apn)).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - safe_close(session) - return result - -def Update_AuC(auc_id, sqn=1): - dbLogger.debug("Updating AuC record for sub " + str(auc_id)) - dbLogger.debug(UpdateObj(AUC, {'sqn': sqn}, auc_id, True)) - return - -def Update_Serving_MME(imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, propagate=True): - dbLogger.debug("Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme)) - Session = sessionmaker(bind = engine) - session = Session() - try: - result = session.query(SUBSCRIBER).filter_by(imsi=imsi).one() - if yaml_config['hss']['CancelLocationRequest_Enabled'] == True: - dbLogger.debug("Evaluating if we should trigger sending a CLR.") - serving_hss = str(result.serving_mme_peer).split(';',1)[1] - serving_mme_peer = str(result.serving_mme_peer).split(';',1)[0] - dbLogger.debug("Subscriber is currently served by serving_mme: " + str(result.serving_mme) + " at realm " + str(result.serving_mme_realm) + " through Diameter peer " + str(result.serving_mme_peer)) - dbLogger.debug("Subscriber is now served by serving_mme: " + str(serving_mme) + " at realm " + str(serving_mme_realm) + " through Diameter peer " + str(serving_mme_peer)) - #Evaluate if we need to send a CLR to the old MME - if result.serving_mme != None: - if str(result.serving_mme) == str(serving_mme): - dbLogger.debug("This MME is unchanged (" + str(serving_mme) + ") - so no need to send a CLR") - elif (str(result.serving_mme) != str(serving_mme)): - dbLogger.debug("There is a difference in serving MME, old MME is '" + str(result.serving_mme) + "' new MME is '" + str(serving_mme) + "' - We need to trigger sending a CLR") - if serving_hss != yaml_config['hss']['OriginHost']: - dbLogger.debug("This subscriber is not served by this HSS it is served by HSS at " + serving_hss + " - We need to trigger sending a CLR on " + str(serving_hss)) - URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) - else: - dbLogger.debug("This subscriber is served by this HSS we need to send a CLR to old MME from this HSS") - - URL = 'http://' + serving_hss + '.' + yaml_config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) - dbLogger.debug("Sending CLR to API at " + str(URL)) - json_data = { - "DestinationRealm": result.serving_mme_realm, - "DestinationHost": result.serving_mme, - "cancellationType": 2, - "diameterPeer": serving_mme_peer, - } - - dbLogger.debug("Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data)) - transaction_id = str(uuid.uuid4()) - GeoRed_Push_thread = threading.Thread(target=GeoRed_Push_Request, args=(serving_hss, json_data, transaction_id, URL)) - GeoRed_Push_thread.start() - else: - #No currently serving MME - No action to take - dbLogger.debug("No currently serving MME - No need to send CLR") - - if type(serving_mme) == str: - dbLogger.debug("Updating serving MME & Timestamp") - result.serving_mme = serving_mme - result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) - result.serving_mme_realm = serving_mme_realm - result.serving_mme_peer = serving_mme_peer - else: - #Clear values - dbLogger.debug("Clearing serving MME") - result.serving_mme = None - result.serving_mme_timestamp = None - result.serving_mme_realm = None - result.serving_mme_peer = None - - session.commit() - objectData = GetObj(SUBSCRIBER, result.subscriber_id) - handleWebhook(objectData, 'UPDATE') - - #Sync state change with geored - if propagate == True: - if 'HSS' in yaml_config['geored'].get('sync_actions', []) and yaml_config['geored'].get('enabled', False) == True: - dbLogger.debug("Propagate MME changes to Geographic PyHSS instances") - handleGeored({ - "imsi": str(imsi), - "serving_mme": result.serving_mme, - "serving_mme_realm": str(result.serving_mme_realm), - "serving_mme_peer": str(result.serving_mme_peer) - }) - else: - dbLogger.debug("Config does not allow sync of HSS events") - except Exception as E: - dbLogger.error("Error occurred, rolling back session: " + str(E)) - raise - finally: - safe_close(session) - + self.logTool.log(service='Database', level='debug', message="Nothing synced to Geographic PyHSS instances for EIR event", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=E, redisClient=self.redisMessaging) -def Update_Serving_CSCF(imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): - dbLogger.debug("Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer)) - Session = sessionmaker(bind = engine) - session = Session() + return - try: - result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() + def Get_IMEI_IMSI_History(self, attribute): + self.logTool.log(service='Database', level='debug', message="Called Get_IMEI_IMSI_History() for entry matching " + str(self.Get_IMEI_IMSI_History), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + result_array = [] try: - assert(type(serving_cscf) == str) - assert(len(serving_cscf) > 0) - dbLogger.debug("Setting serving CSCF") - #Strip duplicate SIP prefix before storing - serving_cscf = serving_cscf.replace("sip:sip:", "sip:") - result.scscf = serving_cscf - result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) - result.scscf_realm = scscf_realm - result.scscf_peer = str(scscf_peer) - except: - #Clear values - dbLogger.debug("Clearing serving CSCF") - result.scscf = None - result.scscf_timestamp = None - result.scscf_realm = None - result.scscf_peer = None - - session.commit() - objectData = GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) - handleWebhook(objectData, 'UPDATE') - - #Sync state change with geored - if propagate == True: - if 'IMS' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - dbLogger.debug("Propagate IMS changes to Geographic PyHSS instances") - handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) - else: - dbLogger.debug("Config does not allow sync of IMS events") - except Exception as E: - dbLogger.error("An error occurred, rolling back session: " + str(E)) - safe_rollback(session) - raise - finally: - safe_close(session) - - -def Update_Serving_APN(imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): - dbLogger.debug("Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn)) - dbLogger.debug("PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing)) - dbLogger.debug("Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer)) - dbLogger.debug("subscriber_routing: " + str(subscriber_routing)) - - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - subscriber_id = subscriber_details['subscriber_id'] - - #Split the APN list into a list - apn_list = subscriber_details['apn_list'].split(',') - dbLogger.debug("Current APN List: " + str(apn_list)) - #Remove the default APN from the list - try: - apn_list.remove(str(subscriber_details['default_apn'])) - except: - dbLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") - pass - #Add default APN in first position - apn_list.insert(0, str(subscriber_details['default_apn'])) - - #Get APN ID from APN - for apn_id in apn_list: - #Get each APN in List - apn_data = Get_APN(apn_id) - dbLogger.debug(apn_data) - if str(apn_data['apn']).lower() == str(apn).lower(): - dbLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) - break - dbLogger.debug("APN ID is " + str(apn_id)) - - json_data = { - 'apn' : apn_id, - 'subscriber_id' : subscriber_id, - 'pcrf_session_id' : str(pcrf_session_id), - 'serving_pgw' : str(serving_pgw), - 'serving_pgw_realm' : str(serving_pgw_realm), - 'serving_pgw_peer' : str(serving_pgw_peer), - 'serving_pgw_timestamp' : datetime.datetime.now(tz=timezone.utc), - 'subscriber_routing' : str(subscriber_routing) - } + results = session.query(IMSI_IMEI_HISTORY).filter(IMSI_IMEI_HISTORY.imsi_imei.ilike("%" + str(attribute) + "%")).all() + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + try: + result['imsi'] = result['imsi_imei'].split(",")[0] + except: + continue + try: + result['imei'] = result['imsi_imei'].split(",")[1] + except: + continue + result_array.append(result) + self.safe_close(session) + return result_array + except Exception as E: + self.safe_close(session) + raise ValueError(E) - try: - #Check if already a serving APN on record - dbLogger.debug("Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id)) - ServingAPN = Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) - dbLogger.debug("Existing Serving APN ID on record, updating") + def Check_EIR(self, imsi, imei): + eir_response_code_table = {0 : 'Whitelist', 1: 'Blacklist', 2: 'Greylist'} + self.logTool.log(service='Database', level='debug', message="Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + #Check for Exact Matches + self.logTool.log(service='Database', level='debug', message="Looking for exact matches", redisClient=self.redisMessaging) + #Check for exact Matches try: - assert(type(serving_pgw) == str) - assert(len(serving_pgw) > 0) - assert("None" not in serving_pgw) - - UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handleWebhook(objectData, 'UPDATE') - except: - dbLogger.debug("Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id'])) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handleWebhook(objectData, 'DELETE') - DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) - except Exception as E: - dbLogger.info("Failed to update existing APN " + str(E)) - #Create if does not exist - CreateObj(SERVING_APN, json_data, True) - objectData = GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - handleWebhook(objectData, 'CREATE') - - #Sync state change with geored - if propagate == True: + results = session.query(EIR).filter_by(imei=str(imei), regex_mode=0) + for result in results: + result = result.__dict__ + match_response_code = result['match_response_code'] + if result['imsi'] == '': + self.logTool.log(service='Database', level='debug', message="No IMSI specified in DB, so matching only on IMEI", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + elif result['imsi'] == str(imsi): + self.logTool.log(service='Database', level='debug', message="Matched on IMEI and IMSI", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + except Exception as E: + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + + self.logTool.log(service='Database', level='debug', message="Did not match any Exact Matches - Checking Regex", redisClient=self.redisMessaging) try: - if 'PCRF' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - dbLogger.debug("Propagate PCRF changes to Geographic PyHSS instances") - handleGeored({"imsi": str(imsi), - 'serving_apn' : str(apn), - 'pcrf_session_id': str(pcrf_session_id), - 'serving_pgw': str(serving_pgw), - 'serving_pgw_realm': str(serving_pgw_realm), - 'serving_pgw_peer': str(serving_pgw_peer), - 'subscriber_routing': str(subscriber_routing) - }) - else: - dbLogger.debug("Config does not allow sync of PCRF events") + results = session.query(EIR).filter_by(regex_mode=1) #Get all Regex records from DB + for result in results: + result = result.__dict__ + match_response_code = result['match_response_code'] + if re.match(result['imei'], imei): + self.logTool.log(service='Database', level='debug', message="IMEI matched " + str(result['imei']), redisClient=self.redisMessaging) + #Check if IMSI also specified + if len(result['imsi']) != 0: + self.logTool.log(service='Database', level='debug', message="With IMEI matched, now checking if IMSI matches regex", redisClient=self.redisMessaging) + if re.match(result['imsi'], imsi): + self.logTool.log(service='Database', level='debug', message="IMSI also matched, so match OK!", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code + else: + self.logTool.log(service='Database', level='debug', message="No IMSI specified, so match OK!", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) + return match_response_code except Exception as E: - dbLogger.debug("Nothing synced to Geographic PyHSS instances for event PCRF") - - - return - -def Get_Serving_APN(subscriber_id, apn_id): - dbLogger.debug("Getting Serving APN " + str(apn_id) + " with subscriber_id " + str(subscriber_id)) - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(SERVING_APN).filter_by(subscriber_id=subscriber_id, apn=apn_id).first() - except Exception as E: - dbLogger.debug(E) - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - - safe_close(session) - return result - -def Get_Charging_Rule(charging_rule_id): - dbLogger.debug("Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id)) - Session = sessionmaker(bind = engine) - session = Session() - #Get base Rule - ChargingRule = GetObj(CHARGING_RULE, charging_rule_id) - ChargingRule['tft'] = [] - #Get TFTs - try: - results = session.query(TFT).filter_by(tft_group_id=ChargingRule['tft_group_id']) - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - ChargingRule['tft'].append(result) - except Exception as E: - safe_close(session) - raise ValueError(E) - safe_close(session) - return ChargingRule - -def Get_Charging_Rules(imsi, apn): - dbLogger.debug("Called Get_Charging_Rules() for IMSI " + str(imsi) + " and APN " + str(apn)) - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - - #Split the APN list into a list - apn_list = subscriber_details['apn_list'].split(',') - dbLogger.debug("Current APN List: " + str(apn_list)) - #Remove the default APN from the list - try: - apn_list.remove(str(subscriber_details['default_apn'])) - except: - dbLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") - pass - #Add default APN in first position - apn_list.insert(0, str(subscriber_details['default_apn'])) - - #Get APN ID from APN - for apn_id in apn_list: - dbLogger.debug("Getting APN ID " + str(apn_id) + " to see if it matches APN " + str(apn)) - #Get each APN in List - apn_data = Get_APN(apn_id) - dbLogger.debug(apn_data) - if str(apn_data['apn']).lower() == str(apn).lower(): - dbLogger.debug("Matched named APN " + str(apn_data['apn']) + " with APN ID " + str(apn_id)) - - dbLogger.debug("Getting charging rule list from " + str(apn_data['charging_rule_list'])) - ChargingRule = {} - ChargingRule['charging_rule_list'] = str(apn_data['charging_rule_list']).split(',') - ChargingRule['apn_data'] = apn_data - - #Get Charging Rules list - if apn_data['charging_rule_list'] == None: - dbLogger.debug("No Charging Rule associated with this APN") - ChargingRule['charging_rules'] = None - return ChargingRule - - dbLogger.debug("ChargingRule['charging_rule_list'] is: " + str(ChargingRule['charging_rule_list'])) - #Empty dict for the Charging Rules to go into - ChargingRule['charging_rules'] = [] - #Add each of the Charging Rules for the APN - for individual_charging_rule in ChargingRule['charging_rule_list']: - dbLogger.debug("Getting Charging rule " + str(individual_charging_rule)) - individual_charging_rule_complete = Get_Charging_Rule(individual_charging_rule) - dbLogger.debug("Got individual_charging_rule_complete: " + str(individual_charging_rule_complete)) - ChargingRule['charging_rules'].append(individual_charging_rule_complete) - dbLogger.debug("Completed Get_Charging_Rules()") - dbLogger.debug(ChargingRule) - return ChargingRule - -def Get_UE_by_IP(subscriber_routing): - dbLogger.debug("Called Get_UE_by_IP() for IP " + str(subscriber_routing)) - - Session = sessionmaker(bind = engine) - session = Session() - - try: - result = session.query(SERVING_APN).filter_by(subscriber_routing=subscriber_routing).one() - except Exception as E: - safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - return result - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) - -def Store_IMSI_IMEI_Binding(imsi, imei, match_response_code, propagate=True): - #IMSI 14-15 Digits - #IMEI 15 Digits - #IMEI-SV 2 Digits - dbLogger.debug("Called Store_IMSI_IMEI_Binding() with IMSI: " + str(imsi) + " IMEI: " + str(imei) + " match_response_code: " + str(match_response_code)) - if yaml_config['eir']['imsi_imei_logging'] != True: - dbLogger.debug("Skipping storing binding") - return - #Concat IMEI + IMSI - imsi_imei = str(imsi) + "," + str(imei) - Session = sessionmaker(bind = engine) - session = Session() + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) - #Check if exist already & update - try: - session.query(IMSI_IMEI_HISTORY).filter_by(imsi_imei=imsi_imei).one() - dbLogger.debug("Entry already present for IMSI/IMEI Combo") - safe_close(session) - return - except Exception as E: - newObj = IMSI_IMEI_HISTORY(imsi_imei=imsi_imei, match_response_code=match_response_code, imsi_imei_timestamp = datetime.datetime.now(tz=timezone.utc)) - session.add(newObj) try: session.commit() except Exception as E: - dbLogger.error("Failed to commit session, error: " + str(E)) - safe_rollback(session) - safe_close(session) + self.logTool.log(service='Database', level='error', message="Failed to commit session, error: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + self.safe_close(session) + raise ValueError(E) + self.logTool.log(service='Database', level='debug', message="No matches at all - Returning default response", redisClient=self.redisMessaging) + self.Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=self.config['eir']['no_match_response']) + self.safe_close(session) + return self.config['eir']['no_match_response'] + + def Get_EIR_Rules(self): + self.logTool.log(service='Database', level='debug', message="Getting all EIR Rules", redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + EIR_Rules = [] + try: + results = session.query(EIR) + for result in results: + result = result.__dict__ + result.pop('_sa_instance_state') + EIR_Rules.append(result) + except Exception as E: + self.safe_rollback(session) + self.safe_close(session) raise ValueError(E) - safe_close(session) - dbLogger.debug("Added new IMSI_IMEI_HISTORY binding") + self.logTool.log(service='Database', level='debug', message="Final EIR_Rules: " + str(EIR_Rules), redisClient=self.redisMessaging) + self.safe_close(session) + return EIR_Rules - if 'sim_swap_notify_webhook' in yaml_config['eir']: - dbLogger.debug("Sending SIM Swap notification to Webhook") - try: - dictToSend = {'imei':imei, 'imsi': imsi, 'match_response_code': match_response_code} - handleWebhook(dictToSend) - except Exception as E: - dbLogger.debug("Failed to post to Webhook") - dbLogger.debug(str(E)) - #Lookup Device Info - if 'tac_database_csv' in yaml_config['eir']: - try: - device_info = get_device_info_from_TAC(imei=str(imei)) - dbLogger.debug("Got Device Info: " + str(device_info)) - #@@Fixme - # prom_eir_devices.labels( - # imei_prefix=device_info['tac_prefix'], - # device_type=device_info['name'], - # device_name=device_info['model'] - # ).inc() - except Exception as E: - dbLogger.debug("Failed to get device info from TAC") - # prom_eir_devices.labels( - # imei_prefix=str(imei)[0:8], - # device_type='Unknown', - # device_name='Unknown' - # ).inc() - else: - dbLogger.debug("No TAC database configured, skipping device info lookup") + def dict_bytes_to_dict_string(self, dict_bytes): + dict_string = {} + for key, value in dict_bytes.items(): + dict_string[key.decode()] = value.decode() + return dict_string - #Sync state change with geored - if propagate == True: - try: - if 'EIR' in yaml_config['geored']['sync_actions'] and yaml_config['geored']['enabled'] == True: - dbLogger.debug("Propagate EIR changes to Geographic PyHSS instances") - handleGeored( - {"imsi": str(imsi), - "imei": str(imei), - "match_response_code": str(match_response_code)} - ) - else: - dbLogger.debug("Config does not allow sync of EIR events") - except Exception as E: - dbLogger.debug("Nothing synced to Geographic PyHSS instances for EIR event") - dbLogger.debug(E) - return - -def Get_IMEI_IMSI_History(attribute): - dbLogger.debug("Called Get_IMEI_IMSI_History() for entry matching " + str(Get_IMEI_IMSI_History)) - Session = sessionmaker(bind = engine) - session = Session() - result_array = [] - try: - results = session.query(IMSI_IMEI_HISTORY).filter(IMSI_IMEI_HISTORY.imsi_imei.ilike("%" + str(attribute) + "%")).all() - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - result = Sanitize_Datetime(result) - try: - result['imsi'] = result['imsi_imei'].split(",")[0] - except: - continue - try: - result['imei'] = result['imsi_imei'].split(",")[1] - except: - continue - result_array.append(result) - safe_close(session) - return result_array - except Exception as E: - safe_close(session) - raise ValueError(E) - -def Check_EIR(imsi, imei): - eir_response_code_table = {0 : 'Whitelist', 1: 'Blacklist', 2: 'Greylist'} - dbLogger.debug("Called Check_EIR() for imsi " + str(imsi) + " and imei: " + str(imei)) - Session = sessionmaker(bind = engine) - session = Session() - #Check for Exact Matches - dbLogger.debug("Looking for exact matches") - #Check for exact Matches - try: - results = session.query(EIR).filter_by(imei=str(imei), regex_mode=0) - for result in results: - result = result.__dict__ - match_response_code = result['match_response_code'] - if result['imsi'] == '': - dbLogger.debug("No IMSI specified in DB, so matching only on IMEI") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - elif result['imsi'] == str(imsi): - dbLogger.debug("Matched on IMEI and IMSI") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - dbLogger.debug("Did not match any Exact Matches - Checking Regex") - try: - results = session.query(EIR).filter_by(regex_mode=1) #Get all Regex records from DB - for result in results: - result = result.__dict__ - match_response_code = result['match_response_code'] - if re.match(result['imei'], imei): - dbLogger.debug("IMEI matched " + str(result['imei'])) - #Check if IMSI also specified - if len(result['imsi']) != 0: - dbLogger.debug("With IMEI matched, now checking if IMSI matches regex") - if re.match(result['imsi'], imsi): - dbLogger.debug("IMSI also matched, so match OK!") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - else: - dbLogger.debug("No IMSI specified, so match OK!") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=match_response_code) - return match_response_code - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - - try: - session.commit() - except Exception as E: - dbLogger.error("Failed to commit session, error: " + str(E)) - safe_rollback(session) - safe_close(session) - raise ValueError(E) - dbLogger.debug("No matches at all - Returning default response") - Store_IMSI_IMEI_Binding(imsi=imsi, imei=imei, match_response_code=yaml_config['eir']['no_match_response']) - safe_close(session) - return yaml_config['eir']['no_match_response'] - -def Get_EIR_Rules(): - dbLogger.debug("Getting all EIR Rules") - Session = sessionmaker(bind = engine) - session = Session() - EIR_Rules = [] - try: - results = session.query(EIR) - for result in results: - result = result.__dict__ - result.pop('_sa_instance_state') - EIR_Rules.append(result) - except Exception as E: - safe_rollback(session) - safe_close(session) - raise ValueError(E) - dbLogger.debug("Final EIR_Rules: " + str(EIR_Rules)) - safe_close(session) - return EIR_Rules - - -def dict_bytes_to_dict_string(dict_bytes): - dict_string = {} - for key, value in dict_bytes.items(): - dict_string[key.decode()] = value.decode() - return dict_string - - -def get_device_info_from_TAC(imei): - dbLogger.debug("Getting Device Info from IMEI: " + str(imei)) - #Try 8 digit TAC - try: - dbLogger.debug("Trying to match on 8 Digit IMEI") - #@@Fixme - # imei_result = logtool.RedisHMGET(str(imei[0:8])) - # print("Got back: " + str(imei_result)) - # imei_result = dict_bytes_to_dict_string(imei_result) - # assert(len(imei_result) != 0) - # dbLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - # return imei_result - return "0" - except: - dbLogger.debug("Failed to match on 8 digit IMEI") - - try: - dbLogger.debug("Trying to match on 6 Digit IMEI") - #@@Fixme - # imei_result = logtool.RedisHMGET(str(imei[0:6])) - # print("Got back: " + str(imei_result)) - # imei_result = dict_bytes_to_dict_string(imei_result) - # assert(len(imei_result) != 0) - # dbLogger.debug("Found match for IMEI " + str(imei) + " with result " + str(imei_result)) - # return imei_result - return "0" - except: - dbLogger.debug("Failed to match on 6 digit IMEI") + def get_device_info_from_TAC(self, imei): + self.logTool.log(service='Database', level='debug', message="Getting Device Info from IMEI: " + str(imei), redisClient=self.redisMessaging) + #Try 8 digit TAC + try: + self.logTool.log(service='Database', level='debug', message="Trying to match on 8 Digit IMEI", redisClient=self.redisMessaging) + #@@Fixme + # imei_result = logtool.RedisHMGET(str(imei[0:8])) + # print("Got back: " + str(imei_result)) + # imei_result = dict_bytes_to_dict_string(imei_result) + # assert(len(imei_result) != 0) + # self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) + # return imei_result + return "0" + except: + self.logTool.log(service='Database', level='debug', message="Failed to match on 8 digit IMEI", redisClient=self.redisMessaging) + + try: + self.logTool.log(service='Database', level='debug', message="Trying to match on 6 Digit IMEI", redisClient=self.redisMessaging) + #@@Fixme + # imei_result = logtool.RedisHMGET(str(imei[0:6])) + # print("Got back: " + str(imei_result)) + # imei_result = dict_bytes_to_dict_string(imei_result) + # assert(len(imei_result) != 0) + # self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) + # return imei_result + return "0" + except: + self.logTool.log(service='Database', level='debug', message="Failed to match on 6 digit IMEI", redisClient=self.redisMessaging) - raise ValueError("No matching TAC in IMEI Database") + raise ValueError("No matching TAC in IMEI Database") if __name__ == "__main__": import binascii,os,pprint DeleteAfter = True + database = Database() #Define Charging Rule charging_rule = { @@ -2044,12 +2042,12 @@ def get_device_info_from_TAC(imei): 'rating_group' : 20000 } print("Creating Charging Rule A") - ChargingRule_newObj_A = CreateObj(CHARGING_RULE, charging_rule) + ChargingRule_newObj_A = database.CreateObj(CHARGING_RULE, charging_rule) print("ChargingRule_newObj A: " + str(ChargingRule_newObj_A)) charging_rule['gbr_ul'], charging_rule['gbr_dl'], charging_rule['mbr_ul'], charging_rule['mbr_dl'] = 256000, 256000, 256000, 256000 print("Creating Charging Rule B") charging_rule['rule_name'], charging_rule['precedence'], charging_rule['tft_group_id'] = 'charging_rule_B', 80, 2 - ChargingRule_newObj_B = CreateObj(CHARGING_RULE, charging_rule) + ChargingRule_newObj_B = database.CreateObj(CHARGING_RULE, charging_rule) print("ChargingRule_newObj B: " + str(ChargingRule_newObj_B)) #Define TFTs @@ -2064,8 +2062,8 @@ def get_device_info_from_TAC(imei): 'direction' : 2 } print("Creating TFT") - CreateObj(TFT, tft_template1) - CreateObj(TFT, tft_template2) + database.CreateObj(TFT, tft_template1) + database.CreateObj(TFT, tft_template2) tft_template3 = { 'tft_group_id' : 2, @@ -2078,8 +2076,8 @@ def get_device_info_from_TAC(imei): 'direction' : 2 } print("Creating TFT") - CreateObj(TFT, tft_template3) - CreateObj(TFT, tft_template4) + database.CreateObj(TFT, tft_template3) + database.CreateObj(TFT, tft_template4) apn2 = { @@ -2092,17 +2090,17 @@ def get_device_info_from_TAC(imei): 'charging_rule_list' : str(ChargingRule_newObj_A['charging_rule_id']) + "," + str(ChargingRule_newObj_B['charging_rule_id']) } print("Creating APN " + str(apn2['apn'])) - newObj = CreateObj(APN, apn2) + newObj = database.CreateObj(APN, apn2) print(newObj) print("Getting APN " + str(apn2['apn'])) - print(GetObj(APN, newObj['apn_id'])) + print(database.GetObj(APN, newObj['apn_id'])) apn_id = newObj['apn_id'] UpdatedObj = newObj UpdatedObj['apn'] = 'UpdatedInUnitTest' print("Updating APN " + str(apn2['apn'])) - newObj = UpdateObj(APN, UpdatedObj, newObj['apn_id']) + newObj = database.UpdateObj(APN, UpdatedObj, newObj['apn_id']) print(newObj) #Create AuC @@ -2114,28 +2112,28 @@ def get_device_info_from_TAC(imei): } print(auc_json) print("Creating AuC entry") - newObj = CreateObj(AUC, auc_json) + newObj = database.CreateObj(AUC, auc_json) print(newObj) #Get AuC print("Getting AuC entry") - newObj = GetObj(AUC, newObj['auc_id']) + newObj = database.GetObj(AUC, newObj['auc_id']) auc_id = newObj['auc_id'] print(newObj) #Update AuC print("Updating AuC entry") newObj['sqn'] = newObj['sqn'] + 10 - newObj = UpdateObj(AUC, newObj, auc_id) + newObj = database.UpdateObj(AUC, newObj, auc_id) #Generate Vectors print("Generating Vectors") - Get_Vectors_AuC(auc_id, "air", plmn='12ff') - print(Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) + database.Get_Vectors_AuC(auc_id, "air", plmn='12ff') + print(database.Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) #Update AuC - Update_AuC(auc_id, sqn=100) + database.Update_AuC(auc_id, sqn=100) #New Subscriber subscriber_json = { @@ -2153,37 +2151,37 @@ def get_device_info_from_TAC(imei): #Delete IMSI if already exists try: - existing_sub_data = Get_Subscriber(imsi=subscriber_json['imsi']) - DeleteObj(SUBSCRIBER, existing_sub_data['subscriber_id']) + existing_sub_data = database.Get_Subscriber(imsi=subscriber_json['imsi']) + database.DeleteObj(SUBSCRIBER, existing_sub_data['subscriber_id']) except: print("Did not find old sub to delete") print("Creating new Subscriber") print(subscriber_json) - newObj = CreateObj(SUBSCRIBER, subscriber_json) + newObj = database.CreateObj(SUBSCRIBER, subscriber_json) print(newObj) subscriber_id = newObj['subscriber_id'] #Get SUBSCRIBER print("Getting Subscriber") - newObj = GetObj(SUBSCRIBER, subscriber_id) + newObj = database.GetObj(SUBSCRIBER, subscriber_id) print(newObj) #Update SUBSCRIBER print("Updating Subscriber") newObj['ue_ambr_ul'] = 999995 - newObj = UpdateObj(SUBSCRIBER, newObj, subscriber_id) + newObj = database.UpdateObj(SUBSCRIBER, newObj, subscriber_id) #Set MME Location for Subscriber print("Updating Serving MME for Subscriber") - Update_Serving_MME(imsi=newObj['imsi'], serving_mme="Test123", serving_mme_peer="Test123", serving_mme_realm="TestRealm") + database.Update_Serving_MME(imsi=newObj['imsi'], serving_mme="Test123", serving_mme_peer="Test123", serving_mme_realm="TestRealm") #Update Serving APN for Subscriber print("Updating Serving APN for Subscriber") - Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='kjsdlkjfd', serving_pgw='pgw.test.com', subscriber_routing='1.2.3.4') + database.Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='kjsdlkjfd', serving_pgw='pgw.test.com', subscriber_routing='1.2.3.4') print("Getting Charging Rule for Subscriber / APN Combo") - ChargingRule = Get_Charging_Rules(imsi=newObj['imsi'], apn=apn2['apn']) + ChargingRule = database.Get_Charging_Rules(imsi=newObj['imsi'], apn=apn2['apn']) pprint.pprint(ChargingRule) #New IMS Subscriber @@ -2195,43 +2193,43 @@ def get_device_info_from_TAC(imei): "sh_profile" : "default_sh_user_data.xml" } print(ims_subscriber_json) - newObj = CreateObj(IMS_SUBSCRIBER, ims_subscriber_json) + newObj = database.CreateObj(IMS_SUBSCRIBER, ims_subscriber_json) print(newObj) ims_subscriber_id = newObj['ims_subscriber_id'] #Test Get Subscriber print("Test Getting Subscriber") - GetSubscriber_Result = Get_Subscriber(imsi=subscriber_json['imsi']) + GetSubscriber_Result = database.Get_Subscriber(imsi=subscriber_json['imsi']) print(GetSubscriber_Result) #Test IMS Get Subscriber print("Getting IMS Subscribers") - print(Get_IMS_Subscriber(imsi='001001000000006')) - print(Get_IMS_Subscriber(msisdn='12345678')) + print(database.Get_IMS_Subscriber(imsi='001001000000006')) + print(database.Get_IMS_Subscriber(msisdn='12345678')) #Set SCSCF for Subscriber - Update_Serving_CSCF(newObj['imsi'], "NickTestCSCF") + database.Update_Serving_CSCF(newObj['imsi'], "NickTestCSCF") #Get Served Subscriber List - print(Get_Served_IMS_Subscribers()) + print(database.Get_Served_IMS_Subscribers()) #Clear Serving PGW for PCRF Subscriber print("Clear Serving PGW for PCRF Subscriber") - Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='sessionid123', serving_pgw=None, subscriber_routing=None) + database.Update_Serving_APN(imsi=newObj['imsi'], apn=apn2['apn'], pcrf_session_id='sessionid123', serving_pgw=None, subscriber_routing=None) #Clear MME Location for Subscriber print("Clear MME Location for Subscriber") - Update_Serving_MME(newObj['imsi'], None) + database.Update_Serving_MME(newObj['imsi'], None) #Generate Vectors for IMS Subscriber print("Generating Vectors for IMS Subscriber") - print(Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) + print(database.Get_Vectors_AuC(auc_id, "sip_auth", plmn='12ff')) #print("Generating Resync for IMS Subscriber") #print(Get_Vectors_AuC(auc_id, "sqn_resync", auts='7964347dfdfe432289522183fcfb', rand='1bc9f096002d3716c65e4e1f4c1c0d17')) #Test getting APNs - GetAPN_Result = Get_APN(GetSubscriber_Result['default_apn']) + GetAPN_Result = database.Get_APN(GetSubscriber_Result['default_apn']) print(GetAPN_Result) #handleGeored({"imsi": "001001000000006", "serving_mme": "abc123"}) @@ -2239,51 +2237,51 @@ def get_device_info_from_TAC(imei): if DeleteAfter == True: #Delete IMS Subscriber - print(DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id)) + print(database.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id)) #Delete Subscriber - print(DeleteObj(SUBSCRIBER, subscriber_id)) + print(database.DeleteObj(SUBSCRIBER, subscriber_id)) #Delete AuC - print(DeleteObj(AUC, auc_id)) + print(database.DeleteObj(AUC, auc_id)) #Delete APN - print(DeleteObj(APN, apn_id)) + print(database.DeleteObj(APN, apn_id)) #Whitelist IMEI / IMSI Binding eir_template = {'imei': '1234', 'imsi': '567', 'regex_mode': 0, 'match_response_code': 0} - CreateObj(EIR, eir_template) + database.CreateObj(EIR, eir_template) #Blacklist Example eir_template = {'imei': '99881232', 'imsi': '', 'regex_mode': 0, 'match_response_code': 1} - CreateObj(EIR, eir_template) + database.CreateObj(EIR, eir_template) #IMEI Prefix Regex Example (Blacklist all IMEIs starting with 666) eir_template = {'imei': '^666.*', 'imsi': '', 'regex_mode': 1, 'match_response_code': 1} - CreateObj(EIR, eir_template) + database.CreateObj(EIR, eir_template) #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) eir_template = {'imei': '^777.*', 'imsi': '^1234123412341234$', 'regex_mode': 1, 'match_response_code': 2} - CreateObj(EIR, eir_template) + database.CreateObj(EIR, eir_template) print("\n\n\n\n") #Check Whitelist (No Match) - assert Check_EIR(imei='1234', imsi='') == 2 + assert database.Check_EIR(imei='1234', imsi='') == 2 print("\n\n\n\n") #Check Whitelist (Matched) - assert Check_EIR(imei='1234', imsi='567') == 0 + assert database.Check_EIR(imei='1234', imsi='567') == 0 print("\n\n\n\n") #Check Blacklist (Match) - assert Check_EIR(imei='99881232', imsi='567') == 1 + assert database.Check_EIR(imei='99881232', imsi='567') == 1 print("\n\n\n\n") #IMEI Prefix Regex Example (Greylist response for IMEI starting with 777 and IMSI is 1234123412341234) - assert Check_EIR(imei='7771234', imsi='1234123412341234') == 2 + assert database.Check_EIR(imei='7771234', imsi='1234123412341234') == 2 - print(Get_IMEI_IMSI_History('1234123412')) + print(database.Get_IMEI_IMSI_History('1234123412')) print("\n\n\n") - print(Generate_JSON_Model_for_Flask(SUBSCRIBER)) + print(database.Generate_JSON_Model_for_Flask(SUBSCRIBER)) diff --git a/lib/diameter.py b/lib/diameter.py index 0be72ae..65c1936 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1,8 +1,5 @@ #Diameter Packet Decoder / Encoder & Tools -from multiprocessing.sharedctypes import Value import socket -import logging -import sys import binascii import math import uuid @@ -10,14 +7,12 @@ import random import ipaddress import jinja2 -import traceback -import database +from database import Database import yaml -from typing import Literal class Diameter: - def __init__(self, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999"): + def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999"): with open("../config.yaml", 'r') as stream: self.yaml_config = (yaml.safe_load(stream)) @@ -26,10 +21,15 @@ def __init__(self, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999. self.ProductName = self.string_to_hex(productName) self.MNC = str(mnc) self.MCC = str(mcc) - self.diameterLibLogger = logging.getLogger('DiameterLibLogger') + self.logTool = logTool + self.redisMessaging=redisMessaging + self.database = Database(logTool=logTool, redisMessaging=redisMessaging) - self.diameterLibLogger.info("Initialized Diameter for " + str(self.OriginHost) + " at Realm " + str(self.OriginRealm) + " serving as Product Name " + str(self.ProductName)) - self.diameterLibLogger.info("PLMN is " + str(self.MCC) + "/" + str(self.MNC)) + self.logTool.log(service='HSS', level='info', message=f"Initialized Diameter Library", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Origin Host: {str(originHost)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Realm: {str(originRealm)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"Product Name: {str(productName)}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=f"PLMN: {str(self.MCC)}/{str(self.MNC)}", redisClient=self.redisMessaging) self.diameterCommandList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, @@ -72,7 +72,7 @@ def ip_to_hex(self, ip): else: ip_hex = "0002" #IPv6 ip_hex += format(ipaddress.IPv6Address(ip), 'X') - #self.diameterLibLogger.debug("Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex)) + #self.logTool.log(service='HSS', level='debug', message="Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex), redisClient=self.redisMessaging) return ip_hex def hex_to_int(self, hex): @@ -122,12 +122,12 @@ def Reverse(self, str): return (slicedString) def DecodePLMN(self, plmn): - self.diameterLibLogger.debug("Decoded PLMN: " + str(plmn)) + self.logTool.log(service='HSS', level='debug', message="Decoded PLMN: " + str(plmn), redisClient=self.redisMessaging) mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') - self.diameterLibLogger.debug("Decoded MCC: " + mcc) + self.logTool.log(service='HSS', level='debug', message="Decoded MCC: " + mcc, redisClient=self.redisMessaging) mnc = self.Reverse(plmn[4:6]) - self.diameterLibLogger.debug("Decoded MNC: " + mnc) + self.logTool.log(service='HSS', level='debug', message="Decoded MNC: " + mnc, redisClient=self.redisMessaging) return mcc, mnc def EncodePLMN(self, mcc, mnc): @@ -142,50 +142,50 @@ def EncodePLMN(self, mcc, mnc): plmn = '' for bits in plmn_list: plmn = plmn + bits - self.diameterLibLogger.debug("Encoded PLMN: " + str(plmn)) + self.logTool.log(service='HSS', level='debug', message="Encoded PLMN: " + str(plmn), redisClient=self.redisMessaging) return plmn def TBCD_special_chars(self, input): - self.diameterLibLogger.debug("Special character possible in " + str(input)) + self.logTool.log(service='HSS', level='debug', message="Special character possible in " + str(input), redisClient=self.redisMessaging) if input == "*": - self.diameterLibLogger.debug("Found * - Returning 1010") + self.logTool.log(service='HSS', level='debug', message="Found * - Returning 1010", redisClient=self.redisMessaging) return "1010" elif input == "#": - self.diameterLibLogger.debug("Found # - Returning 1011") + self.logTool.log(service='HSS', level='debug', message="Found # - Returning 1011", redisClient=self.redisMessaging) return "1011" elif input == "a": - self.diameterLibLogger.debug("Found a - Returning 1100") + self.logTool.log(service='HSS', level='debug', message="Found a - Returning 1100", redisClient=self.redisMessaging) return "1100" elif input == "b": - self.diameterLibLogger.debug("Found b - Returning 1101") + self.logTool.log(service='HSS', level='debug', message="Found b - Returning 1101", redisClient=self.redisMessaging) return "1101" elif input == "c": - self.diameterLibLogger.debug("Found c - Returning 1100") + self.logTool.log(service='HSS', level='debug', message="Found c - Returning 1100", redisClient=self.redisMessaging) return "1100" else: binform = "{:04b}".format(int(input)) - self.diameterLibLogger.debug("input " + str(input) + " is not a special char, converted to bin: " + str(binform)) + self.logTool.log(service='HSS', level='debug', message="input " + str(input) + " is not a special char, converted to bin: " + str(binform), redisClient=self.redisMessaging) return (binform) def TBCD_encode(self, input): - self.diameterLibLogger.debug("TBCD_encode input value is " + str(input)) + self.logTool.log(service='HSS', level='debug', message="TBCD_encode input value is " + str(input), redisClient=self.redisMessaging) offset = 0 output = '' matches = ['*', '#', 'a', 'b', 'c'] while offset < len(input): if len(input[offset:offset+2]) == 2: - self.diameterLibLogger.debug("processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset)) + self.logTool.log(service='HSS', level='debug', message="processing bits " + str(input[offset:offset+2]) + " at position offset " + str(offset), redisClient=self.redisMessaging) bit = input[offset:offset+2] #Get two digits at a time bit = bit[::-1] #Reverse them #Check if *, #, a, b or c if any(x in bit for x in matches): - self.diameterLibLogger.debug("Special char in bit " + str(bit)) + self.logTool.log(service='HSS', level='debug', message="Special char in bit " + str(bit), redisClient=self.redisMessaging) new_bit = '' new_bit = new_bit + str(self.TBCD_special_chars(bit[0])) new_bit = new_bit + str(self.TBCD_special_chars(bit[1])) - self.diameterLibLogger.debug("Final bin output of new_bit is " + str(new_bit)) + self.logTool.log(service='HSS', level='debug', message="Final bin output of new_bit is " + str(new_bit), redisClient=self.redisMessaging) bit = hex(int(new_bit, 2))[2:] #Get Hex value - self.diameterLibLogger.debug("Formatted as Hex this is " + str(bit)) + self.logTool.log(service='HSS', level='debug', message="Formatted as Hex this is " + str(bit), redisClient=self.redisMessaging) output = output + bit offset = offset + 2 else: @@ -193,23 +193,23 @@ def TBCD_encode(self, input): last_digit = str(input[offset:offset+2]) #Check if *, #, a, b or c if any(x in last_digit for x in matches): - self.diameterLibLogger.debug("Special char in bit " + str(bit)) + self.logTool.log(service='HSS', level='debug', message="Special char in bit " + str(bit), redisClient=self.redisMessaging) new_bit = '' new_bit = new_bit + '1111' #Add the F first #Encode the symbol into binary and append it to the new_bit var new_bit = new_bit + str(self.TBCD_special_chars(last_digit)) - self.diameterLibLogger.debug("Final bin output of new_bit is " + str(new_bit)) + self.logTool.log(service='HSS', level='debug', message="Final bin output of new_bit is " + str(new_bit), redisClient=self.redisMessaging) bit = hex(int(new_bit, 2))[2:] #Get Hex value - self.diameterLibLogger.debug("Formatted as Hex this is " + str(bit)) + self.logTool.log(service='HSS', level='debug', message="Formatted as Hex this is " + str(bit), redisClient=self.redisMessaging) else: bit = "f" + last_digit offset = offset + 2 output = output + bit - self.diameterLibLogger.debug("TBCD_encode final output value is " + str(output)) + self.logTool.log(service='HSS', level='debug', message="TBCD_encode final output value is " + str(output), redisClient=self.redisMessaging) return output def TBCD_decode(self, input): - self.diameterLibLogger.debug("TBCD_decode Input value is " + str(input)) + self.logTool.log(service='HSS', level='debug', message="TBCD_decode Input value is " + str(input), redisClient=self.redisMessaging) offset = 0 output = '' while offset < len(input): @@ -221,7 +221,7 @@ def TBCD_decode(self, input): else: #If f in bit strip it bit = input[offset:offset+2] output = output + bit[1] - self.diameterLibLogger.debug("TBCD_decode output value is " + str(output)) + self.logTool.log(service='HSS', level='debug', message="TBCD_decode output value is " + str(output), redisClient=self.redisMessaging) return output #Generates an AVP with inputs provided (AVP Code, AVP Flags, AVP Content, Padding) @@ -261,8 +261,8 @@ def generate_vendor_avp(self, avp_code, avp_flags, avp_vendorid, avp_content): avp_padding = '' else: #Not multiple of 4 - Padding needed rounded_value = self.myround(avp_length) - self.diameterLibLogger.debug("Rounded value is " + str(rounded_value)) - self.diameterLibLogger.debug("Has " + str( int( rounded_value - avp_length)) + " bytes of padding") + self.logTool.log(service='HSS', level='debug', message="Rounded value is " + str(rounded_value), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Has " + str( int( rounded_value - avp_length)) + " bytes of padding", redisClient=self.redisMessaging) avp_padding = format(0,"x").zfill(int( rounded_value - avp_length) * 2) @@ -363,13 +363,11 @@ def decode_avp_packet(self, data): except Exception as e: if str(e) == "invalid literal for int() with base 16: ''": - logging.debug("AVP length 0 error") pass elif str(e) == "Length of data is too short to be valid AVP": - logging.debug("AVP length 0 error v2") pass else: - self.diameterLibLogger.debug("failed to decode sub-avp - error: " + str(e)) + self.logTool.log(service='HSS', level='debug', message="failed to decode sub-avp - error: " + str(e), redisClient=self.redisMessaging) pass @@ -404,38 +402,49 @@ def getDiameterMessageType(self, binaryData: str) -> dict: assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) response['inbound'] = diameterApplication["requestAcronym"] response['outbound'] = diameterApplication["responseAcronym"] - self.diameterLibLogger.debug(f"[diameter.py] Successfully generated response: {response}") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Successfully generated response: {response}", redisClient=self.redisMessaging) except Exception as e: continue return response def generateDiameterResponse(self, binaryData: str) -> str: - packet_vars, avps = self.decode_diameter_packet(binaryData) - origin_host = self.get_avp_data(avps, 264)[0] - origin_host = binascii.unhexlify(origin_host).decode("utf-8") - response = '' - - self.diameterLibLogger.debug(f"Generating a diameter response") - - # Drop packet if it's a response packet: - if packet_vars["flags_bin"][0:1] == "0": - self.diameterLibLogger.debug("Got a Response, not a request - dropping it.") - self.diameterLibLogger.debug(packet_vars) - return - - for diameterApplication in self.diameterCommandList: - try: - assert(packet_vars["command_code"] == diameterApplication["commandCode"]) - assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) - if 'flags' in diameterApplication: - assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) - response = diameterApplication["responseMethod"](packet_vars, avps) - self.diameterLibLogger.debug(f"[diameter.py] Successfully generated response: {response}") - except Exception as e: - continue - - return response + try: + packet_vars, avps = self.decode_diameter_packet(binaryData) + origin_host = self.get_avp_data(avps, 264)[0] + origin_host = binascii.unhexlify(origin_host).decode("utf-8") + response = '' + + self.logTool.log(service='HSS', level='debug', message=f"Generating a diameter response", redisClient=self.redisMessaging) + + # Drop packet if it's a response packet: + if packet_vars["flags_bin"][0:1] == "0": + self.logTool.log(service='HSS', level='debug', message="Got a Response, not a request - dropping it.", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=packet_vars, redisClient=self.redisMessaging) + return + + for diameterApplication in self.diameterCommandList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if 'flags' in diameterApplication: + assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + response = diameterApplication["responseMethod"](packet_vars, avps) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Successfully generated response: {response}", redisClient=self.redisMessaging) + except Exception as e: + continue + + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_response_count_successful', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Successful Diameter Responses', + metricExpiry=60) + return response + except Exception as e: + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_response_count_fail', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Failed Diameter Responses', + metricExpiry=60) + return '' def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: @@ -446,23 +455,23 @@ def AVP_278_Origin_State_Incriment(self, avps): return origin_state_incriment_hex def Charging_Rule_Generator(self, ChargingRules, ue_ip): - self.diameterLibLogger.debug("Called Charging_Rule_Generator") + self.logTool.log(service='HSS', level='debug', message="Called Charging_Rule_Generator", redisClient=self.redisMessaging) #Install Charging Rules - self.diameterLibLogger.info("Naming Charging Rule") + self.logTool.log(service='HSS', level='info', message="Naming Charging Rule", redisClient=self.redisMessaging) Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) - self.diameterLibLogger.info("Named Charging Rule") + self.logTool.log(service='HSS', level='info', message="Named Charging Rule", redisClient=self.redisMessaging) #Populate all Flow Information AVPs Flow_Information = '' for tft in ChargingRules['tft']: - self.diameterLibLogger.info(tft) + self.logTool.log(service='HSS', level='info', message=tft, redisClient=self.redisMessaging) #If {{ UE_IP }} in TFT splice in the real UE IP Value try: tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) - self.diameterLibLogger.info("Spliced in UE IP into TFT: " + str(tft['tft_string'])) + self.logTool.log(service='HSS', level='info', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) except Exception as E: - self.diameterLibLogger.error("Failed to splice in UE IP into flow description") + self.logTool.log(service='HSS', level='error', message="Failed to splice in UE IP into flow description", redisClient=self.redisMessaging) #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) @@ -470,91 +479,96 @@ def Charging_Rule_Generator(self, ChargingRules, ue_ip): Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) - self.diameterLibLogger.info("Defined Flow_Status: " + str(Flow_Status)) + self.logTool.log(service='HSS', level='info', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) - self.diameterLibLogger.info("Defining QoS information") + self.logTool.log(service='HSS', level='info', message="Defining QoS information", redisClient=self.redisMessaging) #QCI QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) #ARP - self.diameterLibLogger.info("Defining ARP information") + self.logTool.log(service='HSS', level='info', message="Defining ARP information", redisClient=self.redisMessaging) AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_capability']), 4)) AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_vulnerability']), 4)) ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - self.diameterLibLogger.info("Defining MBR information") + self.logTool.log(service='HSS', level='info', message="Defining MBR information", redisClient=self.redisMessaging) #Max Requested Bandwidth Bandwidth_info = '' Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) - self.diameterLibLogger.info("Defining GBR information") + self.logTool.log(service='HSS', level='info', message="Defining GBR information", redisClient=self.redisMessaging) #GBR if int(ChargingRules['gbr_ul']) != 0: Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) if int(ChargingRules['gbr_dl']) != 0: Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) - self.diameterLibLogger.info("Defined Bandwith Info: " + str(Bandwidth_info)) + self.logTool.log(service='HSS', level='info', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) #Populate QoS Information QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) - self.diameterLibLogger.info("Defined QoS_Information: " + str(QoS_Information)) + self.logTool.log(service='HSS', level='info', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) #Precedence - self.diameterLibLogger.info("Defining Precedence information") + self.logTool.log(service='HSS', level='info', message="Defining Precedence information", redisClient=self.redisMessaging) Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) - self.diameterLibLogger.info("Defined Precedence " + str(Precedence)) + self.logTool.log(service='HSS', level='info', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) #Rating Group - self.diameterLibLogger.info("Defining Rating Group information") + self.logTool.log(service='HSS', level='info', message="Defining Rating Group information", redisClient=self.redisMessaging) if ChargingRules['rating_group'] != None: RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID else: RatingGroup = '' - self.diameterLibLogger.info("Defined Rating Group " + str(ChargingRules['rating_group'])) + self.logTool.log(service='HSS', level='info', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) #Complete Charging Rule Defintion - self.diameterLibLogger.info("Collating ChargingRuleDef") + self.logTool.log(service='HSS', level='info', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) #Charging Rule Install - self.diameterLibLogger.info("Collating ChargingRuleDef") + self.logTool.log(service='HSS', level='info', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) def Get_IMS_Subscriber_Details_from_AVP(self, username): #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data username = binascii.unhexlify(username).decode('utf-8') - self.diameterLibLogger.info("Username AVP is present, value is " + str(username)) + self.logTool.log(service='HSS', level='info', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) username = username.split('@')[0] #Strip Domain to get User part username = username[4:] #Strip tel: or sip: prefix #Determine if dealing with IMSI or MSISDN if (len(username) == 15) or (len(username) == 16): - self.diameterLibLogger.debug("We have an IMSI: " + str(username)) - ims_subscriber_details = database.Get_IMS_Subscriber(imsi=username) + self.logTool.log(service='HSS', level='debug', message="We have an IMSI: " + str(username), redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(imsi=username) else: - self.diameterLibLogger.debug("We have an msisdn: " + str(username)) - ims_subscriber_details = database.Get_IMS_Subscriber(msisdn=username) - self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="We have an msisdn: " + str(username), redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(msisdn=username) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) return ims_subscriber_details def Generate_Prom_Stats(self): - self.diameterLibLogger.debug("Called Generate_Prom_Stats") - #@@ Fixme - # try: - # prom_ims_subs_value = len(database.Get_Served_IMS_Subscribers(get_local_users_only=True)) - # prom_ims_subs.set(prom_ims_subs_value) - # prom_mme_subs_value = len(database.Get_Served_Subscribers(get_local_users_only=True)) - # prom_mme_subs.set(prom_mme_subs_value) - # prom_pcrf_subs_value = len(database.Get_Served_PCRF_Subscribers(get_local_users_only=True)) - # prom_pcrf_subs.set(prom_pcrf_subs_value) - # except Exception as e: - # self.diameterLibLogger.debug("Failed to generate Prometheus Stats for IMS Subscribers") - # self.diameterLibLogger.debug(e) - # self.diameterLibLogger.debug("Generated Prometheus Stats for IMS Subscribers") + self.logTool.log(service='HSS', level='debug', message="Called Generate_Prom_Stats", redisClient=self.redisMessaging) + try: + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_ims_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_IMS_Subscribers(get_local_users_only=True)), metricHelp='Number of attached IMS Subscribers', + metricExpiry=60) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_mme_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_Subscribers(get_local_users_only=True)), metricHelp='Number of attached MME Subscribers', + metricExpiry=60) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_pcrf_subs', + metricType='gauge', metricAction='set', + metricValue=len(self.database.Get_Served_PCRF_Subscribers(get_local_users_only=True)), metricHelp='Number of attached PCRF Subscribers', + metricExpiry=60) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message="Failed to generate Prometheus Stats for IMS Subscribers", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=e, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Generated Prometheus Stats for IMS Subscribers", redisClient=self.redisMessaging) return @@ -597,7 +611,7 @@ def Answer_257(self, packet_vars, avps): avp += self.generate_avp(265, 40, format(int(13019),"x").zfill(8)) #Supported-Vendor-ID 13019 (ETSI) response = self.generate_diameter_packet("01", "00", 257, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated CEA") + self.logTool.log(service='HSS', level='debug', message="Successfully Generated CEA", redisClient=self.redisMessaging) return response #Device Watchdog Answer @@ -611,7 +625,7 @@ def Answer_280(self, packet_vars, avps): if avps_to_check['avp_code'] == 278: avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated DWA") + self.logTool.log(service='HSS', level='debug', message="Successfully Generated DWA", redisClient=self.redisMessaging) orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it return response @@ -623,7 +637,7 @@ def Answer_282(self, packet_vars, avps): avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(268, 40, "000007d1") #Result Code (DIAMETER_SUCCESS (2001)) response = self.generate_diameter_packet("01", "00", 282, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated DPA") + self.logTool.log(service='HSS', level='debug', message="Successfully Generated DPA", redisClient=self.redisMessaging) return response #3GPP S6a/S6d Update Location Answer @@ -654,21 +668,21 @@ def Answer_16777251_316(self, packet_vars, avps): imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details - self.diameterLibLogger.debug("Got back subscriber_details: " + str(subscriber_details)) + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + self.logTool.log(service='HSS', level='debug', message="Got back subscriber_details: " + str(subscriber_details), redisClient=self.redisMessaging) except ValueError as e: - self.diameterLibLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - self.diameterLibLogger.error("Error is " + str(e)) - self.diameterLibLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") + self.logTool.log(service='HSS', level='error', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Error is " + str(e), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") + self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) return response except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.diameterLibLogger.critical(message) - self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.logTool.critical(message) + self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise #Store MME Location into Database @@ -676,7 +690,7 @@ def Answer_16777251_316(self, packet_vars, avps): OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - self.diameterLibLogger.debug("Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm)) + self.logTool.log(service='HSS', level='debug', message="Subscriber is served by MME " + str(OriginHost) + " at realm " + str(OriginRealm), redisClient=self.redisMessaging) #Find Remote Peer we need to address CLRs through try: #Check if we have a record-route set as that's where we'll need to send the response @@ -685,9 +699,9 @@ def Answer_16777251_316(self, packet_vars, avps): except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) - self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) + self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) - database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) + self.database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) #Boilerplate AVPs @@ -721,34 +735,34 @@ def Answer_16777251_316(self, packet_vars, avps): #Split the APN list into a list apn_list = subscriber_details['apn_list'].split(',') - self.diameterLibLogger.debug("Current APN List: " + str(apn_list)) + self.logTool.log(service='HSS', level='debug', message="Current APN List: " + str(apn_list), redisClient=self.redisMessaging) #Remove the default APN from the list try: apn_list.remove(str(subscriber_details['default_apn'])) except: - self.diameterLibLogger.debug("Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List") + self.logTool.log(service='HSS', level='debug', message="Failed to remove default APN (" + str(subscriber_details['default_apn']) + " from APN List", redisClient=self.redisMessaging) pass #Add default APN in first position apn_list.insert(0, str(subscriber_details['default_apn'])) - self.diameterLibLogger.debug("APN list: " + str(apn_list)) + self.logTool.log(service='HSS', level='debug', message="APN list: " + str(apn_list), redisClient=self.redisMessaging) APN_context_identifer_count = 1 for apn_id in apn_list: #Per APN Setup - self.diameterLibLogger.debug("Processing APN ID " + str(apn_id)) + self.logTool.log(service='HSS', level='debug', message="Processing APN ID " + str(apn_id), redisClient=self.redisMessaging) try: - apn_data = database.Get_APN(apn_id) + apn_data = self.database.Get_APN(apn_id) except: - self.diameterLibLogger.error("Failed to get APN " + str(apn_id)) + self.logTool.log(service='HSS', level='error', message="Failed to get APN " + str(apn_id), redisClient=self.redisMessaging) continue APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_data['apn']))) - self.diameterLibLogger.debug("Setting APN Configuration Profile") + self.logTool.log(service='HSS', level='debug', message="Setting APN Configuration Profile", redisClient=self.redisMessaging) #Sub AVPs of APN Configuration Profile APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(int(apn_data['ip_version']), 4)) - self.diameterLibLogger.debug("Setting APN AMBR") + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) #AMBR AMBR = '' #Initiate empty var AVP for AMBR apn_ambr_ul = int(apn_data['apn_ambr_ul']) @@ -757,7 +771,7 @@ def Answer_16777251_316(self, packet_vars, avps): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) @@ -768,25 +782,25 @@ def Answer_16777251_316(self, packet_vars, avps): #Try static IP allocation try: - subscriber_routing_dict = database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details - self.diameterLibLogger.info("Got static UE IP " + str(subscriber_routing_dict)) - self.diameterLibLogger.debug("Found static IP for UE " + str(subscriber_routing_dict['ip_address'])) + subscriber_routing_dict = self.database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details + self.logTool.log(service='HSS', level='info', message="Got static UE IP " + str(subscriber_routing_dict), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(subscriber_routing_dict['ip_address']), redisClient=self.redisMessaging) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) except Exception as E: - self.diameterLibLogger.debug("Error getting static UE IP: " + str(E)) + self.logTool.log(service='HSS', level='debug', message="Error getting static UE IP: " + str(E), redisClient=self.redisMessaging) Served_Party_Address = "" #if 'PDN_GW_Allocation_Type' in apn_profile: - # self.diameterLibLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) + # self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) # PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - # self.diameterLibLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) + # self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) # else: # PDN_GW_Allocation_Type = '' # if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - # self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) + # self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) # VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - # self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) + # self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) # else: # VPLMN_Dynamic_Address_Allowed = '' PDN_GW_Allocation_Type = '' @@ -794,7 +808,7 @@ def Answer_16777251_316(self, packet_vars, avps): #If static SMF / PGW-C defined if apn_data['pgw_address'] is not None: - self.diameterLibLogger.info("MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address'])) + self.logTool.log(service='HSS', level='info', message="MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address']), redisClient=self.redisMessaging) MIP_Home_Agent_Address = self.generate_avp(334, '40', self.ip_to_hex(apn_data['pgw_address'])) MIP6_Agent_Info = self.generate_avp(486, '40', MIP_Home_Agent_Address) else: @@ -807,40 +821,40 @@ def Answer_16777251_316(self, packet_vars, avps): #Incriment Context Identifier Count to keep track of how many APN Profiles returned APN_context_identifer_count = APN_context_identifer_count + 1 - self.diameterLibLogger.debug("Completed processing APN ID " + str(apn_id)) + self.logTool.log(service='HSS', level='debug', message="Completed processing APN ID " + str(apn_id), redisClient=self.redisMessaging) subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_Configuration_Profile + APN_Configuration) try: - self.diameterLibLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") + self.logTool.log(service='HSS', level='debug', message="MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA", redisClient=self.redisMessaging) msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, self.TBCD_encode(str(subscriber_details['msisdn']))) #MSISDN - self.diameterLibLogger.debug(msisdn_avp) + self.logTool.log(service='HSS', level='debug', message=msisdn_avp, redisClient=self.redisMessaging) subscription_data += msisdn_avp except Exception as E: - self.diameterLibLogger.error("Failed to populate MSISDN in ULA due to error " + str(E)) + self.logTool.log(service='HSS', level='error', message="Failed to populate MSISDN in ULA due to error " + str(E), redisClient=self.redisMessaging) if 'RAT_freq_priorityID' in subscriber_details: - self.diameterLibLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") + self.logTool.log(service='HSS', level='debug', message="RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA", redisClient=self.redisMessaging) rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - self.diameterLibLogger.debug("Adding rat_freq_priorityID: " + str(rat_freq_priorityID)) + self.logTool.log(service='HSS', level='debug', message="Adding rat_freq_priorityID: " + str(rat_freq_priorityID), redisClient=self.redisMessaging) subscription_data += rat_freq_priorityID if 'charging_characteristics' in subscriber_details: - self.diameterLibLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA") + self.logTool.log(service='HSS', level='debug', message="3gpp-charging-characteristics " + str(subscriber_details['charging_characteristics']) + " - Adding in ULA", redisClient=self.redisMessaging) _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, str(subscriber_details['charging_characteristics'])) subscription_data += _3gpp_charging_characteristics - self.diameterLibLogger.debug("Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics)) + self.logTool.log(service='HSS', level='debug', message="Adding _3gpp_charging_characteristics: " + str(_3gpp_charging_characteristics), redisClient=self.redisMessaging) #ToDo - Fix this # if 'APN_OI_replacement' in subscriber_details: - # self.diameterLibLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") + # self.logTool.log(service='HSS', level='debug', message="APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA", redisClient=self.redisMessaging) # subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) avp += self.generate_vendor_avp(1400, "c0", 10415, subscription_data) #Subscription-Data response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated ULA") + self.logTool.log(service='HSS', level='debug', message="Successfully Generated ULA", redisClient=self.redisMessaging) return response #3GPP S6a/S6d Authentication Information Answer @@ -850,20 +864,22 @@ def Answer_16777251_318(self, packet_vars, avps): plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from User-Name AVP in request try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details except ValueError as e: - self.diameterLibLogger.info("Minor getting subscriber details for IMSI " + str(imsi)) - self.diameterLibLogger.info(e) + self.logTool.log(service='HSS', level='info', message="Minor getting subscriber details for IMSI " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message=e, redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777251, - # diameter_cmd_code = 318, - # event='Unknown User', - # imsi_prefix = str(imsi[0:6]), - # ).inc() - - self.diameterLibLogger.info("Subscriber " + str(imsi) + " is unknown in database") + self.logTool.log(service='HSS', level='info', message="Subscriber " + str(imsi) + " is unknown in database", redisClient=self.redisMessaging) avp = '' session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set @@ -883,8 +899,8 @@ def Answer_16777251_318(self, packet_vars, avps): except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.diameterLibLogger.critical(message) - self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.logTool.critical(message) + self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise @@ -892,40 +908,43 @@ def Answer_16777251_318(self, packet_vars, avps): requested_vectors = 1 for avp in avps: if avp['avp_code'] == 1408: - self.diameterLibLogger.debug("AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP") + self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) EUTRAN_Authentication_Info = avp['misc_data'] - self.diameterLibLogger.debug("EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info)) + self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) for sub_avp in EUTRAN_Authentication_Info: #If resync request if sub_avp['avp_code'] == 1411: - self.diameterLibLogger.debug("Re-Synchronization required - SQN is out of sync") - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777251, - # diameter_cmd_code = 318, - # event='Resync', - # imsi_prefix = str(imsi[0:6]), - # ).inc() + self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Resync", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) auts = str(sub_avp['misc_data'])[32:] rand = str(sub_avp['misc_data'])[:32] rand = binascii.unhexlify(rand) #Calculate correct SQN - database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) #Get number of requested vectors if sub_avp['avp_code'] == 1410: - self.diameterLibLogger.debug("Raw value of requested vectors is " + str(sub_avp['misc_data'])) + self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) requested_vectors = int(sub_avp['misc_data'], 16) if requested_vectors >= 32: - self.diameterLibLogger.info("Client has requested " + str(requested_vectors) + " vectors, limiting this to 32") + self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) requested_vectors = 32 - self.diameterLibLogger.debug("Generating " + str(requested_vectors) + " vectors as requested") + self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) eutranvector_complete = '' while requested_vectors != 0: - self.diameterLibLogger.debug("Generating vector number " + str(requested_vectors)) + self.logTool.log(service='HSS', level='debug', message="Generating vector number " + str(requested_vectors), redisClient=self.redisMessaging) plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) eutranvector += self.generate_vendor_avp(1419, "c0", 10415, self.int_to_hex(requested_vectors, 4)) eutranvector += self.generate_vendor_avp(1447, "c0", 10415, vector_dict['rand']) #And is made up of other AVPs joined together with RAND @@ -948,8 +967,8 @@ def Answer_16777251_318(self, packet_vars, avps): #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated AIA") - self.diameterLibLogger.debug(response) + self.logTool.log(service='HSS', level='debug', message="Successfully Generated AIA", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=response, redisClient=self.redisMessaging) return response #Purge UE Answer (PUA) @@ -982,8 +1001,8 @@ def Answer_16777251_321(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 321, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - database.Update_Serving_MME(imsi, None) - self.diameterLibLogger.debug("Successfully Generated PUA") + self.database.Update_Serving_MME(imsi, None) + self.logTool.log(service='HSS', level='debug', message="Successfully Generated PUA", redisClient=self.redisMessaging) return response #Notify Answer (NOA) @@ -1004,7 +1023,7 @@ def Answer_16777251_323(self, packet_vars, avps): SupportedFeatures += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP response = self.generate_diameter_packet("01", "40", 323, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.debug("Successfully Generated PUA") + self.logTool.log(service='HSS', level='debug', message="Successfully Generated PUA", redisClient=self.redisMessaging) return response #3GPP Gx Credit Control Answer @@ -1012,9 +1031,9 @@ def Answer_16777238_272(self, packet_vars, avps): CC_Request_Type = self.get_avp_data(avps, 416)[0] CC_Request_Number = self.get_avp_data(avps, 415)[0] #Called Station ID - self.diameterLibLogger.debug("Attempting to find APN in CCR") + self.logTool.log(service='HSS', level='debug', message="Attempting to find APN in CCR", redisClient=self.redisMessaging) apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') - self.diameterLibLogger.debug("CCR for APN " + str(apn)) + self.logTool.log(service='HSS', level='debug', message="CCR for APN " + str(apn), redisClient=self.redisMessaging) OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it @@ -1027,7 +1046,7 @@ def Answer_16777238_272(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) + self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) avp = '' #Initiate empty var AVP @@ -1041,37 +1060,37 @@ def Answer_16777238_272(self, packet_vars, avps): #Get Subscriber info from Subscription ID for SubscriptionIdentifier in self.get_avp_data(avps, 443): for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - self.diameterLibLogger.debug("Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI") + self.logTool.log(service='HSS', level='debug', message="Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) if UniqueSubscriptionIdentifier['avp_code'] == 444: imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - self.diameterLibLogger.debug("Found IMSI " + str(imsi)) + self.logTool.log(service='HSS', level='debug', message="Found IMSI " + str(imsi), redisClient=self.redisMessaging) - self.diameterLibLogger.info("SubscriptionID: " + str(self.get_avp_data(avps, 443))) + self.logTool.log(service='HSS', level='info', message="SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) try: - self.diameterLibLogger.info("Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database") #Get subscriber details - ChargingRules = database.Get_Charging_Rules(imsi=imsi, apn=apn) - self.diameterLibLogger.info("Got Charging Rules: " + str(ChargingRules)) + self.logTool.log(service='HSS', level='info', message="Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details + ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) + self.logTool.log(service='HSS', level='info', message="Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) except Exception as E: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - self.diameterLibLogger.debug(E) - self.diameterLibLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists") + self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) if int(CC_Request_Type) == 1: - self.diameterLibLogger.info("Request type for CCA is 1 - Initial") + self.logTool.log(service='HSS', level='info', message="Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) #Get UE IP try: ue_ip = self.get_avp_data(avps, 8)[0] ue_ip = str(self.hex_to_ip(ue_ip)) except Exception as E: - self.diameterLibLogger.error("Failed to get UE IP") - self.diameterLibLogger.error(E) + self.logTool.log(service='HSS', level='error', message="Failed to get UE IP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) ue_ip = 'Failed to Decode / Get UE IP' #Store PGW location into Database remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) - database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) #Supported-Features(628) (Gx feature list) avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") @@ -1079,7 +1098,7 @@ def Answer_16777238_272(self, packet_vars, avps): #Default EPS Beaerer QoS (From database with fallback source CCR-I) try: apn_data = ChargingRules['apn_data'] - self.diameterLibLogger.debug("Setting APN AMBR") + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) #AMBR AMBR = '' #Initiate empty var AVP for AMBR apn_ambr_ul = int(apn_data['apn_ambr_ul']) @@ -1088,7 +1107,7 @@ def Answer_16777238_272(self, packet_vars, avps): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) @@ -1097,13 +1116,13 @@ def Answer_16777238_272(self, packet_vars, avps): AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) except Exception as E: - self.diameterLibLogger.error(E) - self.diameterLibLogger.error("Failed to populate default_EPS_QoS from DB for sub " + str(imsi)) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - self.diameterLibLogger.info("Creating QoS Information") + self.logTool.log(service='HSS', level='info', message="Creating QoS Information", redisClient=self.redisMessaging) #QoS-Information try: apn_data = ChargingRules['apn_data'] @@ -1111,39 +1130,39 @@ def Answer_16777238_272(self, packet_vars, avps): apn_ambr_dl = int(apn_data['apn_ambr_dl']) QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - self.diameterLibLogger.info("Created both QoS AVPs from data from Database") - self.diameterLibLogger.info("Populated QoS_Information") + self.logTool.log(service='HSS', level='info', message="Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Populated QoS_Information", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) except Exception as E: - self.diameterLibLogger.error("Failed to get QoS information dynamically for sub " + str(imsi)) - self.diameterLibLogger.error(E) + self.logTool.log(service='HSS', level='error', message="Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) QoS_Information = '' for AMBR_Part in self.get_avp_data(avps, 1016)[0]: - self.diameterLibLogger.debug(AMBR_Part) + self.logTool.log(service='HSS', level='debug', message=AMBR_Part, redisClient=self.redisMessaging) AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) QoS_Information += AMBR_AVP - self.diameterLibLogger.debug("QoS_Information added " + str(AMBR_AVP)) + self.logTool.log(service='HSS', level='debug', message="QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - self.diameterLibLogger.debug("QoS information set statically") + self.logTool.log(service='HSS', level='debug', message="QoS information set statically", redisClient=self.redisMessaging) - self.diameterLibLogger.info("Added to AVP List") - self.diameterLibLogger.debug("QoS Information: " + str(QoS_Information)) + self.logTool.log(service='HSS', level='info', message="Added to AVP List", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I if ChargingRules and ChargingRules['charging_rules'] is not None: try: - self.diameterLibLogger.debug(ChargingRules) + self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) for individual_charging_rule in ChargingRules['charging_rules']: - self.diameterLibLogger.debug("Processing Charging Rule: " + str(individual_charging_rule)) + self.logTool.log(service='HSS', level='debug', message="Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) except Exception as E: - self.diameterLibLogger.debug("Error in populating dynamic charging rules: " + str(E)) + self.logTool.log(service='HSS', level='debug', message="Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) elif int(CC_Request_Type) == 3: - self.diameterLibLogger.info("Request type for CCA is 3 - Termination") - database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) + self.logTool.log(service='HSS', level='info', message="Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -1173,27 +1192,30 @@ def Answer_16777216_300(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) + self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: - self.diameterLibLogger.info("Checking if username present") + self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 1)[0] username = binascii.unhexlify(username).decode('utf-8') - self.diameterLibLogger.info("Username AVP is present, value is " + str(username)) + self.logTool.log(service='HSS', level='info', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part - self.diameterLibLogger.debug("Extracted imsi: " + str(imsi) + " now checking backend for this IMSI") - ims_subscriber_details = database.Get_IMS_Subscriber(imsi=imsi) + self.logTool.log(service='HSS', level='debug', message="Extracted imsi: " + str(imsi) + " now checking backend for this IMSI", redisClient=self.redisMessaging) + ims_subscriber_details = self.database.Get_IMS_Subscriber(imsi=imsi) except Exception as E: - self.diameterLibLogger.error("Threw Exception: " + str(E)) - self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_300() input") - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777216, - # diameter_cmd_code = 300, - # event='Unknown User', - # imsi_prefix = str(imsi[0:6]), - # ).inc() + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_300() input", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 300, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) result_code = 5001 #IMS User Unknown #Experimental Result AVP avp_experimental_result = '' @@ -1208,10 +1230,10 @@ def Answer_16777216_300(self, packet_vars, avps): if user_authorization_type_avp_data: try: User_Authorization_Type = int(user_authorization_type_avp_data[0]) - self.diameterLibLogger.debug("User_Authorization_Type is: " + str(User_Authorization_Type)) + self.logTool.log(service='HSS', level='debug', message="User_Authorization_Type is: " + str(User_Authorization_Type), redisClient=self.redisMessaging) if (User_Authorization_Type == 1): - self.diameterLibLogger.debug("This is Deregister") - database.Update_Serving_CSCF(imsi, serving_cscf=None) + self.logTool.log(service='HSS', level='debug', message="This is Deregister", redisClient=self.redisMessaging) + self.database.Update_Serving_CSCF(imsi, serving_cscf=None) #Populate S-CSCF Address avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) @@ -1219,28 +1241,28 @@ def Answer_16777216_300(self, packet_vars, avps): return response except Exception as E: - self.diameterLibLogger.debug("Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E)) - self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Failed to get User_Authorization_Type AVP & Update_Serving_CSCF error: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) if ims_subscriber_details['scscf'] != None: - self.diameterLibLogger.debug("Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf'])) + self.logTool.log(service='HSS', level='debug', message="Already has SCSCF Assigned from DB: " + str(ims_subscriber_details['scscf']), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(ims_subscriber_details['scscf'])),'ascii')) experimental_avp = '' experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2002),"x").zfill(8)) #DIAMETER_SUBSEQUENT_REGISTRATION (2002) avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result else: - self.diameterLibLogger.debug("No SCSCF Assigned from DB") + self.logTool.log(service='HSS', level='debug', message="No SCSCF Assigned from DB", redisClient=self.redisMessaging) if 'scscf_pool' in self.yaml_config['hss']: try: scscf = random.choice(self.yaml_config['hss']['scscf_pool']) - self.diameterLibLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") + self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.diameterLibLogger.info("Using generated S-CSCF Address as failed to source from list due to " + str(E)) + self.logTool.log(service='HSS', level='info', message="Using generated S-CSCF Address as failed to source from list due to " + str(E), redisClient=self.redisMessaging) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.diameterLibLogger.info("Using generated S-CSCF Address as none set in scscf_pool in config") + self.logTool.log(service='HSS', level='info', message="Using generated S-CSCF Address as none set in scscf_pool in config", redisClient=self.redisMessaging) experimental_avp = '' experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2001),"x").zfill(8)) #DIAMETER_FIRST_REGISTRATION (2001) @@ -1273,18 +1295,18 @@ def Answer_16777216_301(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.diameterLibLogger.debug("Remote Peer is " + str(remote_peer)) + self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: - self.diameterLibLogger.info("Checking if username present") + self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) - self.diameterLibLogger.debug("Got subscriber details: " + str(ims_subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) imsi = ims_subscriber_details['imsi'] domain = "ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org" except Exception as E: - self.diameterLibLogger.error("Threw Exception: " + str(E)) - self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_301() input") + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_301() input", redisClient=self.redisMessaging) result_code = 5005 #Experimental Result AVP avp_experimental_result = '' @@ -1300,7 +1322,7 @@ def Answer_16777216_301(self, packet_vars, avps): #This loads a Jinja XML template as the default iFC templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - self.diameterLibLogger.debug("Loading iFC from path " + str(ims_subscriber_details['ifc_path'])) + self.logTool.log(service='HSS', level='debug', message="Loading iFC from path " + str(ims_subscriber_details['ifc_path']), redisClient=self.redisMessaging) template = templateEnv.get_template(ims_subscriber_details['ifc_path']) #These variables are passed to the template for use @@ -1317,17 +1339,17 @@ def Answer_16777216_301(self, packet_vars, avps): #Determine SAR Type & Store Server_Assignment_Type_Hex = self.get_avp_data(avps, 614)[0] Server_Assignment_Type = self.hex_to_int(Server_Assignment_Type_Hex) - self.diameterLibLogger.debug("Server-Assignment-Type is: " + str(Server_Assignment_Type)) + self.logTool.log(service='HSS', level='debug', message="Server-Assignment-Type is: " + str(Server_Assignment_Type), redisClient=self.redisMessaging) ServingCSCF = self.get_avp_data(avps, 602)[0] #Get OriginHost from AVP ServingCSCF = binascii.unhexlify(ServingCSCF).decode('utf-8') #Format it - self.diameterLibLogger.debug("Subscriber is served by S-CSCF " + str(ServingCSCF)) + self.logTool.log(service='HSS', level='debug', message="Subscriber is served by S-CSCF " + str(ServingCSCF), redisClient=self.redisMessaging) if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): - self.diameterLibLogger.debug("SAR is Register / Re-Restister") + self.logTool.log(service='HSS', level='debug', message="SAR is Register / Re-Restister", redisClient=self.redisMessaging) remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) - database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) + self.database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) else: - self.diameterLibLogger.debug("SAR is not Register") - database.Update_Serving_CSCF(imsi, serving_cscf=None) + self.logTool.log(service='HSS', level='debug', message="SAR is not Register", redisClient=self.redisMessaging) + self.database.Update_Serving_CSCF(imsi, serving_cscf=None) avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) @@ -1347,37 +1369,40 @@ def Answer_16777216_302(self, packet_vars, avps): try: - self.diameterLibLogger.info("Checking if username present") + self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) if ims_subscriber_details['scscf'] != None: - self.diameterLibLogger.debug("Got SCSCF on record for Sub") + self.logTool.log(service='HSS', level='debug', message="Got SCSCF on record for Sub", redisClient=self.redisMessaging) #Strip double sip prefix avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(str(ims_subscriber_details['scscf']))),'ascii')) else: - self.diameterLibLogger.debug("No SCSF assigned - Using SCSCF Pool") + self.logTool.log(service='HSS', level='debug', message="No SCSF assigned - Using SCSCF Pool", redisClient=self.redisMessaging) if 'scscf_pool' in self.yaml_config['hss']: try: scscf = random.choice(self.yaml_config['hss']['scscf_pool']) - self.diameterLibLogger.debug("Randomly picked SCSCF address " + str(scscf) + " from pool") + self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.diameterLibLogger.info("Using generated iFC as failed to source from list due to " + str(E)) + self.logTool.log(service='HSS', level='info', message="Using generated iFC as failed to source from list due to " + str(E), redisClient=self.redisMessaging) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.diameterLibLogger.info("Using generated iFC") + self.logTool.log(service='HSS', level='info', message="Using generated iFC", redisClient=self.redisMessaging) except Exception as E: - self.diameterLibLogger.error("Threw Exception: " + str(E)) - self.diameterLibLogger.error("No known MSISDN or IMSI in Answer_16777216_302() input") + self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_302() input", redisClient=self.redisMessaging) result_code = 5001 - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777216, - # diameter_cmd_code = 302, - # event='Unknown User', - # imsi_prefix = str(username[0:6]), - # ).inc() + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 302, + "event": "Unknown User", + "imsi_prefix": str(username[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) #Experimental Result AVP avp_experimental_result = '' avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID @@ -1395,12 +1420,12 @@ def Answer_16777216_302(self, packet_vars, avps): def Answer_16777216_303(self, packet_vars, avps): public_identity = self.get_avp_data(avps, 601)[0] public_identity = binascii.unhexlify(public_identity).decode('utf-8') - self.diameterLibLogger.debug("Got MAR for public_identity : " + str(public_identity)) + self.logTool.log(service='HSS', level='debug', message="Got MAR for public_identity : " + str(public_identity), redisClient=self.redisMessaging) username = self.get_avp_data(avps, 1)[0] username = binascii.unhexlify(username).decode('utf-8') imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part - self.diameterLibLogger.debug("Got MAR username: " + str(username)) + self.logTool.log(service='HSS', level='debug', message="Got MAR username: " + str(username), redisClient=self.redisMessaging) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1411,17 +1436,20 @@ def Answer_16777216_303(self, packet_vars, avps): avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details except: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - self.diameterLibLogger.debug("Subscriber " + str(imsi) + " unknown in HSS for MAA") - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777216, - # diameter_cmd_code = 303, - # event='Unknown User', - # imsi_prefix = str(username[0:6]), - # ).inc() + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for MAA", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 303, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #Result Code (DIAMETER ERROR - User Unknown) experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") #Experimental Result (297) @@ -1429,7 +1457,7 @@ def Answer_16777216_303(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 303, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response - self.diameterLibLogger.debug("Got subscriber data for MAA OK") + self.logTool.log(service='HSS', level='debug', message="Got subscriber data for MAA OK", redisClient=self.redisMessaging) mcc, mnc = imsi[0:3], imsi[3:5] plmn = self.EncodePLMN(mcc, mnc) @@ -1437,33 +1465,36 @@ def Answer_16777216_303(self, packet_vars, avps): #Determine if SQN Resync is required & auth type to use for sub_avp_612 in self.get_avp_data(avps, 612)[0]: if sub_avp_612['avp_code'] == 610: - self.diameterLibLogger.info("SQN in HSS is out of sync - Performing resync") + self.logTool.log(service='HSS', level='info', message="SQN in HSS is out of sync - Performing resync", redisClient=self.redisMessaging) auts = str(sub_avp_612['misc_data'])[32:] rand = str(sub_avp_612['misc_data'])[:32] rand = binascii.unhexlify(rand) - database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - self.diameterLibLogger.debug("Resynced SQN in DB") - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777216, - # diameter_cmd_code = 302, - # event='ReAuth', - # imsi_prefix = str(imsi[0:6]), - # ).inc() + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + self.logTool.log(service='HSS', level='debug', message="Resynced SQN in DB", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 302, + "event": "ReAuth", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) if sub_avp_612['avp_code'] == 608: - self.diameterLibLogger.info("Auth mechansim requested: " + str(sub_avp_612['misc_data'])) + self.logTool.log(service='HSS', level='info', message="Auth mechansim requested: " + str(sub_avp_612['misc_data']), redisClient=self.redisMessaging) auth_scheme = binascii.unhexlify(sub_avp_612['misc_data']).decode('utf-8') - self.diameterLibLogger.info("Auth mechansim requested: " + str(auth_scheme)) + self.logTool.log(service='HSS', level='info', message="Auth mechansim requested: " + str(auth_scheme), redisClient=self.redisMessaging) - self.diameterLibLogger.debug("IMSI is " + str(imsi)) + self.logTool.log(service='HSS', level='debug', message="IMSI is " + str(imsi), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi + "@" + domain)),'ascii')) #Username #Determine Vectors to Generate if auth_scheme == "Digest-MD5": - self.diameterLibLogger.debug("Generating MD5 Challenge") - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "Digest-MD5", username=imsi, plmn=plmn) + self.logTool.log(service='HSS', level='debug', message="Generating MD5 Challenge", redisClient=self.redisMessaging) + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "Digest-MD5", username=imsi, plmn=plmn) avp_SIP_Item_Number = self.generate_vendor_avp(613, "c0", 10415, format(int(0),"x").zfill(8)) avp_SIP_Authentication_Scheme = self.generate_vendor_avp(608, "c0", 10415, str(binascii.hexlify(b'Digest-MD5'),'ascii')) #Nonce @@ -1472,8 +1503,8 @@ def Answer_16777216_303(self, packet_vars, avps): avp_SIP_Authorization = self.generate_vendor_avp(610, "c0", 10415, str(binascii.hexlify(str.encode(vector_dict['SIP_Authenticate'])),'ascii')) auth_data_item = avp_SIP_Item_Number + avp_SIP_Authentication_Scheme + avp_SIP_Authenticate + avp_SIP_Authorization else: - self.diameterLibLogger.debug("Generating AKA-MD5 Auth Challenge") - vector_dict = database.Get_Vectors_AuC(subscriber_details['auc_id'], "sip_auth", plmn=plmn) + self.logTool.log(service='HSS', level='debug', message="Generating AKA-MD5 Auth Challenge", redisClient=self.redisMessaging) + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sip_auth", plmn=plmn) #diameter.3GPP-SIP-Auth-Data-Items: @@ -1504,7 +1535,7 @@ def Answer_16777216_303(self, packet_vars, avps): #Generate a Generic error handler with Result Code as input def Respond_ResultCode(self, packet_vars, avps, result_code): - logging.error("Responding with result code " + str(result_code) + " to request with command code " + str(packet_vars['command_code'])) + self.logTool.log(service='HSS', level='error', message="Responding with result code " + str(result_code) + " to request with command code " + str(packet_vars['command_code']), redisClient=self.redisMessaging) avp = '' #Initiate empty var AVP avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -1512,7 +1543,7 @@ def Respond_ResultCode(self, packet_vars, avps, result_code): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID except: - self.diameterLibLogger.info("Failed to add SessionID into error") + self.logTool.log(service='HSS', level='info', message="Failed to add SessionID into error", redisClient=self.redisMessaging) for avps_to_check in avps: #Only include AVP 260 (Vendor-Specific-Application-ID) if inital request included it if avps_to_check['avp_code'] == 260: concat_subavp = '' @@ -1536,9 +1567,9 @@ def Answer_16777216_304(self, packet_vars, avps): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID vendor_id = self.generate_avp(266, 40, str(binascii.hexlify('10415'),'ascii')) - self.diameterLibLogger.debug("vendor_id avp: " + str(vendor_id)) + self.logTool.log(service='HSS', level='debug', message="vendor_id avp: " + str(vendor_id), redisClient=self.redisMessaging) auth_application_id = self.generate_avp(248, 40, self.int_to_hex(16777252, 8)) - self.diameterLibLogger.debug("auth_application_id: " + auth_application_id) + self.logTool.log(service='HSS', level='debug', message="auth_application_id: " + auth_application_id, redisClient=self.redisMessaging) avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx avp += self.generate_avp(268, 40, "000007d1") #Result Code - DIAMETER_SUCCESS avp += self.generate_avp(277, 40, "00000001") #Auth Session State @@ -1564,30 +1595,37 @@ def Answer_16777217_306(self, packet_vars, avps): try: user_identity_avp = self.get_avp_data(avps, 700)[0] msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request - self.diameterLibLogger.info("Got raw MSISDN with value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got raw MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) msisdn = self.TBCD_decode(msisdn) - self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) except: - self.diameterLibLogger.error("No MSISDN") + self.logTool.log(service='HSS', level='error', message="No MSISDN", redisClient=self.redisMessaging) + try: + username = self.get_avp_data(avps, 601)[0] + except: + self.logTool.log(service='HSS', level='error', message="No Username", redisClient=self.redisMessaging) if msisdn is not None: - self.diameterLibLogger.debug("Getting susbcriber IMS info based on MSISDN") - subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - self.diameterLibLogger.debug("Getting susbcriber info based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber details: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber IMS info based on MSISDN", redisClient=self.redisMessaging) + subscriber_ims_details = self.database.Get_IMS_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber IMS details: " + str(subscriber_ims_details), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber info based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) subscriber_details = {**subscriber_details, **subscriber_ims_details} - self.diameterLibLogger.debug("Merged subscriber details: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Merged subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) else: - self.diameterLibLogger.error("No MSISDN or IMSI in Answer_16777217_306() input") - #@@Fixme - # prom_diam_auth_event_count.labels( - # diameter_application_id = 16777216, - # diameter_cmd_code = 306, - # event='Unknown User', - # imsi_prefix = str(username[0:6]), - # ).inc() + self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI in Answer_16777217_306() input", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777216, + "diameter_cmd_code": 306, + "event": "Unknown User", + "imsi_prefix": str(username[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) result_code = 5005 #Experimental Result AVP avp_experimental_result = '' @@ -1610,13 +1648,13 @@ def Answer_16777217_306(self, packet_vars, avps): templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] - self.diameterLibLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") + self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) subscriber_details['mcc'] = self.MCC.zfill(3) - self.diameterLibLogger.debug("Rendering template with values: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Rendering template with values: " + str(subscriber_details), redisClient=self.redisMessaging) xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) @@ -1638,12 +1676,12 @@ def Answer_16777217_307(self, packet_vars, avps): sh_user_data = self.get_avp_data(avps, 702)[0] #Get IMSI from User-Name AVP in request sh_user_data = binascii.unhexlify(sh_user_data).decode('utf-8') - self.diameterLibLogger.debug("Got Sh User data: " + str(sh_user_data)) + self.logTool.log(service='HSS', level='debug', message="Got Sh User data: " + str(sh_user_data), redisClient=self.redisMessaging) #Push updated User Data into IMS Backend #Start with the Current User Data - subscriber_ims_details = database.Get_IMS_Subscriber(imsi=imsi) - database.UpdateObj(database.IMS_SUBSCRIBER, {'sh_profile': sh_user_data}, subscriber_ims_details['ims_subscriber_id']) + subscriber_ims_details = self.database.Get_IMS_Subscriber(imsi=imsi) + self.database.UpdateObj(self.database.IMS_SUBSCRIBER, {'sh_profile': sh_user_data}, subscriber_ims_details['ims_subscriber_id']) avp = '' #Initiate empty var AVP #Session-ID session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1669,17 +1707,17 @@ def Answer_16777252_324(self, packet_vars, avps): imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.diameterLibLogger.info("Got IMSI with value " + str(imsi)) + self.logTool.log(service='HSS', level='info', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) except Exception as e: - self.diameterLibLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - self.diameterLibLogger.debug("Error was: " + str(e)) + self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) #Get IMEI for sub_avp in self.get_avp_data(avps, 1401)[0]: - self.diameterLibLogger.debug("Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI") + self.logTool.log(service='HSS', level='debug', message="Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI", redisClient=self.redisMessaging) if sub_avp['avp_code'] == 1402: imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') - self.diameterLibLogger.debug("Found IMEI " + str(imei)) + self.logTool.log(service='HSS', level='debug', message="Found IMEI " + str(imei), redisClient=self.redisMessaging) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1695,10 +1733,15 @@ def Answer_16777252_324(self, packet_vars, avps): avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) #Equipment-Status - EquipmentStatus = database.Check_EIR(imsi=imsi, imei=imei) + EquipmentStatus = self.database.Check_EIR(imsi=imsi, imei=imei) avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) - # @@Fixme - # prom_diam_eir_event_count.labels(response=EquipmentStatus).inc() + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_eir_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "response": EquipmentStatus}, + metricHelp='Diameter EIR event related Counters', + metricExpiry=60) response = self.generate_diameter_packet("01", "40", 324, 16777252, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response @@ -1728,56 +1771,56 @@ def Answer_16777291_8388622(self, packet_vars, avps): #Try and get IMSI if present if 1 in present_avps: - self.diameterLibLogger.info("IMSI AVP is present") + self.logTool.log(service='HSS', level='info', message="IMSI AVP is present", redisClient=self.redisMessaging) try: imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.diameterLibLogger.info("Got IMSI with value " + str(imsi)) + self.logTool.log(service='HSS', level='info', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) except Exception as e: - self.diameterLibLogger.debug("Failed to get IMSI from LCS-Routing-Info-Request") - self.diameterLibLogger.debug("Error was: " + str(e)) + self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) elif 701 in present_avps: #Try and get MSISDN if present try: msisdn = self.get_avp_data(avps, 701)[0] #Get MSISDN from AVP in request - self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(701, 'c0', 10415, self.get_avp_data(avps, 701)[0]) #MSISDN - self.diameterLibLogger.info("Got MSISDN with encoded value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got MSISDN with encoded value " + str(msisdn), redisClient=self.redisMessaging) msisdn = self.TBCD_decode(msisdn) - self.diameterLibLogger.info("Got MSISDN with decoded value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got MSISDN with decoded value " + str(msisdn), redisClient=self.redisMessaging) except Exception as e: - self.diameterLibLogger.debug("Failed to get MSISDN from LCS-Routing-Info-Request") - self.diameterLibLogger.debug("Error was: " + str(e)) + self.logTool.log(service='HSS', level='debug', message="Failed to get MSISDN from LCS-Routing-Info-Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) else: - self.diameterLibLogger.error("No MSISDN or IMSI") + self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI", redisClient=self.redisMessaging) try: if imsi is not None: - self.diameterLibLogger.debug("Getting susbcriber location based on IMSI") - subscriber_details = database.Get_Subscriber(imsi=imsi) - self.diameterLibLogger.debug("Got subscriber_details from IMSI: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber location based on IMSI", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(imsi=imsi) + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details from IMSI: " + str(subscriber_details), redisClient=self.redisMessaging) elif msisdn is not None: - self.diameterLibLogger.debug("Getting susbcriber location based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber_details from MSISDN: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber location based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details from MSISDN: " + str(subscriber_details), redisClient=self.redisMessaging) except Exception as E: - self.diameterLibLogger.error("No MSISDN or IMSI returned in Answer_16777291_8388622 input") - self.diameterLibLogger.error("Error is " + str(E)) - self.diameterLibLogger.error("Responding with DIAMETER_ERROR_USER_UNKNOWN") + self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI returned in Answer_16777291_8388622 input", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Error is " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.diameterLibLogger.info("Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN") + self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) return response - self.diameterLibLogger.info("Got subscriber_details for subscriber: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='info', message="Got subscriber_details for subscriber: " + str(subscriber_details), redisClient=self.redisMessaging) if subscriber_details['serving_mme'] == None: #DB has no location on record for subscriber - self.diameterLibLogger.info("No location on record for Subscriber") + self.logTool.log(service='HSS', level='info', message="No location on record for Subscriber", redisClient=self.redisMessaging) result_code = 4201 #DIAMETER_ERROR_ABSENT_USER (4201) #This result code shall be sent by the HSS to indicate that the location of the targeted user is not known at this time to @@ -1957,7 +2000,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): SupportedFeatures += self.generate_vendor_avp(629, 80, 10415, self.int_to_hex(1, 4)) #Feature-List ID SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags if 'GetLocation' in kwargs: - self.diameterLibLogger.debug("Requsted Get Location ISD") + self.logTool.log(service='HSS', level='debug', message="Requsted Get Location ISD", redisClient=self.redisMessaging) #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP SupportedFeatures = '' SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID @@ -1968,23 +2011,23 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): try: user_identity_avp = self.get_avp_data(avps, 700)[0] - self.diameterLibLogger.info(user_identity_avp) + self.logTool.log(service='HSS', level='info', message=user_identity_avp, redisClient=self.redisMessaging) msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request msisdn = self.TBCD_decode(msisdn) - self.diameterLibLogger.info("Got MSISDN with value " + str(msisdn)) + self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) except: - self.diameterLibLogger.error("No MSISDN present") + self.logTool.log(service='HSS', level='error', message="No MSISDN present", redisClient=self.redisMessaging) return #Get Subscriber Location from Database - subscriber_location = database.GetSubscriberLocation(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber location: " + subscriber_location) + subscriber_location = self.database.GetSubscriberLocation(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber location: " + subscriber_location, redisClient=self.redisMessaging) - self.diameterLibLogger.info("Getting IMSI for MSISDN " + str(msisdn)) - imsi = database.Get_IMSI_from_MSISDN(msisdn) + self.logTool.log(service='HSS', level='info', message="Getting IMSI for MSISDN " + str(msisdn), redisClient=self.redisMessaging) + imsi = self.database.Get_IMSI_from_MSISDN(msisdn) avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.diameterLibLogger.info("Got back location data: " + str(subscriber_location)) + self.logTool.log(service='HSS', level='info', message="Got back location data: " + str(subscriber_location), redisClient=self.redisMessaging) #Populate Destination Host & Realm avp += self.generate_avp(293, 40, self.string_to_hex(subscriber_location)) #Destination Host #Destination-Host @@ -2000,26 +2043,26 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): destinationHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP destinationHost = binascii.unhexlify(destinationHost).decode('utf-8') #Format it - self.diameterLibLogger.debug("Received originHost to use as destinationHost is " + str(destinationHost)) + self.logTool.log(service='HSS', level='debug', message="Received originHost to use as destinationHost is " + str(destinationHost), redisClient=self.redisMessaging) destinationRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP destinationRealm = binascii.unhexlify(destinationRealm).decode('utf-8') #Format it - self.diameterLibLogger.debug("Received originRealm to use as destinationRealm is " + str(destinationRealm)) + self.logTool.log(service='HSS', level='debug', message="Received originRealm to use as destinationRealm is " + str(destinationRealm), redisClient=self.redisMessaging) avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination-Host avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) APN_Configuration = '' try: - subscriber_details = database.Get_Subscriber(imsi=imsi) #Get subscriber details + subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details except ValueError as e: - self.diameterLibLogger.error("failed to get data backfrom database for imsi " + str(imsi)) - self.diameterLibLogger.error("Error is " + str(e)) + self.logTool.log(service='HSS', level='error', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Error is " + str(e), redisClient=self.redisMessaging) raise except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.diameterLibLogger.critical(message) - self.diameterLibLogger.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) + self.logTool.critical(message) + self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise @@ -2056,18 +2099,18 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): apn_list = subscriber_details['pdn'] - self.diameterLibLogger.debug("APN list: " + str(apn_list)) + self.logTool.log(service='HSS', level='debug', message="APN list: " + str(apn_list), redisClient=self.redisMessaging) APN_context_identifer_count = 1 for apn_profile in apn_list: - self.diameterLibLogger.debug("Processing APN profile " + str(apn_profile)) + self.logTool.log(service='HSS', level='debug', message="Processing APN profile " + str(apn_profile), redisClient=self.redisMessaging) APN_Service_Selection = self.generate_avp(493, "40", self.string_to_hex(str(apn_profile['apn']))) - self.diameterLibLogger.debug("Setting APN Configuration Profile") + self.logTool.log(service='HSS', level='debug', message="Setting APN Configuration Profile", redisClient=self.redisMessaging) #Sub AVPs of APN Configuration Profile APN_context_identifer = self.generate_vendor_avp(1423, "c0", 10415, self.int_to_hex(APN_context_identifer_count, 4)) APN_PDN_type = self.generate_vendor_avp(1456, "c0", 10415, self.int_to_hex(0, 4)) - self.diameterLibLogger.debug("Setting APN AMBR") + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) #AMBR AMBR = '' #Initiate empty var AVP for AMBR if 'AMBR' in apn_profile: @@ -2082,7 +2125,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(ue_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - self.diameterLibLogger.debug("Setting APN Allocation-Retention-Priority") + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['priority_level']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_profile['qos']['arp']['pre_emption_capability']), 4)) @@ -2095,32 +2138,32 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #If static UE IP is specified try: apn_ip = apn_profile['ue']['addr'] - self.diameterLibLogger.debug("Found static IP for UE " + str(apn_ip)) + self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(apn_ip), redisClient=self.redisMessaging) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(apn_ip)) except: Served_Party_Address = "" if 'MIP6-Agent-Info' in apn_profile: - self.diameterLibLogger.info("MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info'])) + self.logTool.log(service='HSS', level='info', message="MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info']), redisClient=self.redisMessaging) MIP6_Destination_Host = self.generate_avp(293, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_HOST']))) MIP6_Destination_Realm = self.generate_avp(283, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_REALM']))) MIP6_Home_Agent_Host = self.generate_avp(348, '40', MIP6_Destination_Host + MIP6_Destination_Realm) MIP6_Agent_Info = self.generate_avp(486, '40', MIP6_Home_Agent_Host) - self.diameterLibLogger.info("MIP6 value is " + str(MIP6_Agent_Info)) + self.logTool.log(service='HSS', level='info', message="MIP6 value is " + str(MIP6_Agent_Info), redisClient=self.redisMessaging) else: MIP6_Agent_Info = '' if 'PDN_GW_Allocation_Type' in apn_profile: - self.diameterLibLogger.info("PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type'])) + self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - self.diameterLibLogger.info("PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type)) + self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) else: PDN_GW_Allocation_Type = '' if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed'])) + self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - self.diameterLibLogger.info("VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed)) + self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) else: VPLMN_Dynamic_Address_Allowed = '' @@ -2131,7 +2174,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #Incriment Context Identifier Count to keep track of how many APN Profiles returned APN_context_identifer_count = APN_context_identifer_count + 1 - self.diameterLibLogger.debug("Processed APN profile " + str(apn_profile['apn'])) + self.logTool.log(service='HSS', level='debug', message="Processed APN profile " + str(apn_profile['apn']), redisClient=self.redisMessaging) subscription_data += self.generate_vendor_avp(1619, "80", 10415, self.int_to_hex(720, 4)) #Subscribed-Periodic-RAU-TAU-Timer (value 720) subscription_data += self.generate_vendor_avp(1429, "c0", 10415, APN_context_identifer + \ @@ -2139,26 +2182,26 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #If MSISDN is present include it in Subscription Data if 'msisdn' in subscriber_details: - self.diameterLibLogger.debug("MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA") + self.logTool.log(service='HSS', level='debug', message="MSISDN is " + str(subscriber_details['msisdn']) + " - adding in ULA", redisClient=self.redisMessaging) msisdn_avp = self.generate_vendor_avp(701, 'c0', 10415, str(subscriber_details['msisdn'])) #MSISDN - self.diameterLibLogger.debug(msisdn_avp) + self.logTool.log(service='HSS', level='debug', message=msisdn_avp, redisClient=self.redisMessaging) subscription_data += msisdn_avp if 'RAT_freq_priorityID' in subscriber_details: - self.diameterLibLogger.debug("RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA") + self.logTool.log(service='HSS', level='debug', message="RAT_freq_priorityID is " + str(subscriber_details['RAT_freq_priorityID']) + " - Adding in ULA", redisClient=self.redisMessaging) rat_freq_priorityID = self.generate_vendor_avp(1440, "C0", 10415, self.int_to_hex(int(subscriber_details['RAT_freq_priorityID']), 4)) #RAT-Frequency-Selection-Priority ID - self.diameterLibLogger.debug(rat_freq_priorityID) + self.logTool.log(service='HSS', level='debug', message=rat_freq_priorityID, redisClient=self.redisMessaging) subscription_data += rat_freq_priorityID if '3gpp-charging-characteristics' in subscriber_details: - self.diameterLibLogger.debug("3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA") + self.logTool.log(service='HSS', level='debug', message="3gpp-charging-characteristics " + str(subscriber_details['3gpp-charging-characteristics']) + " - Adding in ULA", redisClient=self.redisMessaging) _3gpp_charging_characteristics = self.generate_vendor_avp(13, "80", 10415, self.string_to_hex(str(subscriber_details['3gpp-charging-characteristics']))) subscription_data += _3gpp_charging_characteristics - self.diameterLibLogger.debug(_3gpp_charging_characteristics) + self.logTool.log(service='HSS', level='debug', message=_3gpp_charging_characteristics, redisClient=self.redisMessaging) if 'APN_OI_replacement' in subscriber_details: - self.diameterLibLogger.debug("APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA") + self.logTool.log(service='HSS', level='debug', message="APN_OI_replacement " + str(subscriber_details['APN_OI_replacement']) + " - Adding in ULA", redisClient=self.redisMessaging) subscription_data += self.generate_vendor_avp(1427, "C0", 10415, self.string_to_hex(str(subscriber_details['APN_OI_replacement']))) @@ -2440,7 +2483,7 @@ def Request_16777238_258(self, sessionid, ChargingRules, ue_ip, Serving_PGW, Ser avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP #Setup Charging Rule - self.diameterLibLogger.debug(ChargingRules) + self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) avp += self.Charging_Rule_Generator(ChargingRules=ChargingRules, ue_ip=ue_ip) @@ -2534,14 +2577,14 @@ def Request_16777217_307(self, msisdn): avp += self.generate_avp(283, 40, self.OriginRealm) #Destination Realm avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - self.diameterLibLogger.debug("Getting susbcriber IMS info based on MSISDN") - subscriber_ims_details = database.Get_IMS_Subscriber(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber IMS details: " + str(subscriber_ims_details)) - self.diameterLibLogger.debug("Getting susbcriber info based on MSISDN") - subscriber_details = database.Get_Subscriber(msisdn=msisdn) - self.diameterLibLogger.debug("Got subscriber details: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber IMS info based on MSISDN", redisClient=self.redisMessaging) + subscriber_ims_details = self.database.Get_IMS_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber IMS details: " + str(subscriber_ims_details), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Getting susbcriber info based on MSISDN", redisClient=self.redisMessaging) + subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) + self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) subscriber_details = {**subscriber_details, **subscriber_ims_details} - self.diameterLibLogger.debug("Merged subscriber details: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Merged subscriber details: " + str(subscriber_details), redisClient=self.redisMessaging) avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(subscriber_details['imsi'])),'ascii')) #Username AVP @@ -2551,13 +2594,13 @@ def Request_16777217_307(self, msisdn): templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] - self.diameterLibLogger.info("Using template " + str(sh_userdata_template) + " for SH user data") + self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) subscriber_details['mcc'] = self.MCC.zfill(3) - self.diameterLibLogger.debug("Rendering template with values: " + str(subscriber_details)) + self.logTool.log(service='HSS', level='debug', message="Rendering template with values: " + str(subscriber_details), redisClient=self.redisMessaging) xmlbody = template.render(Sh_template_vars=subscriber_details) # this is where to put args to the template renderer avp += self.generate_vendor_avp(702, "c0", 10415, str(binascii.hexlify(str.encode(xmlbody)),'ascii')) diff --git a/lib/messaging.py b/lib/messaging.py index 3adb7a5..127bdaf 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,5 +1,5 @@ from redis import Redis -import time, json +import time, json, uuid class RedisMessaging: """ @@ -27,7 +27,7 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA """ Stores a prometheus metric in a format readable by the metric service. """ - if not metricValue.isdigit(): + if not isinstance(metricValue, (int, float)): return 'Invalid Argument: metricValue must be a digit' metricValue = float(metricValue) prometheusMetricBody = json.dumps([{ @@ -40,7 +40,7 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA } ]) - metricQueueName = f"metric-{serviceName}-{metricTimestamp}" + metricQueueName = f"metric-{serviceName}-{metricTimestamp}-{uuid.uuid4()}" try: self.redisClient.rpush(metricQueueName, prometheusMetricBody) @@ -55,7 +55,7 @@ def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, mes Stores a message in a given Queue (Key). """ try: - logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}" + logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}-{uuid.uuid4()}" logMessage = json.dumps({"message": message}) self.redisClient.rpush(logQueueName, logMessage) if logExpiry is not None: diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 843fb87..85ab690 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -1,6 +1,6 @@ import asyncio import redis.asyncio as redis -import time, json +import time, json, uuid class RedisMessagingAsync: """ @@ -42,7 +42,7 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m } ]) - metricQueueName = f"metric-{serviceName}-{metricTimestamp}" + metricQueueName = f"metric-{serviceName}-{metricTimestamp}-{uuid.uuid4()}" try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -59,7 +59,7 @@ async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: in Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: - logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}" + logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}-{uuid.uuid4()}" logMessage = json.dumps({"message": message}) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.rpush(logQueueName, logMessage) diff --git a/services/diameterService.py b/services/diameterService.py index 8ddcdda..3e40295 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,8 +1,6 @@ import asyncio import sys, os, json import time, yaml, uuid -import concurrent.futures -import logging sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync diff --git a/services/georedService.py b/services/georedService.py index 9217a1f..e894845 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -14,10 +14,8 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): except: print(f"[Geored] Fatal Error - config.yaml not found, exiting.") quit() - self.logTool = LogTool() + self.logTool = LogTool(self.config) self.banners = Banners() - self.georedLogger = self.logTool.setupLogger(loggerName='Geored', config=self.config) - self.georedLogger.info(self.banners.georedService()) self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) self.remotePeers = self.config.get('geored', {}).get('sync_endpoints', []) if not self.config.get('geored', {}).get('enabled'): @@ -26,6 +24,8 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): if not (len(self.remotePeers) > 0): self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") quit() + self.logTool.log(service='Geored', level='info', message=f"{self.banners.georedService()}", redisClient=self.redisMessaging) + def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: operation = operation.upper() @@ -46,7 +46,7 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui else: response = requestOperations[operation](url, headers=headers) if 200 <= response.status_code <= 299: - self.georedLogger.debug(f"[Geored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}") + self.logTool.log(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}", redisClient=self.redisMessaging) self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', @@ -70,7 +70,7 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui metricExpiry=60) except requests.exceptions.ConnectionError as e: error_message = str(e) - self.georedLogger.warning(f"[Geored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) if "Name or service not known" in error_message: self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', @@ -92,7 +92,7 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "error": "Connection Refused"}, metricExpiry=60) except requests.exceptions.Timeout: - self.georedLogger.warning(f"[Geored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', @@ -103,7 +103,7 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "error": "Timeout"}, metricExpiry=60) except Exception as e: - self.georedLogger.error(f"[Geored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}") + self.logTool.log(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', @@ -120,8 +120,8 @@ def handleGeoredQueue(self): georedQueue = self.redisMessaging.getNextQueue(pattern='geored-*') georedMessage = self.redisMessaging.getMessage(queue=georedQueue) assert(len(georedMessage)) - self.georedLogger.debug(f"[Geored] Queue: {georedQueue}") - self.georedLogger.debug(f"[Geored] Message: {georedMessage}") + self.logTool.log(service='Geored', level='debug', message=f"[Geored] Queue: {georedQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='Geored', level='debug', message=f"[Geored] Message: {georedMessage}", redisClient=self.redisMessaging) georedDict = json.loads(georedMessage) georedOperation = georedDict['operation'] diff --git a/services/hssService.py b/services/hssService.py index beaedb9..3557dab 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -23,8 +23,8 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.originRealm = self.config.get('hss', {}).get('OriginRealm', f'mnc{self.mnc}.mcc{self.mcc}.3gppnetwork.org') self.originHost = self.config.get('hss', {}).get('OriginHost', f'hss01') self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') - self.diameterLibrary = Diameter(originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) + self.diameterLibrary = Diameter(redisMessaging=self.redisMessaging, logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) def handleQueue(self): """ @@ -49,11 +49,11 @@ def handleQueue(self): diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') except Exception as e: - self.hssLogger.warn(f"[HSS] [handleInboundQueue] Failed to generate diameter outbound: {e}") + self.logTool.log(service='HSS', level='warning', message=f"[HSS] [handleQueue] Failed to generate diameter outbound: {e}", redisClient=self.redisMessaging) continue - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) if not len(diameterOutbound) > 0: continue @@ -61,9 +61,9 @@ def handleQueue(self): outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}-{inboundTimestamp}" outboundMessage = json.dumps({"diameter-outbound": diameterOutbound}) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleInboundQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) diff --git a/services/metricService.py b/services/metricService.py index f0ea88f..be0d0a7 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -17,16 +17,14 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): with open("../config.yaml", "r") as self.configFile: self.config = yaml.safe_load(self.configFile) except: - print(f"[HSS] Fatal Error - config.yaml not found, exiting.") + print(f"[Metric] Fatal Error - config.yaml not found, exiting.") quit() self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) self.banners = Banners() - self.logTool = LogTool() + self.logTool = LogTool(config=self.config) self.registry = CollectorRegistry(auto_describe=True) - self.metricLogger = self.logTool.setupLogger(loggerName='Metric', config=self.config) - self.metricLogger.info(self.banners.metricService()) - + self.logTool.log(service='Metric', level='info', message=f"{self.banners.metricService()}", redisClient=self.redisMessaging) def handleMetrics(self): try: @@ -37,10 +35,10 @@ def handleMetrics(self): metric = self.redisMessaging.getMessage(queue=metricQueue) if not (len(metric) > 0): return - self.metricLogger.info(f"Received Metric: {metric}") + self.logTool.log(service='Metric', level='debug', message=f"Received Metric: {metric}", redisClient=self.redisMessaging) prometheusJsonList = json.loads(metric) for prometheusJson in prometheusJsonList: - self.metricLogger.debug(prometheusJson) + self.logTool.log(service='Metric', level='debug', message=f"{prometheusJson}", redisClient=self.redisMessaging) if not all(key in prometheusJson for key in ('NAME', 'TYPE', 'ACTION', 'VALUE')): raise ValueError('All fields are not available for parsing') counterName = prometheusJson['NAME'] @@ -68,14 +66,14 @@ def handleMetrics(self): prometheusMethod = getattr(counterRecord, action) prometheusMethod(counterValue) else: - self.metricLogger.debug(f"Invalid action `{counterAction}` in message: {metric}, skipping.") + self.logTool.log(service='Metric', level='warn', message=f"Invalid action '{counterAction}' in message: {metric}, skipping.", redisClient=self.redisMessaging) continue else: - self.metricLogger.debug(f"Invalid type `{counterType}` in message: {metric}, skipping.") + self.logTool.log(service='Metric', level='warn', message=f"Invalid type '{counterType}' in message: {metric}, skipping.", redisClient=self.redisMessaging) continue except Exception as e: - self.metricLogger.error(f"Unable to parse message: {metric}, due to {e}. Skipping.") + self.logTool.log(service='Metric', level='error', message=f"Unable to parse message: {metric}, due to {e}. Skipping.", redisClient=self.redisMessaging) return From 5349db63e2bd07e669bd02656d768d6dafb871df Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 29 Aug 2023 20:08:14 +1000 Subject: [PATCH 07/43] Add systemctl services --- PyHSS_API.py | 70 ++++++++++------------- lib/database.py | 62 ++++++++++---------- lib/diameter.py | 30 ++++++++-- lib/messaging.py | 8 +++ services/diameterService.py | 61 ++++++++++++++++---- services/georedService.py | 101 +++++++++++++++++++++++++++++++-- systemd/pyhss.service | 17 ++++++ systemd/pyhss_diameter.service | 13 +++++ systemd/pyhss_geored.service | 13 +++++ systemd/pyhss_hss.service | 13 +++++ systemd/pyhss_log.service | 13 +++++ systemd/pyhss_metric.service | 13 +++++ 12 files changed, 321 insertions(+), 93 deletions(-) create mode 100644 systemd/pyhss.service create mode 100644 systemd/pyhss_diameter.service create mode 100644 systemd/pyhss_geored.service create mode 100644 systemd/pyhss_hss.service create mode 100644 systemd/pyhss_log.service create mode 100644 systemd/pyhss_metric.service diff --git a/PyHSS_API.py b/PyHSS_API.py index a2a7d63..12e7e23 100644 --- a/PyHSS_API.py +++ b/PyHSS_API.py @@ -4,33 +4,22 @@ from flask_restx import Api, Resource, fields, reqparse, abort from werkzeug.middleware.proxy_fix import ProxyFix from functools import wraps +sys.path.append(os.path.realpath('lib')) import datetime import traceback import sqlalchemy import socket - - +import logtool +from diameter import Diameter +import database import logging import yaml +import os with open("config.yaml", 'r') as stream: yaml_config = (yaml.safe_load(stream)) -import os -import sys -sys.path.append(os.path.realpath('lib')) - -#Setup Logging -import logtool - - -import database - -from prometheus_flask_exporter import PrometheusMetrics app = Flask(__name__) -metrics = PrometheusMetrics.for_app_factory() -metrics.init_app(app) -from logtool import prom_flask_http_geored_endpoints APN = database.APN Serving_APN = database.SERVING_APN @@ -49,6 +38,14 @@ site_name = yaml_config.get("hss", {}).get("site_name", "") origin_host_name = socket.gethostname() +diameterClient = Diameter( + OriginHost=yaml_config['hss']['OriginHost'], + OriginRealm=yaml_config['hss']['OriginRealm'], + MNC=yaml_config['hss']['MNC'], + MCC=yaml_config['hss']['MCC'], + ProductName='PyHSS-client-API' + ) + app.wsgi_app = ProxyFix(app.wsgi_app) api = Api(app, version='1.0', title=f'{site_name + " - " if site_name else ""}{origin_host_name} - PyHSS OAM API', description='Restful API for working with PyHSS', @@ -1037,9 +1034,11 @@ class PyHSS_OAM_Peers(Resource): def get(self): '''Get all Diameter Peers''' try: - logObj = logtool.LogTool() - DiameterPeers = logObj.GetDiameterPeers() - return DiameterPeers, 200 + #@@Fixme + # logObj = logtool.LogTool() + # DiameterPeers = logObj.GetDiameterPeers() + # return DiameterPeers, 200 + return '' except Exception as E: print(E) return handle_exception(E) @@ -1279,12 +1278,11 @@ def put(self): DestinationRealm = OriginRealm mcc = yaml_config['hss']['MCC'] #Mobile Country Code mnc = yaml_config['hss']['MNC'] #Mobile Network Code - import diameter - diameter = diameter.Diameter(diameter_host, DestinationRealm, 'PyHSS-client-API', str(mcc), str(mnc)) - diam_hex = diameter.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') + diam_hex = diameterClient.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') import time - logObj = logtool.LogTool() - logObj.Async_SendRequest(diam_hex, str(pcrf_session_data['serving_pgw'])) + # @@Fixme + # logObj = logtool.LogTool() + # logObj.Async_SendRequest(diam_hex, str(pcrf_session_data['serving_pgw'])) return diam_hex, 200 @ns_pcrf.route('/') @@ -1327,7 +1325,8 @@ def patch(self): if 'serving_mme' in json_data: print("Updating serving MME") response_data.append(database.Update_Serving_MME(imsi=str(json_data['imsi']), serving_mme=json_data['serving_mme'], serving_mme_realm=json_data['serving_mme_realm'], serving_mme_peer=json_data['serving_mme_peer'], propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='HSS', geored_host=request.remote_addr).inc() + #@@Fixme + # prom_flask_http_geored_endpoints.labels(endpoint='HSS', geored_host=request.remote_addr).inc() if 'serving_apn' in json_data: print("Updating serving APN") if 'serving_pgw_realm' not in json_data: @@ -1343,7 +1342,8 @@ def patch(self): serving_pgw_realm=json_data['serving_pgw_realm'], serving_pgw_peer=json_data['serving_pgw_peer'], propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='PCRF', geored_host=request.remote_addr).inc() + #@@Fixme + # prom_flask_http_geored_endpoints.labels(endpoint='PCRF', geored_host=request.remote_addr).inc() if 'scscf' in json_data: print("Updating serving SCSCF") if 'scscf_realm' not in json_data: @@ -1351,11 +1351,13 @@ def patch(self): if 'scscf_peer' not in json_data: json_data['scscf_peer'] = None response_data.append(database.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='IMS', geored_host=request.remote_addr).inc() + #@@Fixme + # prom_flask_http_geored_endpoints.labels(endpoint='IMS', geored_host=request.remote_addr).inc() if 'imei' in json_data: print("Updating EIR") response_data.append(database.Store_IMSI_IMEI_Binding(str(json_data['imsi']), str(json_data['imei']), str(json_data['match_response_code']), propagate=False)) - prom_flask_http_geored_endpoints.labels(endpoint='EIR', geored_host=request.remote_addr).inc() + #@@Fixme + # prom_flask_http_geored_endpoints.labels(endpoint='EIR', geored_host=request.remote_addr).inc() return response_data, 200 except Exception as E: print("Exception when updating: " + str(E)) @@ -1384,22 +1386,12 @@ def put(self, imsi): print("JSON Data sent: " + str(json_data)) if 'DestinationHost' not in json_data: json_data['DestinationHost'] = None - import diameter - diameter = diameter.Diameter( - OriginHost=yaml_config['hss']['OriginHost'], - OriginRealm=yaml_config['hss']['OriginRealm'], - MNC=yaml_config['hss']['MNC'], - MCC=yaml_config['hss']['MCC'], - ProductName='PyHSS-client-API' - ) - diam_hex = diameter.Request_16777251_317( + diam_hex = diameterClient.sendDiameterRequest( imsi=imsi, DestinationHost=json_data['DestinationHost'], DestinationRealm=json_data['DestinationRealm'], CancellationType=json_data['cancellationType'] ) - logObj = logtool.LogTool() - logObj.Async_SendRequest(diam_hex, str(json_data['diameterPeer'])) return diam_hex, 200 if __name__ == '__main__': diff --git a/lib/database.py b/lib/database.py index d7cfb02..960a715 100755 --- a/lib/database.py +++ b/lib/database.py @@ -18,6 +18,7 @@ import threading from messaging import RedisMessaging import yaml +import json Base = declarative_base() @@ -831,12 +832,12 @@ def get_last_operation_log(self, existingSession=None): def handleGeored(self, jsonData): try: + georedDict = {} if self.config.get('geored', {}).get('enabled', False): if self.config.get('geored', {}).get('sync_endpoints', []) is not None and len(self.config.get('geored', {}).get('sync_endpoints', [])) > 0: - transaction_id = str(uuid.uuid4()) - self.logTool.log(service='Database', level='info', message="[Database] Break 1", redisClient=self.redisMessaging) - self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=jsonData, queueExpiry=120) - self.logTool.log(service='Database', level='info', message="[Database] Break 1", redisClient=self.redisMessaging) + georedDict['body'] = jsonData + georedDict['operation'] = 'PATCH' + self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) except Exception as E: self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) @@ -852,7 +853,7 @@ def handleWebhook(self, objectData, operation): externalNotification = self.Sanitize_Datetime(objectData) externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} externalNotification['headers'] = externalNotificationHeaders - self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=externalNotification, queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(externalNotification), queueExpiry=120) return True def Sanitize_Datetime(self, result): @@ -1841,19 +1842,22 @@ def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=Tru try: device_info = self.get_device_info_from_TAC(imei=str(imei)) self.logTool.log(service='Database', level='debug', message="Got Device Info: " + str(device_info), redisClient=self.redisMessaging) - #@@Fixme - # prom_eir_devices.labels( - # imei_prefix=device_info['tac_prefix'], - # device_type=device_info['name'], - # device_name=device_info['model'] - # ).inc() + self.redisMessaging.sendMetric(serviceName='database', metricName='prom_eir_devices', + metricType='counter', metricAction='inc', + metricValue=1, metricHelp='Profile of attached devices', + metricLabels={'imei_prefix': device_info['tac_prefix'], + 'device_type': device_info['name'], + 'device_name': device_info['model']}, + metricExpiry=60) except Exception as E: self.logTool.log(service='Database', level='debug', message="Failed to get device info from TAC", redisClient=self.redisMessaging) - # prom_eir_devices.labels( - # imei_prefix=str(imei)[0:8], - # device_type='Unknown', - # device_name='Unknown' - # ).inc() + self.redisMessaging.sendMetric(serviceName='database', metricName='prom_eir_devices', + metricType='counter', metricAction='inc', + metricValue=1, metricHelp='Profile of attached devices', + metricLabels={'imei_prefix': str(imei)[0:8], + 'device_type': 'Unknown', + 'device_name': 'Unknown'}, + metricExpiry=60) else: self.logTool.log(service='Database', level='debug', message="No TAC database configured, skipping device info lookup", redisClient=self.redisMessaging) @@ -1995,27 +1999,21 @@ def get_device_info_from_TAC(self, imei): #Try 8 digit TAC try: self.logTool.log(service='Database', level='debug', message="Trying to match on 8 Digit IMEI", redisClient=self.redisMessaging) - #@@Fixme - # imei_result = logtool.RedisHMGET(str(imei[0:8])) - # print("Got back: " + str(imei_result)) - # imei_result = dict_bytes_to_dict_string(imei_result) - # assert(len(imei_result) != 0) - # self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) - # return imei_result - return "0" + imei_result = self.redisMessaging.RedisHGetAll(str(imei[0:8])) + imei_result = self.dict_bytes_to_dict_string(imei_result) + assert(len(imei_result) != 0) + self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) + return imei_result except: self.logTool.log(service='Database', level='debug', message="Failed to match on 8 digit IMEI", redisClient=self.redisMessaging) try: self.logTool.log(service='Database', level='debug', message="Trying to match on 6 Digit IMEI", redisClient=self.redisMessaging) - #@@Fixme - # imei_result = logtool.RedisHMGET(str(imei[0:6])) - # print("Got back: " + str(imei_result)) - # imei_result = dict_bytes_to_dict_string(imei_result) - # assert(len(imei_result) != 0) - # self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) - # return imei_result - return "0" + imei_result = self.redisMessaging.RedisHGetAll(str(imei[0:6])) + imei_result = self.dict_bytes_to_dict_string(imei_result) + assert(len(imei_result) != 0) + self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) + return imei_result except: self.logTool.log(service='Database', level='debug', message="Failed to match on 6 digit IMEI", redisClient=self.redisMessaging) diff --git a/lib/diameter.py b/lib/diameter.py index 65c1936..0c425f1 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -31,7 +31,7 @@ def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm self.logTool.log(service='HSS', level='info', message=f"Product Name: {str(productName)}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=f"PLMN: {str(self.MCC)}/{str(self.MNC)}", redisClient=self.redisMessaging) - self.diameterCommandList = [ + self.diameterResponseList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, @@ -50,6 +50,11 @@ def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] + self.diameterRequestList = [ + {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, + {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, + ] + #Generates rounding for calculating padding def myround(self, n, base=4): if(n > 0): @@ -72,7 +77,6 @@ def ip_to_hex(self, ip): else: ip_hex = "0002" #IPv6 ip_hex += format(ipaddress.IPv6Address(ip), 'X') - #self.logTool.log(service='HSS', level='debug', message="Converted IP to hex - Input: " + str(ip) + " output: " + str(ip_hex), redisClient=self.redisMessaging) return ip_hex def hex_to_int(self, hex): @@ -381,7 +385,6 @@ def get_avp_data(self, avps, avp_code): #Loops through list of dic misc_data.append(keys['misc_data']) return misc_data - def decode_diameter_packet_length(self, data): packet_vars = {} data = data.hex() @@ -396,7 +399,7 @@ def getDiameterMessageType(self, binaryData: str) -> dict: packet_vars, avps = self.decode_diameter_packet(binaryData) response = {} - for diameterApplication in self.diameterCommandList: + for diameterApplication in self.diameterResponseList: try: assert(packet_vars["command_code"] == diameterApplication["commandCode"]) assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) @@ -405,9 +408,24 @@ def getDiameterMessageType(self, binaryData: str) -> dict: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Successfully generated response: {response}", redisClient=self.redisMessaging) except Exception as e: continue - return response + def generateDiameterRequest(self, requestType: str, **kwargs) -> str: + try: + request = '' + self.logTool.log(service='HSS', level='debug', message=f"Generating a diameter outbound request", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + request = diameterApplication["requestMethod"](kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + continue + return request + except Exception as e: + return '' + def generateDiameterResponse(self, binaryData: str) -> str: try: packet_vars, avps = self.decode_diameter_packet(binaryData) @@ -423,7 +441,7 @@ def generateDiameterResponse(self, binaryData: str) -> str: self.logTool.log(service='HSS', level='debug', message=packet_vars, redisClient=self.redisMessaging) return - for diameterApplication in self.diameterCommandList: + for diameterApplication in self.diameterResponseList: try: assert(packet_vars["command_code"] == diameterApplication["commandCode"]) assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) diff --git a/lib/messaging.py b/lib/messaging.py index 127bdaf..6e6cdcc 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -111,6 +111,14 @@ def deleteQueue(self, queue: str) -> bool: except Exception as e: return False + def RedisHGetAll(self, key: str): + """ + Wrapper for Redis HGETALL""" + try: + data = self.redisClient.hgetall(key) + return data + except Exception as e: + return '' if __name__ == '__main__': redisMessaging = RedisMessaging() diff --git a/services/diameterService.py b/services/diameterService.py index 3e40295..25d847d 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -1,11 +1,13 @@ import asyncio import sys, os, json import time, yaml, uuid +from datetime import datetime sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync from banners import Banners from logtool import LogTool +import traceback class DiameterService: """ @@ -26,16 +28,20 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.banners = Banners() self.logTool = LogTool(config=self.config) self.diameterLibrary = DiameterAsync() - self.activeConnections = set() + self.activeConnections = {} - async def validateDiameterInbound(self, inboundData) -> bool: + async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ Asynchronously validates a given diameter inbound, and increments the 'Number of Diameter Inbounds' metric. """ try: packetVars, avps = await(self.diameterLibrary.decodeDiameterPacketAsync(inboundData)) + messageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(inboundData)) originHost = (await self.diameterLibrary.getAvpDataAsync(avps, 264))[0] originHost = bytes.fromhex(originHost).decode("utf-8") + self.activeConnections[f"{clientAddress}-{clientPort}"].update({'last_dwr_timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activeConnections[f"{clientAddress}:{clientPort}"]['last_dwr_timestamp'], + 'DiameterHostname': originHost, + }) asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Diameter Inbounds', @@ -49,7 +55,32 @@ async def validateDiameterInbound(self, inboundData) -> bool: print(e) return False return True - + + async def handleActiveDiameterPeers(self): + """ + Prunes stale connection entries from self.activeConnections. + """ + while True: + try: + if not len(self.activeConnections) > 0: + await(asyncio.sleep(1)) + continue + + activeDiameterPeersTimeout = self.config.get('hss', {}).get('active_diameter_peers_timeout', 86400) + + for key, connection in self.activeConnections.items(): + if connection.get('connection_status', '') == 'disconnected': + if (datetime.now() - datetime.strptime(connection['connect_timestamp'], "%Y-%m-%d %H:%M:%S")).seconds > activeDiameterPeersTimeout: + del self.activeConnections[key] + + await(self.redisMessaging.sendMessage(queue='ActiveDiameterPeers', message=json.dumps(self.activeConnections))) + + await(asyncio.sleep(1)) + except Exception as e: + print(e) + await(asyncio.sleep(1)) + continue + async def logActiveConnections(self): """ Logs the number of active connections on a rolling basis. @@ -58,7 +89,7 @@ async def logActiveConnections(self): if not len(activeConnections) > 0: activeConnections = '' await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActiveConnections] {len(self.activeConnections)} Active Connections {activeConnections}", redisClient=self.redisMessaging)) - + async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. @@ -77,7 +108,7 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc if len(inboundData) > 0: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) - if not await(self.validateDiameterInbound(inboundData)): + if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, terminating connection.", redisClient=self.redisMessaging)) return False @@ -139,7 +170,14 @@ async def handleConnection(self, reader, writer): coroutineUuid = str(uuid.uuid4()) (clientAddress, clientPort) = writer.get_extra_info('peername') await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) - self.activeConnections.add((clientAddress, clientPort, coroutineUuid)) + if f"{clientAddress}-{clientPort}" not in self.activeConnections: + self.activeConnections[f"{clientAddress}-{clientPort}"] = {} + self.activeConnections[f"{clientAddress}-{clientPort}"].update({ + "connect_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "recv_ip_address":clientAddress, + "recv_ip_port":clientAddress, + "connection_status": 'connected', + }) await(self.logActiveConnections()) readTask = asyncio.create_task(self.readInboundData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) @@ -156,21 +194,21 @@ async def handleConnection(self, reader, writer): writer.close() await(writer.wait_closed()) - self.activeConnections.discard((clientAddress, clientPort, coroutineUuid)) + self.activeConnections[f"{clientAddress}-{clientPort}"].update({ + "connection_status": 'disconnected', + }) await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.", redisClient=self.redisMessaging)) await(self.logActiveConnections()) - - return except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}\n{traceback.format_exc()}", redisClient=self.redisMessaging)) return async def startServer(self, host: str=None, port: int=None, type: str=None): """ Start a server with the given parameters and handle new clients with self.handleConnection. - Also create a single instance of self.logActiveConnections. + Also create a single instance of self.handleActiveDiameterPeers. """ if host is None: @@ -190,6 +228,7 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}", redisClient=self.redisMessaging)) + handleActiveDiameterPeerTask = asyncio.create_task(self.handleActiveDiameterPeers()) async with server: await(server.serve_forever()) diff --git a/services/georedService.py b/services/georedService.py index e894845..7194c11 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -26,7 +26,6 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.logTool.log(service='Geored', level='info', message=f"{self.banners.georedService()}", redisClient=self.redisMessaging) - def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: operation = operation.upper() requestOperations = {'GET': requests.get, 'PUT': requests.put, 'POST': requests.post, 'PATCH':requests.patch, 'DELETE': requests.delete} @@ -115,7 +114,96 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui metricExpiry=60) return True - def handleGeoredQueue(self): + def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + operation = operation.upper() + requestOperations = {'GET': requests.get, 'PUT': requests.put, 'POST': requests.post, 'PATCH':requests.patch, 'DELETE': requests.delete} + + if not url or not operation or not body: + return False + + if operation not in requestOperations: + return False + + headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} + + for attempt in range(retryCount): + try: + if operation in ['PUT', 'POST', 'PATCH']: + response = requestOperations[operation](url, json=body, headers=headers) + else: + response = requestOperations[operation](url, headers=headers) + if 200 <= response.status_code <= 299: + self.logTool.log(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}", redisClient=self.redisMessaging) + + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(response.status_code), + "error": ""}, + metricExpiry=60) + break + else: + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(response.status_code), + "error": str(response.reason)}, + metricExpiry=60) + except requests.exceptions.ConnectionError as e: + error_message = str(e) + self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) + if "Name or service not known" in error_message: + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "No matching DNS entry found"}, + metricExpiry=60) + else: + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "Connection Refused"}, + metricExpiry=60) + except requests.exceptions.Timeout: + self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": "Timeout"}, + metricExpiry=60) + except Exception as e: + self.logTool.log(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": "000", + "error": e}, + metricExpiry=60) + return True + + + def handleQueue(self): try: georedQueue = self.redisMessaging.getNextQueue(pattern='geored-*') georedMessage = self.redisMessaging.getMessage(queue=georedQueue) @@ -127,8 +215,11 @@ def handleGeoredQueue(self): georedOperation = georedDict['operation'] georedBody = georedDict['body'] - for remotePeer in self.remotePeers: - self.sendGeored(url=remotePeer+'/geored/', operation=georedOperation, body=georedBody) + try: + for remotePeer in self.remotePeers: + self.sendGeored(url=remotePeer+'/geored/', operation=georedOperation, body=georedBody) + except Exception as e: + self.logTool.log(service='Geored', level='debug', message=f"[Geored] Error sending geored message: {e}", redisClient=self.redisMessaging) except Exception as e: return False @@ -136,4 +227,4 @@ def handleGeoredQueue(self): if __name__ == '__main__': georedService = GeoredService() while True: - georedService.handleGeoredQueue() \ No newline at end of file + georedService.handleQueue() \ No newline at end of file diff --git a/systemd/pyhss.service b/systemd/pyhss.service new file mode 100644 index 0000000..39ac30a --- /dev/null +++ b/systemd/pyhss.service @@ -0,0 +1,17 @@ +[Unit] +Description=PyHSS +After=network-online.target mysql.service +Wants=pyhss_diameter.service +Wants=pyhss_geored.service +Wants=pyhss_hss.service +Wants=pyhss_log.service +Wants=pyhss_metric.service + + +[Service] +Type=oneshot +ExecStart=/bin/true +RemainAfterExit=yes + +[Install] +WantedBy=multi-user.target \ No newline at end of file diff --git a/systemd/pyhss_diameter.service b/systemd/pyhss_diameter.service new file mode 100644 index 0000000..02ceaa4 --- /dev/null +++ b/systemd/pyhss_diameter.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Diameter Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 diameterService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_geored.service b/systemd/pyhss_geored.service new file mode 100644 index 0000000..7f2da02 --- /dev/null +++ b/systemd/pyhss_geored.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Geored Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 georedService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_hss.service b/systemd/pyhss_hss.service new file mode 100644 index 0000000..5d5994c --- /dev/null +++ b/systemd/pyhss_hss.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS HSS Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 hssService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_log.service b/systemd/pyhss_log.service new file mode 100644 index 0000000..11a7e15 --- /dev/null +++ b/systemd/pyhss_log.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Log Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 logService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file diff --git a/systemd/pyhss_metric.service b/systemd/pyhss_metric.service new file mode 100644 index 0000000..4c3995a --- /dev/null +++ b/systemd/pyhss_metric.service @@ -0,0 +1,13 @@ +[Unit] +Description=PyHSS Metric Service +PartOf=pyhss.service + + +[Service] +User=root +WorkingDirectory=/etc/pyhss/services/ +ExecStart=python3 metricService.py +Restart=always + +[Install] +WantedBy=pyhss.service \ No newline at end of file From ebad1feb3d3639fb12c21218ead9220db380625f Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 30 Aug 2023 17:25:26 +1000 Subject: [PATCH 08/43] Nearing completion, api refactored --- hss.py | 8 - lib/database.py | 41 ++- lib/diameter.py | 31 ++ lib/messaging.py | 41 ++- lib/messagingAsync.py | 30 +- lib/old.logtool.py | 243 -------------- PyHSS_API.py => services/apiService.py | 365 +++++++++++---------- services/diameterService.py | 117 ++++--- services/georedService.py | 342 ++++++++++++------- services/hssService.py | 9 +- services/metricService.py | 17 +- services/webhookService.py | 0 test_Diameter.py => tests/test_Diameter.py | 0 tests_API.py => tests/tests_API.py | 0 14 files changed, 643 insertions(+), 601 deletions(-) delete mode 100644 hss.py delete mode 100644 lib/old.logtool.py rename PyHSS_API.py => services/apiService.py (75%) delete mode 100644 services/webhookService.py rename test_Diameter.py => tests/test_Diameter.py (100%) rename tests_API.py => tests/tests_API.py (100%) diff --git a/hss.py b/hss.py deleted file mode 100644 index 72d741f..0000000 --- a/hss.py +++ /dev/null @@ -1,8 +0,0 @@ -import os, sys, json, yaml - -class PyHSS: - - def __init__(self): - pass - - \ No newline at end of file diff --git a/lib/database.py b/lib/database.py index 960a715..9381f81 100755 --- a/lib/database.py +++ b/lib/database.py @@ -282,7 +282,7 @@ def __init__(self, logTool, redisMessaging): self.logTool.log(service='Database', level='debug', message="Database already created", redisClient=self.redisMessaging) #Load IMEI TAC database into Redis if enabled - if ('tac_database_csv' in self.config['eir']) and (self.config['redis']['enabled'] == True): + if ('tac_database_csv' in self.config['eir']): self.load_IMEI_database_into_Redis() else: self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) @@ -834,26 +834,31 @@ def handleGeored(self, jsonData): try: georedDict = {} if self.config.get('geored', {}).get('enabled', False): - if self.config.get('geored', {}).get('sync_endpoints', []) is not None and len(self.config.get('geored', {}).get('sync_endpoints', [])) > 0: + if self.config.get('geored', {}).get('endpoints', []) is not None and len(self.config.get('geored', {}).get('endpoints', [])) > 0: georedDict['body'] = jsonData georedDict['operation'] = 'PATCH' self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) except Exception as E: self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) - def handleWebhook(self, objectData, operation): - external_webhook_notification_enabled = self.config.get('external', {}).get('external_webhook_notification_enabled', False) - external_webhook_notification_url = self.config.get('external', {}).get('external_webhook_notification_url', '') + def handleWebhook(self, objectData, operation: str="PATCH"): + webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) + endpointList = self.config.get('webhooks', {}).get('endpoints', []) + webhook = {} - if not external_webhook_notification_enabled: + if not webhooksEnabled: return False - if not external_webhook_notification_url: - self.logTool.log(service='Database', level='error', message="External webhook notification enabled, but external_webhook_notification_url is not defined.", redisClient=self.redisMessaging) + + if not len (endpointList) > 0: + self.logTool.log(service='Database', level='error', message="Webhooks enabled, but endpoints are missing.", redisClient=self.redisMessaging) + return False + + webhookHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} - externalNotification = self.Sanitize_Datetime(objectData) - externalNotificationHeaders = {'Content-Type': 'application/json', 'Referer': socket.gethostname()} - externalNotification['headers'] = externalNotificationHeaders - self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(externalNotification), queueExpiry=120) + webhook['body'] = self.Sanitize_Datetime(objectData) + webhook['headers'] = webhookHeaders + webhook['operation'] = "POST" + self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(webhook), queueExpiry=120) return True def Sanitize_Datetime(self, result): @@ -1035,7 +1040,7 @@ def UpdateObj(self, obj_type, json_data, obj_id, disable_logging=False, operatio self.log_changes_before_commit(session) objectData = self.GetObj(obj_type, obj_id) session.commit() - self.handleWebhook(objectData, 'UPDATE') + self.handleWebhook(objectData, 'PATCH') except Exception as E: self.logTool.log(service='Database', level='error', message=f"Failed to commit session, error: {E}", redisClient=self.redisMessaging) self.safe_rollback(session) @@ -1102,7 +1107,7 @@ def CreateObj(self, obj_type, json_data, disable_logging=False, operation_id=Non session.refresh(newObj) result = newObj.__dict__ result.pop('_sa_instance_state') - self.handleWebhook(result, 'CREATE') + self.handleWebhook(result, 'PUT') return result except Exception as E: self.logTool.log(service='Database', level='error', message=f"Exception in CreateObj, error: {E}", redisClient=self.redisMessaging) @@ -1537,7 +1542,7 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ session.commit() objectData = self.GetObj(SUBSCRIBER, result.subscriber_id) - self.handleWebhook(objectData, 'UPDATE') + self.handleWebhook(objectData, 'PATCH') #Sync state change with geored if propagate == True: @@ -1585,7 +1590,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N session.commit() objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) - self.handleWebhook(objectData, 'UPDATE') + self.handleWebhook(objectData, 'PATCH') #Sync state change with geored if propagate == True: @@ -1657,7 +1662,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber self.UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - self.handleWebhook(objectData, 'UPDATE') + self.handleWebhook(objectData, 'PATCH') except: self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) @@ -1668,7 +1673,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber #Create if does not exist self.CreateObj(SERVING_APN, json_data, True) objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - self.handleWebhook(objectData, 'CREATE') + self.handleWebhook(objectData, 'PUT') #Sync state change with geored if propagate == True: diff --git a/lib/diameter.py b/lib/diameter.py index 0c425f1..52f462f 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -53,6 +53,7 @@ def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm self.diameterRequestList = [ {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, + {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, ] #Generates rounding for calculating padding @@ -395,6 +396,36 @@ def decode_diameter_packet_length(self, data): else: return False + def getPeerType(self, originHost: str) -> str: + try: + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + + for peer in peerTypes: + if peer in originHost.lower(): + return peer + + except Exception as e: + return '' + + def getConnectedPeersByType(self, peerType: str) -> list: + try: + peerType = peerType.lower() + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + + if peerType not in peerTypes: + return [] + filteredConnectedPeers = [] + activePeers = self.redisMessaging.getValue(key="ActiveDiameterPeers") + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + filteredConnectedPeers.append(activePeers.get(key, {})) + + return filteredConnectedPeers + + except Exception as e: + return [] + def getDiameterMessageType(self, binaryData: str) -> dict: packet_vars, avps = self.decode_diameter_packet(binaryData) response = {} diff --git a/lib/messaging.py b/lib/messaging.py index 6e6cdcc..2491f44 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -111,14 +111,41 @@ def deleteQueue(self, queue: str) -> bool: except Exception as e: return False + def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + """ + Stores a value under a given key and sets an expiry (in seconds) if provided. + """ + try: + self.redisClient.set(key, value) + if keyExpiry is not None: + self.redisClient.expire(key, keyExpiry) + return f'{value} stored in {key} successfully.' + except Exception as e: + return '' + + def getValue(self, key: str) -> str: + """ + Gets the value stored under a given key. + """ + try: + message = self.redisClient.get(key) + if message is None: + message = '' + else: + return message + except Exception as e: + return '' + def RedisHGetAll(self, key: str): - """ - Wrapper for Redis HGETALL""" - try: - data = self.redisClient.hgetall(key) - return data - except Exception as e: - return '' + """ + Wrapper for Redis HGETALL + *Deprecated: will be removed upon completed database cleanup. + """ + try: + data = self.redisClient.hgetall(key) + return data + except Exception as e: + return '' if __name__ == '__main__': redisMessaging = RedisMessaging() diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 85ab690..bc9c76d 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -88,7 +88,6 @@ async def getMessage(self, queue: str) -> str: pass return message except Exception as e: - print(e) return '' async def getQueues(self, pattern: str='*') -> list: @@ -121,7 +120,34 @@ async def deleteQueue(self, queue: str) -> bool: return True except Exception as e: return False - + + async def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + """ + Stores a value under a given key asynchronously and sets an expiry (in seconds) if provided. + """ + try: + async with self.redisClient.pipeline(transaction=True) as redisPipe: + await redisPipe.set(key, value) + if keyExpiry is not None: + await redisPipe.expire(key, value) + setValueResult, expireValueResult = await redisPipe.execute() + return f'{value} stored in {key} successfully.' + except Exception as e: + return '' + + async def getValue(self, key: str) -> str: + """ + Gets the value stored under a given key asynchronously. + """ + try: + message = await(self.redisClient.get(key)) + if message is None: + message = '' + else: + return message + except Exception as e: + return '' + async def closeConnection(self) -> bool: await self.redisClient.close() return True diff --git a/lib/old.logtool.py b/lib/old.logtool.py deleted file mode 100644 index ac341a2..0000000 --- a/lib/old.logtool.py +++ /dev/null @@ -1,243 +0,0 @@ -import logging -import logging.handlers as handlers -import os -import sys -import inspect -sys.path.append(os.path.realpath('../')) -import yaml -from datetime import datetime as log_dt -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) - -import json -import pickle - -from prometheus_client import Counter, Gauge, Histogram, Summary - -from prometheus_client import start_http_server - -if yaml_config['prometheus']['enabled'] == True: - #Check if this is the HSS service, and if it's not increment the port before starting - print(sys.argv[0]) - if 'hss.py' in str(sys.argv[0]): - print("Starting Prometheus on port from config " + str(yaml_config['prometheus']['port'])) - else: - print("This is not the HSS stack so offsetting Prometheus port") - yaml_config['prometheus']['port'] += 1 - try: - start_http_server(yaml_config['prometheus']['port']) - print("Started Prometheus on port " + str(yaml_config['prometheus']['port'])) - except Exception as E: - print("Error loading Prometheus") - print(E) - - -tags = ['diameter_application_id', 'diameter_cmd_code', 'endpoint', 'type'] -prom_diam_request_count = Counter('prom_diam_request_count', 'Number of Diameter Requests', tags) -prom_diam_response_count_successful = Counter('prom_diam_response_count_successful', 'Number of Successful Diameter Responses', tags) -prom_diam_response_count_fail = Counter('prom_diam_response_count_fail', 'Number of Failed Diameter Responses', tags) -prom_diam_connected_peers = Gauge('prom_diam_connected_peers', 'Connected Diameter Peer Count', ['endpoint']) -prom_diam_connected_peers._metrics.clear() -prom_diam_response_time_diam = Histogram('prom_diam_response_time_diam', 'Diameter Response Times') -prom_diam_response_time_method = Histogram('prom_diam_response_time_method', 'Diameter Response Times', tags) -prom_diam_response_time_db = Summary('prom_diam_response_time_db', 'Diameter Response Times from Database') -prom_diam_response_time_h = Histogram('request_latency_seconds', 'Diameter Response Time Histogram') -prom_diam_auth_event_count = Counter('prom_diam_auth_event_count', 'Diameter Authentication related Counters', ['diameter_application_id', 'diameter_cmd_code', 'event', 'imsi_prefix']) -prom_diam_eir_event_count = Counter('prom_diam_eir_event_count', 'Diameter EIR event related Counters', ['response']) - -prom_eir_devices = Counter('prom_eir_devices', 'Profile of attached devices', ['imei_prefix', 'device_type', 'device_name']) - -prom_http_geored = Counter('prom_http_geored', 'Number of Geored Pushes', ['geored_host', 'endpoint', 'http_response_code', 'error']) -prom_flask_http_geored_endpoints = Counter('prom_flask_http_geored_endpoints', 'Number of Geored Pushes Received', ['geored_host', 'endpoint']) - - -prom_pcrf_subs = Gauge('prom_pcrf_subs', 'Number of attached PCRF Subscribers') -prom_mme_subs = Gauge('prom_mme_subs', 'Number of attached MME Subscribers') -prom_ims_subs = Gauge('prom_ims_subs', 'Number of attached IMS Subscribers') - -class LogTool: - def __init__(self, **kwargs): - print("Instantiating LogTool with Kwargs " + str(kwargs.items())) - if yaml_config['redis']['enabled'] == True: - print("Redis support enabled") - import redis - redis_store = redis.Redis(host=str(yaml_config['redis']['host']), port=str(yaml_config['redis']['port']), db=0) - self.redis_store = redis_store - try: - if "HSS_Init" in kwargs: - print("Called Init for HSS_Init") - redis_store.incr('restart_count') - if yaml_config['redis']['clear_stats_on_boot'] == True: - logging.debug("Clearing ActivePeerDict") - redis_store.delete('ActivePeerDict') - else: - logging.debug("Leaving prexisting Redis keys") - #Clear ActivePeerDict - redis_store.delete('ActivePeerDict') - - #Clear Async Keys - for key in redis_store.scan_iter("*_request_queue"): - print("Deleting Key: " + str(key)) - redis_store.delete(key) - logging.info("Connected to Redis server") - else: - logging.info("Init of Logtool but not from HSS_Init") - except: - logging.error("Failed to connect to Redis server - Disabling") - yaml_config['redis']['enabled'] == False - - #function for handling incrimenting Redis counters with error handling - def RedisIncrimenter(self, name): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.incr(name) - except: - logging.error("failed to incriment " + str(name)) - - def RedisStore(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(key, value) - except: - logging.error("failed to set Redis key " + str(key) + " to value " + str(value)) - - def RedisGet(self, key): - if yaml_config['redis']['enabled'] == True: - try: - return self.redis_store.get(key) - except: - logging.error("failed to set Redis key " + str(key)) - - def RedisHMSET(self, key, value_dict): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.hmset(key, value_dict) - except: - logging.error("failed to set hm Redis key " + str(key) + " to value " + str(value_dict)) - - def Async_SendRequest(self, request, DiameterHostname): - if yaml_config['redis']['enabled'] == True: - try: - import time - print("Writing request to Queue '" + str(DiameterHostname) + "_request_queue'") - self.redis_store.hset(str(DiameterHostname) + "_request_queue", "hss_Async_client_" + str(int(time.time())), request) - print("Written to Queue to send.") - except Exception as E: - logging.error("failed to run Async_SendRequest to " + str(DiameterHostname)) - - def RedisHMGET(self, key): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Getting HM Get from " + str(key)) - data = self.redis_store.hgetall(key) - logging.debug("Result: " + str(data)) - return data - except: - logging.error("failed to get hm Redis key " + str(key)) - - def RedisHDEL(self, key, item): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Removing item " + str(item) + " from key " + str(key)) - self.redis_store.hdel(key, item) - except: - logging.error("failed to hdel Redis key " + str(key) + " item " + str(item)) - - def RedisStoreDict(self, key, value): - if yaml_config['redis']['enabled'] == True: - try: - self.redis_store.set(str(key), pickle.dumps(value)) - except: - logging.error("failed to set Redis dict " + str(key) + " to value " + str(value)) - - def RedisGetDict(self, key): - if yaml_config['redis']['enabled'] == True: - try: - read_dict = self.redis_store.get(key) - return pickle.loads(read_dict) - except: - logging.error("failed to hmget Redis key " + str(key)) - - def GetDiameterPeers(self): - if yaml_config['redis']['enabled'] == True: - try: - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - return ActivePeerDict - except: - logging.error("Failed to get ActivePeerDict") - - - def Manage_Diameter_Peer(self, peername, ip, action): - if yaml_config['redis']['enabled'] == True: - try: - logging.debug("Managing Diameter peer to Redis with hostname" + str(peername) + " and IP " + str(ip)) - now = log_dt.now() - timestamp = str(now.strftime("%Y-%m-%d %H:%M:%S")) - - #Try and get IP and Port seperately - try: - ip = ip[0] - port = ip[1] - except: - pass - - if self.redis_store.exists('ActivePeerDict') == False: - #Initialise empty active peer dict in Redis - logging.debug("Populated new empty ActivePeerDict Redis key") - ActivePeerDict = {} - ActivePeerDict['internal_connection'] = {"connect_timestamp" : timestamp} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "add": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict) + " to add peer " + str(peername) + " with ip " + str(ip)) - - - #If key has already existed in dict due to disconnect / reconnect, get reconnection count - try: - reconnection_count = ActivePeerDict[str(ip)]['reconnection_count'] + 1 - except: - reconnection_count = 0 - - ActivePeerDict[str(ip)] = {"connect_timestamp" : timestamp, \ - "recv_ip_address" : str(ip), "DiameterHostname" : "Unknown - Socket connection only", \ - "reconnection_count" : reconnection_count, - "connection_status" : "Pending"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "remove": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - logging.debug("ActivePeerDict back from Redis" + str(ActivePeerDict)) - ActivePeerDict[str(ip)] = {"disconnect_timestamp" : str(timestamp), \ - "DiameterHostname" : str(ActivePeerDict[str(ip)]['DiameterHostname']), \ - "reconnection_count" : ActivePeerDict[str(ip)]['reconnection_count'], - "connection_status" : "Disconnected"} - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - - if action == "update": - data = self.RedisGet('ActivePeerDict') - ActivePeerDict = json.loads(data) - ActivePeerDict[str(ip)]['DiameterHostname'] = str(peername) - ActivePeerDict[str(ip)]['last_dwr_timestamp'] = str(timestamp) - ActivePeerDict[str(ip)]['connection_status'] = "Connected" - self.RedisStore('ActivePeerDict', json.dumps(ActivePeerDict)) - except Exception as E: - logging.error("failed to add/update/remove Diameter peer from Redis") - logging.error(E) - - - def setup_logger(self, logger_name, log_file, level=logging.DEBUG): - l = logging.getLogger(logger_name) - formatter = logging.Formatter('%(asctime)s \t %(levelname)s \t {%(pathname)s:%(lineno)d} \t %(message)s') - fileHandler = logging.FileHandler(log_file, mode='a+') - fileHandler.setFormatter(formatter) - streamHandler = logging.StreamHandler() - streamHandler.setFormatter(formatter) - rolloverHandler = handlers.RotatingFileHandler(log_file, maxBytes=50000000, backupCount=5) - l.setLevel(level) - l.addHandler(fileHandler) - l.addHandler(streamHandler) - l.addHandler(rolloverHandler) diff --git a/PyHSS_API.py b/services/apiService.py similarity index 75% rename from PyHSS_API.py rename to services/apiService.py index 12e7e23..00f9892 100644 --- a/PyHSS_API.py +++ b/services/apiService.py @@ -4,22 +4,51 @@ from flask_restx import Api, Resource, fields, reqparse, abort from werkzeug.middleware.proxy_fix import ProxyFix from functools import wraps -sys.path.append(os.path.realpath('lib')) -import datetime +import os +sys.path.append(os.path.realpath('../lib')) +import time +import requests import traceback import sqlalchemy import socket -import logtool +from logtool import LogTool from diameter import Diameter +from messaging import RedisMessaging import database -import logging import yaml -import os -with open("config.yaml", 'r') as stream: - yaml_config = (yaml.safe_load(stream)) +with open("../config.yaml", 'r') as stream: + config = (yaml.safe_load(stream)) + +siteName = config.get("hss", {}).get("site_name", "") +originHostname = socket.gethostname() +lockProvisioning = config.get('hss', {}).get('lock_provisioning', False) +provisioningKey = config.get('hss', {}).get('provisioning_key', '') +mnc = config.get('hss', {}).get('MNC', '999') +mcc = config.get('hss', {}).get('MCC', '999') +originRealm = config.get('hss', {}).get('OriginRealm', f'mnc{mnc}.mcc{mcc}.3gppnetwork.org') +originHost = config.get('hss', {}).get('OriginHost', f'hss01') +productName = config.get('hss', {}).get('ProductName', f'PyHSS') + +redisHost = config.get("redis", {}).get("host", "127.0.0.1") +redisPort = int(config.get("redis", {}).get("port", 6379)) +redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + +logTool = LogTool(config) -app = Flask(__name__) +diameterClient = Diameter( + redisMessaging=redisMessaging, + logTool=logTool, + originHost=originHost, + originRealm=originRealm, + mnc=mnc, + mcc=mcc, + productName='PyHSS-client-API' + ) + +databaseClient = database.Database(logTool=logTool, redisMessaging=redisMessaging) + +apiService = Flask(__name__) APN = database.APN Serving_APN = database.SERVING_APN @@ -34,20 +63,8 @@ OPERATION_LOG = database.OPERATION_LOG_BASE SUBSCRIBER_ROUTING = database.SUBSCRIBER_ROUTING - -site_name = yaml_config.get("hss", {}).get("site_name", "") -origin_host_name = socket.gethostname() - -diameterClient = Diameter( - OriginHost=yaml_config['hss']['OriginHost'], - OriginRealm=yaml_config['hss']['OriginRealm'], - MNC=yaml_config['hss']['MNC'], - MCC=yaml_config['hss']['MCC'], - ProductName='PyHSS-client-API' - ) - -app.wsgi_app = ProxyFix(app.wsgi_app) -api = Api(app, version='1.0', title=f'{site_name + " - " if site_name else ""}{origin_host_name} - PyHSS OAM API', +apiService.wsgi_app = ProxyFix(apiService.wsgi_app) +api = Api(apiService, version='1.0', title=f'{siteName + " - " if siteName else ""}{originHostname} - PyHSS OAM API', description='Restful API for working with PyHSS', doc='/docs/' ) @@ -73,41 +90,40 @@ paginatorParser = reqparse.RequestParser() paginatorParser.add_argument('page', type=int, required=False, default=0, help='Page number for pagination') -paginatorParser.add_argument('page_size', type=int, required=False, default=yaml_config['api'].get('page_size', 100), help='Number of items per page for pagination') - +paginatorParser.add_argument('page_size', type=int, required=False, default=config['api'].get('page_size', 100), help='Number of items per page for pagination') APN_model = api.schema_model('APN JSON', - database.Generate_JSON_Model_for_Flask(APN) + databaseClient.Generate_JSON_Model_for_Flask(APN) ) Serving_APN_model = api.schema_model('Serving APN JSON', - database.Generate_JSON_Model_for_Flask(Serving_APN) + databaseClient.Generate_JSON_Model_for_Flask(Serving_APN) ) AUC_model = api.schema_model('AUC JSON', - database.Generate_JSON_Model_for_Flask(AUC) + databaseClient.Generate_JSON_Model_for_Flask(AUC) ) SUBSCRIBER_model = api.schema_model('SUBSCRIBER JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER) ) SUBSCRIBER_ROUTING_model = api.schema_model('SUBSCRIBER_ROUTING JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER_ROUTING) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER_ROUTING) ) IMS_SUBSCRIBER_model = api.schema_model('IMS_SUBSCRIBER JSON', - database.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) + databaseClient.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) ) TFT_model = api.schema_model('TFT JSON', - database.Generate_JSON_Model_for_Flask(TFT) + databaseClient.Generate_JSON_Model_for_Flask(TFT) ) CHARGING_RULE_model = api.schema_model('CHARGING_RULE JSON', - database.Generate_JSON_Model_for_Flask(CHARGING_RULE) + databaseClient.Generate_JSON_Model_for_Flask(CHARGING_RULE) ) EIR_model = api.schema_model('EIR JSON', - database.Generate_JSON_Model_for_Flask(EIR) + databaseClient.Generate_JSON_Model_for_Flask(EIR) ) IMSI_IMEI_HISTORY_model = api.schema_model('IMSI_IMEI_HISTORY JSON', - database.Generate_JSON_Model_for_Flask(IMSI_IMEI_HISTORY) + databaseClient.Generate_JSON_Model_for_Flask(IMSI_IMEI_HISTORY) ) SUBSCRIBER_ATTRIBUTES_model = api.schema_model('SUBSCRIBER_ATTRIBUTES JSON', - database.Generate_JSON_Model_for_Flask(SUBSCRIBER_ATTRIBUTES) + databaseClient.Generate_JSON_Model_for_Flask(SUBSCRIBER_ATTRIBUTES) ) PCRF_Push_model = api.model('PCRF_Rule', { @@ -160,9 +176,6 @@ } -lock_provisioning = yaml_config.get('hss', {}).get('lock_provisioning', False) -provisioning_key = yaml_config.get('hss', {}).get('provisioning_key', '') - def no_auth_required(f): f.no_auth_required = True return f @@ -170,9 +183,9 @@ def no_auth_required(f): def auth_required(f): @wraps(f) def decorated_function(*args, **kwargs): - if getattr(f, 'no_auth_required', False) or (lock_provisioning == False): + if getattr(f, 'no_auth_required', False) or (lockProvisioning == False): return f(*args, **kwargs) - if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != yaml_config['hss']['provisioning_key']: + if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != config['hss']['provisioning_key']: return {'Result': 'Unauthorized - Provisioning-Key Invalid'}, 401 return f(*args, **kwargs) return decorated_function @@ -181,12 +194,12 @@ def auth_before_request(): if request.path.startswith('/docs') or request.path.startswith('/swagger') or request.path.startswith('/metrics'): return None if request.endpoint and 'static' not in request.endpoint: - view_function = app.view_functions[request.endpoint] + view_function = apiService.view_functions[request.endpoint] if hasattr(view_function, 'view_class'): view_class = view_function.view_class view_method = getattr(view_class, request.method.lower(), None) if view_method: - if(lock_provisioning == False): + if(lockProvisioning == False): return None if request.method == 'GET' and not getattr(view_method, 'auth_required', False): return None @@ -195,12 +208,13 @@ def auth_before_request(): else: return None - if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != yaml_config['hss']['provisioning_key']: + if 'Provisioning-Key' not in request.headers or request.headers['Provisioning-Key'] != config['hss']['provisioning_key']: return {'Result': 'Unauthorized - Provisioning-Key Invalid'}, 401 return None def handle_exception(e): - logging.error(f"An error occurred: {e}") + + logTool.log(service='API', level='error', message=f"[API] An error occurred: {e}", redisClient=redisMessaging) response_json = {'result': 'Failed'} if isinstance(e, sqlalchemy.exc.SQLAlchemyError): @@ -219,19 +233,18 @@ def handle_exception(e): return response_json, 410 else: response_json['reason'] = f'An internal server error occurred: {e}' - logging.error(f'{traceback.format_exc()}') - logging.error(f'{sys.exc_info()[2]}') + logTool.log(service='API', level='error', message=f"[API] Additional Error Information: {traceback.format_exc()}\n{sys.exc_info()[2]}", redisClient=redisMessaging) return response_json, 500 -app.before_request(auth_before_request) +apiService.before_request(auth_before_request) -@app.errorhandler(404) +@apiService.errorhandler(404) def page_not_found(e): return {"Result": "Not Found"}, 404 -@app.after_request +@apiService.after_request def apply_caching(response): - response.headers["HSS"] = str(yaml_config['hss']['OriginHost']) + response.headers["HSS"] = str(config['hss']['OriginHost']) return response @ns_apn.route('/') @@ -239,7 +252,7 @@ class PyHSS_APN_Get(Resource): def get(self, apn_id): '''Get all APN data for specified APN ID''' try: - apn_data = database.GetObj(APN, apn_id) + apn_data = databaseClient.GetObj(APN, apn_id) return apn_data, 200 except Exception as E: print(E) @@ -250,7 +263,7 @@ def delete(self, apn_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(APN, apn_id, False, operation_id) + data = databaseClient.DeleteObj(APN, apn_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -265,7 +278,7 @@ def patch(self, apn_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_data = database.UpdateObj(APN, json_data, apn_id, False, operation_id) + apn_data = databaseClient.UpdateObj(APN, json_data, apn_id, False, operation_id) print("Updated object") print(apn_data) @@ -285,7 +298,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_id = database.CreateObj(APN, json_data, False, operation_id) + apn_id = databaseClient.CreateObj(APN, json_data, False, operation_id) return apn_id, 200 except Exception as E: @@ -299,7 +312,7 @@ def get(self): '''Get all APNs''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(APN, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(APN, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -310,8 +323,8 @@ class PyHSS_AUC_Get(Resource): def get(self, auc_id): '''Get all AuC data for specified AuC ID''' try: - auc_data = database.GetObj(AUC, auc_id) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.GetObj(AUC, auc_id) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -322,7 +335,7 @@ def delete(self, auc_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(AUC, auc_id, False, operation_id) + data = databaseClient.DeleteObj(AUC, auc_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -337,8 +350,8 @@ def patch(self, auc_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - auc_data = database.UpdateObj(AUC, json_data, auc_id, False, operation_id) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.UpdateObj(AUC, json_data, auc_id, False, operation_id) + auc_data = databaseClient.Sanitize_Keys(auc_data) print("Updated object") print(auc_data) @@ -352,8 +365,8 @@ class PyHSS_AUC_Get_ICCID(Resource): def get(self, iccid): '''Get all AuC data for specified ICCID''' try: - auc_data = database.Get_AuC(iccid=iccid) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.Get_AuC(iccid=iccid) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -364,8 +377,8 @@ class PyHSS_AUC_Get_IMSI(Resource): def get(self, imsi): '''Get all AuC data for specified IMSI''' try: - auc_data = database.Get_AuC(imsi=imsi) - auc_data = database.Sanitize_Keys(auc_data) + auc_data = databaseClient.Get_AuC(imsi=imsi) + auc_data = databaseClient.Sanitize_Keys(auc_data) return auc_data, 200 except Exception as E: print(E) @@ -382,7 +395,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(AUC, json_data, False, operation_id) + data = databaseClient.CreateObj(AUC, json_data, False, operation_id) return data, 200 except Exception as E: @@ -396,7 +409,7 @@ def get(self): '''Get all AuC Data (except keys)''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(AUC, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(AUC, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -407,7 +420,7 @@ class PyHSS_SUBSCRIBER_Get(Resource): def get(self, subscriber_id): '''Get all SUBSCRIBER data for specified subscriber_id''' try: - apn_data = database.GetObj(SUBSCRIBER, subscriber_id) + apn_data = databaseClient.GetObj(SUBSCRIBER, subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -418,7 +431,7 @@ def delete(self, subscriber_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(SUBSCRIBER, subscriber_id, False, operation_id) + data = databaseClient.DeleteObj(SUBSCRIBER, subscriber_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -433,7 +446,7 @@ def patch(self, subscriber_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER, json_data, subscriber_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER, json_data, subscriber_id, False, operation_id) print("Updated object") print(data) @@ -453,7 +466,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER, json_data, False, operation_id) return data, 200 except Exception as E: @@ -465,7 +478,7 @@ class PyHSS_SUBSCRIBER_IMSI(Resource): def get(self, imsi): '''Get data for IMSI''' try: - data = database.Get_Subscriber(imsi=imsi, get_attributes=True) + data = databaseClient.Get_Subscriber(imsi=imsi, get_attributes=True) return data, 200 except Exception as E: print(E) @@ -476,7 +489,7 @@ class PyHSS_SUBSCRIBER_MSISDN(Resource): def get(self, msisdn): '''Get data for MSISDN''' try: - data = database.Get_Subscriber(msisdn=msisdn, get_attributes=True) + data = databaseClient.Get_Subscriber(msisdn=msisdn, get_attributes=True) return data, 200 except Exception as E: print(E) @@ -489,7 +502,7 @@ def get(self): '''Get all Subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(SUBSCRIBER, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(SUBSCRIBER, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -506,7 +519,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER_ROUTING, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER_ROUTING, json_data, False, operation_id) return data, 200 except Exception as E: @@ -518,7 +531,7 @@ class PyHSS_SUBSCRIBER_SUBSCRIBER_ROUTING(Resource): def get(self, subscriber_id, apn_id): '''Get Subscriber Routing for specified subscriber_id & apn_id''' try: - apn_data = database.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) + apn_data = databaseClient.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) return apn_data, 200 except Exception as E: print(E) @@ -529,8 +542,8 @@ def delete(self, subscriber_id, apn_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - apn_data = database.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) - data = database.DeleteObj(SUBSCRIBER_ROUTING, apn_data['subscriber_routing_id'], False, operation_id) + apn_data = databaseClient.Get_SUBSCRIBER_ROUTING(subscriber_id, apn_id) + data = databaseClient.DeleteObj(SUBSCRIBER_ROUTING, apn_data['subscriber_routing_id'], False, operation_id) return data, 200 except Exception as E: print(E) @@ -547,7 +560,7 @@ def patch(self, subscriber_routing_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER_ROUTING, json_data, subscriber_routing_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER_ROUTING, json_data, subscriber_routing_id, False, operation_id) print("Updated object") print(data) @@ -561,7 +574,7 @@ class PyHSS_IMS_SUBSCRIBER_Get(Resource): def get(self, ims_subscriber_id): '''Get all SUBSCRIBER data for specified ims_subscriber_id''' try: - apn_data = database.GetObj(IMS_SUBSCRIBER, ims_subscriber_id) + apn_data = databaseClient.GetObj(IMS_SUBSCRIBER, ims_subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -572,7 +585,7 @@ def delete(self, ims_subscriber_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id, False, operation_id) + data = databaseClient.DeleteObj(IMS_SUBSCRIBER, ims_subscriber_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -587,7 +600,7 @@ def patch(self, ims_subscriber_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(IMS_SUBSCRIBER, json_data, ims_subscriber_id, False, operation_id) + data = databaseClient.UpdateObj(IMS_SUBSCRIBER, json_data, ims_subscriber_id, False, operation_id) print("Updated object") print(data) @@ -607,7 +620,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(IMS_SUBSCRIBER, json_data, False, operation_id) + data = databaseClient.CreateObj(IMS_SUBSCRIBER, json_data, False, operation_id) return data, 200 except Exception as E: @@ -619,7 +632,7 @@ class PyHSS_IMS_SUBSCRIBER_MSISDN(Resource): def get(self, msisdn): '''Get IMS data for MSISDN''' try: - data = database.Get_IMS_Subscriber(msisdn=msisdn) + data = databaseClient.Get_IMS_Subscriber(msisdn=msisdn) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -631,7 +644,7 @@ class PyHSS_IMS_SUBSCRIBER_IMSI(Resource): def get(self, imsi): '''Get IMS data for imsi''' try: - data = database.Get_IMS_Subscriber(imsi=imsi) + data = databaseClient.Get_IMS_Subscriber(imsi=imsi) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -645,7 +658,7 @@ def get(self): '''Get all IMS Subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(IMS_SUBSCRIBER, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(IMS_SUBSCRIBER, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -656,7 +669,7 @@ class PyHSS_TFT_Get(Resource): def get(self, tft_id): '''Get all TFT data for specified tft_id''' try: - apn_data = database.GetObj(TFT, tft_id) + apn_data = databaseClient.GetObj(TFT, tft_id) return apn_data, 200 except Exception as E: print(E) @@ -667,7 +680,7 @@ def delete(self, tft_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(TFT, tft_id, False, operation_id) + data = databaseClient.DeleteObj(TFT, tft_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -682,7 +695,7 @@ def patch(self, tft_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(TFT, json_data, tft_id, False, operation_id) + data = databaseClient.UpdateObj(TFT, json_data, tft_id, False, operation_id) print("Updated object") print(data) @@ -702,7 +715,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(TFT, json_data, False, operation_id) + data = databaseClient.CreateObj(TFT, json_data, False, operation_id) return data, 200 except Exception as E: @@ -716,7 +729,7 @@ def get(self): '''Get all TFTs''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(TFT, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(TFT, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -727,7 +740,7 @@ class PyHSS_Charging_Rule_Get(Resource): def get(self, charging_rule_id): '''Get all Charging Rule data for specified charging_rule_id''' try: - apn_data = database.GetObj(CHARGING_RULE, charging_rule_id) + apn_data = databaseClient.GetObj(CHARGING_RULE, charging_rule_id) return apn_data, 200 except Exception as E: print(E) @@ -738,7 +751,7 @@ def delete(self, charging_rule_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(CHARGING_RULE, charging_rule_id, False, operation_id) + data = databaseClient.DeleteObj(CHARGING_RULE, charging_rule_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -753,7 +766,7 @@ def patch(self, charging_rule_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(CHARGING_RULE, json_data, charging_rule_id, False, operation_id) + data = databaseClient.UpdateObj(CHARGING_RULE, json_data, charging_rule_id, False, operation_id) print("Updated object") print(data) @@ -773,7 +786,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(CHARGING_RULE, json_data, False, operation_id) + data = databaseClient.CreateObj(CHARGING_RULE, json_data, False, operation_id) return data, 200 except Exception as E: @@ -787,7 +800,7 @@ def get(self): '''Get all Charging Rules''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(CHARGING_RULE, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(CHARGING_RULE, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -798,7 +811,7 @@ class PyHSS_EIR_Get(Resource): def get(self, eir_id): '''Get all EIR data for specified eir_id''' try: - eir_data = database.GetObj(EIR, eir_id) + eir_data = databaseClient.GetObj(EIR, eir_id) return eir_data, 200 except Exception as E: print(E) @@ -809,7 +822,7 @@ def delete(self, eir_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(EIR, eir_id, False, operation_id) + data = databaseClient.DeleteObj(EIR, eir_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -824,7 +837,7 @@ def patch(self, eir_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(EIR, json_data, eir_id, False, operation_id) + data = databaseClient.UpdateObj(EIR, json_data, eir_id, False, operation_id) print("Updated object") print(data) @@ -844,7 +857,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(EIR, json_data, False, operation_id) + data = databaseClient.CreateObj(EIR, json_data, False, operation_id) return data, 200 except Exception as E: @@ -856,11 +869,11 @@ class PyHSS_EIR_HISTORY(Resource): def get(self, attribute): '''Get history for IMSI or IMEI''' try: - data = database.Get_IMEI_IMSI_History(attribute=attribute) + data = databaseClient.Get_IMEI_IMSI_History(attribute=attribute) #Add device info for each entry data_w_device_info = [] for record in data: - record['imei_result'] = database.get_device_info_from_TAC(imei=str(record['imei'])) + record['imei_result'] = databaseClient.get_device_info_from_TAC(imei=str(record['imei'])) data_w_device_info.append(record) return data_w_device_info, 200 except Exception as E: @@ -870,9 +883,9 @@ def get(self, attribute): def delete(self, attribute): '''Get Delete for IMSI or IMEI''' try: - data = database.Get_IMEI_IMSI_History(attribute=attribute) + data = databaseClient.Get_IMEI_IMSI_History(attribute=attribute) for record in data: - database.DeleteObj(IMSI_IMEI_HISTORY, record['imsi_imei_history_id']) + databaseClient.DeleteObj(IMSI_IMEI_HISTORY, record['imsi_imei_history_id']) return data, 200 except Exception as E: print(E) @@ -885,7 +898,7 @@ def get(self): '''Get EIR history for all subscribers''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(IMSI_IMEI_HISTORY, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(IMSI_IMEI_HISTORY, args['page'], args['page_size']) for record in data: record['imsi'] = record['imsi_imei'].split(',')[0] record['imei'] = record['imsi_imei'].split(',')[1] @@ -901,7 +914,7 @@ def get(self): '''Get all EIR Rules''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(EIR, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(EIR, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -912,7 +925,7 @@ class PyHSS_EIR_TAC(Resource): def get(self, imei): '''Get Device Info from IMEI''' try: - data = database.get_device_info_from_TAC(imei=imei) + data = databaseClient.get_device_info_from_TAC(imei=imei) return (data), 200 except Exception as E: print(E) @@ -926,7 +939,7 @@ def get(self): '''Get all Subscriber Attributes''' try: args = paginatorParser.parse_args() - data = database.getAllPaginated(SUBSCRIBER_ATTRIBUTES, args['page'], args['page_size']) + data = databaseClient.getAllPaginated(SUBSCRIBER_ATTRIBUTES, args['page'], args['page_size']) return (data), 200 except Exception as E: print(E) @@ -937,7 +950,7 @@ class PyHSS_Attributes_Get(Resource): def get(self, subscriber_id): '''Get all attributes / values for specified Subscriber ID''' try: - apn_data = database.Get_Subscriber_Attributes(subscriber_id) + apn_data = databaseClient.Get_Subscriber_Attributes(subscriber_id) return apn_data, 200 except Exception as E: print(E) @@ -950,7 +963,7 @@ def delete(self, subscriber_attributes_id): try: args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.DeleteObj(SUBSCRIBER_ATTRIBUTES, subscriber_attributes_id, False, operation_id) + data = databaseClient.DeleteObj(SUBSCRIBER_ATTRIBUTES, subscriber_attributes_id, False, operation_id) return data, 200 except Exception as E: print(E) @@ -965,7 +978,7 @@ def patch(self, subscriber_attributes_id): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.UpdateObj(SUBSCRIBER_ATTRIBUTES, json_data, subscriber_attributes_id, False, operation_id) + data = databaseClient.UpdateObj(SUBSCRIBER_ATTRIBUTES, json_data, subscriber_attributes_id, False, operation_id) print("Updated object") print(data) @@ -985,7 +998,7 @@ def put(self): print("JSON Data sent: " + str(json_data)) args = parser.parse_args() operation_id = args.get('operation_id', None) - data = database.CreateObj(SUBSCRIBER_ATTRIBUTES, json_data, False, operation_id) + data = databaseClient.CreateObj(SUBSCRIBER_ATTRIBUTES, json_data, False, operation_id) return data, 200 except Exception as E: @@ -999,7 +1012,7 @@ def get(self): '''Get all Operation Logs''' try: args = paginatorParser.parse_args() - OperationLogs = database.get_all_operation_logs(args['page'], args['page_size']) + OperationLogs = databaseClient.get_all_operation_logs(args['page'], args['page_size']) return OperationLogs, 200 except Exception as E: print(E) @@ -1010,7 +1023,7 @@ class PyHSS_Operation_Log_Last(Resource): def get(self): '''Get the most recent Operation Log''' try: - OperationLogs = database.get_last_operation_log() + OperationLogs = databaseClient.get_last_operation_log() return OperationLogs, 200 except Exception as E: print(E) @@ -1023,7 +1036,7 @@ def get(self, table_name): '''Get all Operation Logs for a given table''' try: args = paginatorParser.parse_args() - OperationLogs = database.get_all_operation_logs_by_table(table_name, args['page'], args['page_size']) + OperationLogs = databaseClient.get_all_operation_logs_by_table(table_name, args['page'], args['page_size']) return OperationLogs, 200 except Exception as E: print(E) @@ -1034,11 +1047,8 @@ class PyHSS_OAM_Peers(Resource): def get(self): '''Get all Diameter Peers''' try: - #@@Fixme - # logObj = logtool.LogTool() - # DiameterPeers = logObj.GetDiameterPeers() - # return DiameterPeers, 200 - return '' + diameterPeers = redisMessaging.getValue("ActiveDiameterPeers") + return diameterPeers, 200 except Exception as E: print(E) return handle_exception(E) @@ -1060,7 +1070,7 @@ class PyHSS_OAM_Rollback_Last(Resource): def get(self): '''Undo the last Insert/Update/Delete operation''' try: - RollbackResponse = database.rollback_last_change() + RollbackResponse = databaseClient.rollback_last_change() return RollbackResponse, 200 except Exception as E: print(E) @@ -1072,7 +1082,7 @@ class PyHSS_OAM_Rollback_Last_Table(Resource): def get(self, operation_id): '''Undo the last Insert/Update/Delete operation for a given operation id''' try: - RollbackResponse = database.rollback_change_by_operation_id(operation_id) + RollbackResponse = databaseClient.rollback_change_by_operation_id(operation_id) return RollbackResponse, 200 except Exception as E: print(E) @@ -1083,7 +1093,7 @@ class PyHSS_OAM_Serving_Subs(Resource): def get(self): '''Get all Subscribers served by HSS''' try: - data = database.Get_Served_Subscribers() + data = databaseClient.Get_Served_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1095,7 +1105,7 @@ class PyHSS_OAM_Serving_Subs_PCRF(Resource): def get(self): '''Get all Subscribers served by PCRF''' try: - data = database.Get_Served_PCRF_Subscribers() + data = databaseClient.Get_Served_PCRF_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1107,7 +1117,7 @@ class PyHSS_OAM_Serving_Subs_IMS(Resource): def get(self): '''Get all Subscribers served by IMS''' try: - data = database.Get_Served_IMS_Subscribers() + data = databaseClient.Get_Served_IMS_Subscribers() print("Got back served Subs: " + str(data)) return data, 200 except Exception as E: @@ -1119,16 +1129,15 @@ class PyHSS_OAM_Reconcile_IMS(Resource): def get(self, imsi): '''Get current location of IMS Subscriber from all linked HSS nodes''' response_dict = {} - import requests try: #Get local database result - local_result = database.Get_IMS_Subscriber(imsi=imsi) + local_result = databaseClient.Get_IMS_Subscriber(imsi=imsi) response_dict['localhost'] = {} for keys in local_result: if 'cscf' in keys: response_dict['localhost'][keys] = local_result[keys] - for remote_HSS in yaml_config['geored']['sync_endpoints']: + for remote_HSS in config['geored']['sync_endpoints']: print("Pulling data from remote HSS: " + str(remote_HSS)) try: response = requests.get(remote_HSS + '/ims_subscriber/ims_subscriber_imsi/' + str(imsi)) @@ -1171,9 +1180,9 @@ def get(self, imsi): serving_sub_final['apns'] = {} #Resolve Subscriber ID - subscriber_data = database.Get_Subscriber(imsi=str(imsi)) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(imsi)) print("subscriber_data: " + str(subscriber_data)) - serving_sub_final['subscriber_data'] = database.Sanitize_Datetime(subscriber_data) + serving_sub_final['subscriber_data'] = databaseClient.Sanitize_Datetime(subscriber_data) #Split the APN list into a list apn_list = subscriber_data['apn_list'].split(',') @@ -1191,11 +1200,11 @@ def get(self, imsi): #Get APN ID from APN for list_apn_id in apn_list: print("Getting APN ID " + str(list_apn_id)) - apn_data = database.Get_APN(list_apn_id) + apn_data = databaseClient.Get_APN(list_apn_id) print(apn_data) try: serving_sub_final['apns'][str(apn_data['apn'])] = {} - serving_sub_final['apns'][str(apn_data['apn'])] = database.Sanitize_Datetime(database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=list_apn_id)) + serving_sub_final['apns'][str(apn_data['apn'])] = databaseClient.Sanitize_Datetime(databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=list_apn_id)) except: serving_sub_final['apns'][str(apn_data['apn'])] = {} print("Failed to get Serving APN for APN ID " + str(list_apn_id)) @@ -1217,7 +1226,7 @@ def get(self, imsi, apn_id): apn_id_final = None #Resolve Subscriber ID - subscriber_data = database.Get_Subscriber(imsi=str(imsi)) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(imsi)) print("subscriber_data: " + str(subscriber_data)) #Split the APN list into a list @@ -1236,14 +1245,14 @@ def get(self, imsi, apn_id): for list_apn_id in apn_list: print("Getting APN ID " + str(list_apn_id) + " to see if it matches APN " + str(apn_id)) #Get each APN in List - apn_data = database.Get_APN(list_apn_id) + apn_data = databaseClient.Get_APN(list_apn_id) print(apn_data) if str(apn_data['apn_id']).lower() == str(apn_id).lower(): print("Matched named APN with APN ID") apn_id_final = apn_data['apn_id'] - data = database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=apn_id_final) - data = database.Sanitize_Datetime(data) + data = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=apn_id_final) + data = databaseClient.Sanitize_Datetime(data) print("Got back: " + str(data)) return data, 200 except Exception as E: @@ -1261,36 +1270,34 @@ def put(self): json_data = request.get_json(force=True) print("JSON Data sent: " + str(json_data)) #Get IMSI - subscriber_data = database.Get_Subscriber(imsi=str(json_data['imsi'])) + subscriber_data = databaseClient.Get_Subscriber(imsi=str(json_data['imsi'])) print("subscriber_data: " + str(subscriber_data)) #Get PCRF Session - pcrf_session_data = database.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) + pcrf_session_data = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) print("pcrf_session_data: " + str(pcrf_session_data)) #Get Charging Rules - ChargingRule = database.Get_Charging_Rule(json_data['charging_rule_id']) - ChargingRule['apn_data'] = database.Get_APN(json_data['apn_id']) + ChargingRule = databaseClient.Get_Charging_Rule(json_data['charging_rule_id']) + ChargingRule['apn_data'] = databaseClient.Get_APN(json_data['apn_id']) print("Got ChargingRule: " + str(ChargingRule)) - diameter_host = yaml_config['hss']['OriginHost'] #Diameter Host of this Machine - OriginRealm = yaml_config['hss']['OriginRealm'] - DestinationRealm = OriginRealm - mcc = yaml_config['hss']['MCC'] #Mobile Country Code - mnc = yaml_config['hss']['MNC'] #Mobile Network Code - diam_hex = diameterClient.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') - import time - # @@Fixme - # logObj = logtool.LogTool() - # logObj.Async_SendRequest(diam_hex, str(pcrf_session_data['serving_pgw'])) - return diam_hex, 200 + diameterRequest = diameterClient.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') + connectedPgws = diameterClient.getConnectedPeersByType('pgw') + for connectedPgw in connectedPgws: + outboundQueue = f"diameter-outbound-{connectedPgw.get('ipAddress')}-{connectedPgw.get('port')}-{time.time_ns()}" + outboundMessage = json.dumps({"diameter-outbound": diameterRequest}) + redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + + result = {"request": diameterRequest, "destinationClients": connectedPgws} + return result, 200 @ns_pcrf.route('/') class PyHSS_PCRF_Complete(Resource): def get(self, charging_rule_id): '''Get full Charging Rule + TFTs''' try: - data = database.Get_Charging_Rule(charging_rule_id) + data = databaseClient.Get_Charging_Rule(charging_rule_id) return data, 200 except Exception as E: print(E) @@ -1301,39 +1308,41 @@ class PyHSS_PCRF_SUBSCRIBER_ROUTING(Resource): def get(self, subscriber_routing): '''Get Subscriber info from Subscriber Routing''' try: - data = database.Get_UE_by_IP(subscriber_routing) + data = databaseClient.Get_UE_by_IP(subscriber_routing) return data, 200 except Exception as E: print(E) return handle_exception(E) @ns_geored.route('/') - class PyHSS_Geored(Resource): @ns_geored.doc('Create ChargingRule Object') @ns_geored.expect(GeoRed_model) - # @metrics.counter('flask_http_geored_pushes', 'Count of GeoRed Pushes to this API', - # labels={'status': lambda r: r.status_code, 'source_endpoint': lambda r: r.remote_addr}) @no_auth_required def patch(self): '''Get Geored data Pushed''' try: json_data = request.get_json(force=True) print("JSON Data sent in Geored request: " + str(json_data)) - #Determine what actions to take / update based on keys returned response_data = [] if 'serving_mme' in json_data: print("Updating serving MME") - response_data.append(database.Update_Serving_MME(imsi=str(json_data['imsi']), serving_mme=json_data['serving_mme'], serving_mme_realm=json_data['serving_mme_realm'], serving_mme_peer=json_data['serving_mme_peer'], propagate=False)) - #@@Fixme - # prom_flask_http_geored_endpoints.labels(endpoint='HSS', geored_host=request.remote_addr).inc() + response_data.append(databaseClient.Update_Serving_MME(imsi=str(json_data['imsi']), serving_mme=json_data['serving_mme'], serving_mme_realm=json_data['serving_mme_realm'], serving_mme_peer=json_data['serving_mme_peer'], propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "HSS", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'serving_apn' in json_data: print("Updating serving APN") if 'serving_pgw_realm' not in json_data: json_data['serving_pgw_realm'] = None if 'serving_pgw_peer' not in json_data: json_data['serving_pgw_peer'] = None - response_data.append(database.Update_Serving_APN( + response_data.append(databaseClient.Update_Serving_APN( imsi=str(json_data['imsi']), apn=json_data['serving_apn'], pcrf_session_id=json_data['pcrf_session_id'], @@ -1342,22 +1351,40 @@ def patch(self): serving_pgw_realm=json_data['serving_pgw_realm'], serving_pgw_peer=json_data['serving_pgw_peer'], propagate=False)) - #@@Fixme - # prom_flask_http_geored_endpoints.labels(endpoint='PCRF', geored_host=request.remote_addr).inc() + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "PCRF", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'scscf' in json_data: print("Updating serving SCSCF") if 'scscf_realm' not in json_data: json_data['scscf_realm'] = None if 'scscf_peer' not in json_data: json_data['scscf_peer'] = None - response_data.append(database.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), propagate=False)) - #@@Fixme - # prom_flask_http_geored_endpoints.labels(endpoint='IMS', geored_host=request.remote_addr).inc() + response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMS", + "geored_host": request.remote_addr, + }, + metricExpiry=60) if 'imei' in json_data: print("Updating EIR") - response_data.append(database.Store_IMSI_IMEI_Binding(str(json_data['imsi']), str(json_data['imei']), str(json_data['match_response_code']), propagate=False)) - #@@Fixme - # prom_flask_http_geored_endpoints.labels(endpoint='EIR', geored_host=request.remote_addr).inc() + response_data.append(databaseClient.Store_IMSI_IMEI_Binding(str(json_data['imsi']), str(json_data['imei']), str(json_data['match_response_code']), propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMEI", + "geored_host": request.remote_addr, + }, + metricExpiry=60) return response_data, 200 except Exception as E: print("Exception when updating: " + str(E)) @@ -1395,5 +1422,5 @@ def put(self, imsi): return diam_hex, 200 if __name__ == '__main__': - app.run(debug=False) + apiService.run(debug=False) diff --git a/services/diameterService.py b/services/diameterService.py index 25d847d..9d49d75 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -7,7 +7,6 @@ from diameterAsync import DiameterAsync from banners import Banners from logtool import LogTool -import traceback class DiameterService: """ @@ -28,7 +27,9 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.banners = Banners() self.logTool = LogTool(config=self.config) self.diameterLibrary = DiameterAsync() - self.activeConnections = {} + self.activePeers = {} + self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) + self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ @@ -39,9 +40,11 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb messageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(inboundData)) originHost = (await self.diameterLibrary.getAvpDataAsync(avps, 264))[0] originHost = bytes.fromhex(originHost).decode("utf-8") - self.activeConnections[f"{clientAddress}-{clientPort}"].update({'last_dwr_timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activeConnections[f"{clientAddress}:{clientPort}"]['last_dwr_timestamp'], - 'DiameterHostname': originHost, - }) + peerType = await(self.diameterLibrary.getPeerType(originHost)) + self.activePeers[f"{clientAddress}-{clientPort}"].update({'lastDwrTimestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activePeers[f"{clientAddress}-{clientPort}"]['lastDwrTimestamp'], + 'diameterHostname': originHost, + 'peerType': peerType, + }) asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Diameter Inbounds', @@ -52,28 +55,37 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb "type": "inbound"}, metricExpiry=60)) except Exception as e: - print(e) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}", redisClient=self.redisMessaging)) return False - return True + return TruediameterHostname async def handleActiveDiameterPeers(self): """ - Prunes stale connection entries from self.activeConnections. + Prunes stale entries from self.activePeers, and + keeps the ActiveDiameterPeers key in Redis current. """ while True: try: - if not len(self.activeConnections) > 0: - await(asyncio.sleep(1)) + if not len(self.activePeers) > 0: + await(asyncio.sleep(0)) continue - activeDiameterPeersTimeout = self.config.get('hss', {}).get('active_diameter_peers_timeout', 86400) + activeDiameterPeersTimeout = self.config.get('hss', {}).get('active_diameter_peers_timeout', 3600) + + stalePeers = [] - for key, connection in self.activeConnections.items(): - if connection.get('connection_status', '') == 'disconnected': - if (datetime.now() - datetime.strptime(connection['connect_timestamp'], "%Y-%m-%d %H:%M:%S")).seconds > activeDiameterPeersTimeout: - del self.activeConnections[key] + for key, connection in self.activePeers.items(): + if connection.get('connectionStatus', '') == 'disconnected': + if (datetime.now() - datetime.strptime(connection['disconnectTimestamp'], "%Y-%m-%d %H:%M:%S")).seconds > activeDiameterPeersTimeout: + stalePeers.append(key) - await(self.redisMessaging.sendMessage(queue='ActiveDiameterPeers', message=json.dumps(self.activeConnections))) + if len(stalePeers) > 0: + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [handleActiveDiameterPeers] Pruning disconnected peers: {stalePeers}", redisClient=self.redisMessaging)) + for key in stalePeers: + del self.activePeers[key] + await(self.logActivePeers()) + + await(self.redisMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) await(asyncio.sleep(1)) except Exception as e: @@ -81,14 +93,14 @@ async def handleActiveDiameterPeers(self): await(asyncio.sleep(1)) continue - async def logActiveConnections(self): + async def logActivePeers(self): """ Logs the number of active connections on a rolling basis. """ - activeConnections = self.activeConnections - if not len(activeConnections) > 0: - activeConnections = '' - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActiveConnections] {len(self.activeConnections)} Active Connections {activeConnections}", redisClient=self.redisMessaging)) + activePeers = self.activePeers + if not len(activePeers) > 0: + activePeers = '' + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActivePeers] {len(self.activePeers)} Active Peers {activePeers}", redisClient=self.redisMessaging)) async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ @@ -99,7 +111,10 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc while True: try: - inboundData = await(asyncio.wait_for(reader.read(1024), timeout=socketTimeout)) + inboundData = await(asyncio.wait_for(reader.read(8192), timeout=socketTimeout)) + + if self.benchmarking: + startTime = time.perf_counter() if reader.at_eof(): await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.", redisClient=self.redisMessaging)) @@ -118,7 +133,10 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=60)) + asyncio.ensure_future(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) + if self.benchmarking: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}", redisClient=self.redisMessaging)) @@ -131,6 +149,8 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) while True: try: + if self.benchmarking: + startTime = time.perf_counter() if writer.transport.is_closing(): return False @@ -139,6 +159,7 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s if not len(pendingOutboundQueues) > 0: await(asyncio.sleep(0)) continue + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queues: {pendingOutboundQueues}", redisClient=self.redisMessaging)) for outboundQueue in pendingOutboundQueues: outboundQueueSplit = str(outboundQueue).split('-') @@ -156,6 +177,9 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s writer.write(diameterOutboundBinary) await(writer.drain()) await(asyncio.sleep(0)) + if self.benchmarking: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + except Exception: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.", redisClient=self.redisMessaging)) @@ -164,21 +188,39 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s async def handleConnection(self, reader, writer): """ - For each new connection on port 3868, create an asynchronous reader and writer. If a reader or writer returns false, ensure that the connection is torn down entirely. + For each new connection on port 3868, create an asynchronous reader and writer, and handle adding and updating self.activePeers. + If a reader or writer returns false, ensure that the connection is torn down entirely. """ try: coroutineUuid = str(uuid.uuid4()) (clientAddress, clientPort) = writer.get_extra_info('peername') await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) - if f"{clientAddress}-{clientPort}" not in self.activeConnections: - self.activeConnections[f"{clientAddress}-{clientPort}"] = {} - self.activeConnections[f"{clientAddress}-{clientPort}"].update({ - "connect_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "recv_ip_address":clientAddress, - "recv_ip_port":clientAddress, - "connection_status": 'connected', + if f"{clientAddress}-{clientPort}" not in self.activePeers: + self.activePeers[f"{clientAddress}-{clientPort}"] = { + "connectTimestamp": '', + "disconnectTimestamp": '', + "reconnectionCount": 0, + "ipAddress":'', + "port":'', + "connectionStatus": '', + "lastDwrTimestamp": '', + "diameterHostname": '', + "peerType": '', + } + else: + reconnectionCount = self.activePeers.get(f"{clientAddress}-{clientPort}", {}).get('reconnectionCount', 0) + reconnectionCount += 1 + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "reconnectionCount": reconnectionCount + }) + + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "connectTimestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "ipAddress":clientAddress, + "port": clientPort, + "connectionStatus": 'connected', }) - await(self.logActiveConnections()) + await(self.logActivePeers()) readTask = asyncio.create_task(self.readInboundData(reader=reader, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) writeTask = asyncio.create_task(self.writeOutboundData(writer=writer, clientAddress=clientAddress, clientPort=clientPort, socketTimeout=self.socketTimeout, coroutineUuid=coroutineUuid)) @@ -194,15 +236,16 @@ async def handleConnection(self, reader, writer): writer.close() await(writer.wait_closed()) - self.activeConnections[f"{clientAddress}-{clientPort}"].update({ - "connection_status": 'disconnected', + self.activePeers[f"{clientAddress}-{clientPort}"].update({ + "connectionStatus": 'disconnected', + "disconnectTimestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }) await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.", redisClient=self.redisMessaging)) - await(self.logActiveConnections()) + await(self.logActivePeers()) return except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}\n{traceback.format_exc()}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}", redisClient=self.redisMessaging)) return async def startServer(self, host: str=None, port: int=None, type: str=None): @@ -236,4 +279,4 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): if __name__ == '__main__': diameterService = DiameterService() - asyncio.run(diameterService.startServer(), debug=True) \ No newline at end of file + asyncio.run(diameterService.startServer()) \ No newline at end of file diff --git a/services/georedService.py b/services/georedService.py index 7194c11..39db3e4 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -1,11 +1,16 @@ import os, sys, json, yaml -import requests, uuid +import uuid, time +import asyncio, aiohttp sys.path.append(os.path.realpath('../lib')) -from messaging import RedisMessaging +from messagingAsync import RedisMessagingAsync from banners import Banners from logtool import LogTool class GeoredService: + """ + PyHSS Geored Service + Handles updating and sending webhooks to remote endpoints. + """ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): try: @@ -16,19 +21,26 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.logTool = LogTool(self.config) self.banners = Banners() - self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) - self.remotePeers = self.config.get('geored', {}).get('sync_endpoints', []) + self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.georedPeers = self.config.get('geored', {}).get('endpoints', []) + self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) + self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + if not self.config.get('geored', {}).get('enabled'): self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") quit() - if not (len(self.remotePeers) > 0): + if not (len(self.georedPeers) > 0): self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") quit() - self.logTool.log(service='Geored', level='info', message=f"{self.banners.georedService()}", redisClient=self.redisMessaging) - def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + async def sendGeored(self, asyncSession, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + """ + Sends a Geored HTTP request to a given endpoint. + """ + if self.benchmarking: + startTime = time.perf_counter() operation = operation.upper() - requestOperations = {'GET': requests.get, 'PUT': requests.put, 'POST': requests.post, 'PATCH':requests.patch, 'DELETE': requests.delete} + requestOperations = {'GET': asyncSession.get, 'PUT': asyncSession.put, 'POST': asyncSession.post, 'PATCH':asyncSession.patch, 'DELETE': asyncSession.delete} if not url or not operation or not body: return False @@ -37,41 +49,49 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui return False headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} - + for attempt in range(retryCount): try: + responseStatusCode = None + responseBody = None + if operation in ['PUT', 'POST', 'PATCH']: - response = requestOperations[operation](url, json=body, headers=headers) + async with requestOperations[operation](url, json=body, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + else: + async with requestOperations[operation](url, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + + if 200 <= responseStatusCode <= 299: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}", redisClient=self.redisMessaging)) + + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(responseStatusCode), + "error": ""}, + metricExpiry=60)) + break else: - response = requestOperations[operation](url, headers=headers) - if 200 <= response.status_code <= 299: - self.logTool.log(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}", redisClient=self.redisMessaging) - - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', - metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Geored Pushes', - metricLabels={ - "geored_host": str(url.replace('https://', '').replace('http://', '')), - "endpoint": "geored", - "http_response_code": str(response.status_code), - "error": ""}, - metricExpiry=60) - break - else: - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', - metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Geored Pushes', - metricLabels={ - "geored_host": str(url.replace('https://', '').replace('http://', '')), - "endpoint": "geored", - "http_response_code": str(response.status_code), - "error": str(response.reason)}, - metricExpiry=60) - except requests.exceptions.ConnectionError as e: + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "geored_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "geored", + "http_response_code": str(responseStatusCode), + "error": str(response.reason)}, + metricExpiry=60)) + except aiohttp.ClientConnectionError as e: error_message = str(e) - self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) if "Name or service not known" in error_message: - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -79,9 +99,9 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "endpoint": "geored", "http_response_code": "000", "error": "No matching DNS entry found"}, - metricExpiry=60) + metricExpiry=60)) else: - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -89,10 +109,10 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "endpoint": "geored", "http_response_code": "000", "error": "Connection Refused"}, - metricExpiry=60) - except requests.exceptions.Timeout: - self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + metricExpiry=60)) + except aiohttp.ServerTimeoutError: + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -100,10 +120,10 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "endpoint": "geored", "http_response_code": "000", "error": "Timeout"}, - metricExpiry=60) + metricExpiry=60)) except Exception as e: - self.logTool.log(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -111,55 +131,69 @@ def sendGeored(self, url: str, operation: str, body: str, transactionId: str=uui "endpoint": "geored", "http_response_code": "000", "error": e}, - metricExpiry=60) + metricExpiry=60)) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendGeored] Time taken to send individual geored request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + return True - def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + """ + Sends a Webhook HTTP request to a given endpoint. + """ + if self.benchmarking: + startTime = time.perf_counter() operation = operation.upper() - requestOperations = {'GET': requests.get, 'PUT': requests.put, 'POST': requests.post, 'PATCH':requests.patch, 'DELETE': requests.delete} + requestOperations = {'GET': asyncSession.get, 'PUT': asyncSession.put, 'POST': asyncSession.post, 'PATCH':asyncSession.patch, 'DELETE': asyncSession.delete} - if not url or not operation or not body: + if not url or not operation or not body or not headers: return False if operation not in requestOperations: return False - - headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} - + for attempt in range(retryCount): try: + responseStatusCode = None + responseBody = None + if operation in ['PUT', 'POST', 'PATCH']: - response = requestOperations[operation](url, json=body, headers=headers) + async with requestOperations[operation](url, json=body, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status else: - response = requestOperations[operation](url, headers=headers) - if 200 <= response.status_code <= 299: - self.logTool.log(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}", redisClient=self.redisMessaging) - - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', - metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Webhook Pushes', - metricLabels={ - "webhook_host": str(url.replace('https://', '').replace('http://', '')), - "endpoint": "webhook", - "http_response_code": str(response.status_code), - "error": ""}, - metricExpiry=60) - break - else: - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', - metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Webhook Pushes', - metricLabels={ - "webhook_host": str(url.replace('https://', '').replace('http://', '')), - "endpoint": "webhook", - "http_response_code": str(response.status_code), - "error": str(response.reason)}, - metricExpiry=60) - except requests.exceptions.ConnectionError as e: + async with requestOperations[operation](url, headers=headers) as response: + responseBody = await(response.text()) + responseStatusCode = response.status + + if 200 <= responseStatusCode <= 299: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}", redisClient=self.redisMessaging)) + + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(responseStatusCode), + "error": ""}, + metricExpiry=60)) + break + else: + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Webhook Pushes', + metricLabels={ + "webhook_host": str(url.replace('https://', '').replace('http://', '')), + "endpoint": "webhook", + "http_response_code": str(responseStatusCode), + "error": str(response.reason)}, + metricExpiry=60)) + except aiohttp.ClientConnectionError as e: error_message = str(e) - self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) if "Name or service not known" in error_message: - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -167,9 +201,9 @@ def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uu "endpoint": "webhook", "http_response_code": "000", "error": "No matching DNS entry found"}, - metricExpiry=60) + metricExpiry=60)) else: - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webook', + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -177,10 +211,10 @@ def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uu "endpoint": "webhook", "http_response_code": "000", "error": "Connection Refused"}, - metricExpiry=60) - except requests.exceptions.Timeout: - self.logTool.log(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + metricExpiry=60)) + except aiohttp.ServerTimeoutError: + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -188,10 +222,10 @@ def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uu "endpoint": "webhook", "http_response_code": "000", "error": "Timeout"}, - metricExpiry=60) + metricExpiry=60)) except Exception as e: - self.logTool.log(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {response.status_code}. Error Message: {e}", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendWebhook] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -199,32 +233,120 @@ def sendWebhook(self, url: str, operation: str, body: str, transactionId: str=uu "endpoint": "webhook", "http_response_code": "000", "error": e}, - metricExpiry=60) + metricExpiry=60)) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendWebhook] Time taken to send individual webhook request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + return True + + async def handleGeoredQueue(self): + """ + Collects and processes queued geored messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + georedQueue = await(self.redisMessaging.getNextQueue(pattern='geored-*')) + georedMessage = await(self.redisMessaging.getMessage(queue=georedQueue)) + if len(georedMessage) > 0: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}", redisClient=self.redisMessaging)) + georedDict = json.loads(georedMessage) + georedOperation = georedDict['operation'] + georedBody = georedDict['body'] + georedTasks = [] - def handleQueue(self): - try: - georedQueue = self.redisMessaging.getNextQueue(pattern='geored-*') - georedMessage = self.redisMessaging.getMessage(queue=georedQueue) - assert(len(georedMessage)) - self.logTool.log(service='Geored', level='debug', message=f"[Geored] Queue: {georedQueue}", redisClient=self.redisMessaging) - self.logTool.log(service='Geored', level='debug', message=f"[Geored] Message: {georedMessage}", redisClient=self.redisMessaging) + for remotePeer in self.georedPeers: + georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody)) + await asyncio.gather(*georedTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + + await(asyncio.sleep(0)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}", redisClient=self.redisMessaging)) + await(asyncio.sleep(0)) + continue + + async def handleWebhookQueue(self): + """ + Collects and processes queued webhook messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + webhookQueue = await(self.redisMessaging.getNextQueue(pattern='webhook-*')) + webhookMessage = await(self.redisMessaging.getMessage(queue=webhookQueue)) + if len(webhookMessage) > 0: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}", redisClient=self.redisMessaging)) + + webhookDict = json.loads(webhookMessage) + webhookHeaders = webhookDict['headers'] + webhookOperation = webhookDict['operation'] + webhookBody = webhookDict['body'] + webhookTasks = [] + + for remotePeer in self.webhookPeers: + webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders)) + await asyncio.gather(*webhookTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + + await(asyncio.sleep(0)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}", redisClient=self.redisMessaging)) + await(asyncio.sleep(0)) + continue + + async def startService(self): + """ + Performs sanity checks on configuration and starts the geored and webhook tasks, when enabled. + """ + await(self.logTool.logAsync(service='Geored', level='info', message=f"{self.banners.georedService()}", redisClient=self.redisMessaging)) + while True: + + georedEnabled = self.config.get('geored', {}).get('enabled', False) + webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) + + if not len(self.georedPeers) > 0: + georedEnabled = False + + if not len(self.webhookPeers) > 0: + webhooksEnabled = False + + if not georedEnabled and not webhooksEnabled: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [startService] Geored and Webhook services both disabled or missing peers, exiting.", redisClient=self.redisMessaging)) + sys.exit() + + activeTasks = [] + + if georedEnabled: + georedTask = asyncio.create_task(self.handleGeoredQueue()) + activeTasks.append(georedTask) + + if webhooksEnabled: + webhookTask = asyncio.create_task(self.handleWebhookQueue()) + activeTasks.append(webhookTask) - georedDict = json.loads(georedMessage) - georedOperation = georedDict['operation'] - georedBody = georedDict['body'] + completeTasks, pendingTasks = await(asyncio.wait(activeTasks, return_when=asyncio.FIRST_COMPLETED)) - try: - for remotePeer in self.remotePeers: - self.sendGeored(url=remotePeer+'/geored/', operation=georedOperation, body=georedBody) - except Exception as e: - self.logTool.log(service='Geored', level='debug', message=f"[Geored] Error sending geored message: {e}", redisClient=self.redisMessaging) + if len(pendingTasks) > 0: + for pendingTask in pendingTasks: + try: + pendingTask.cancel() + await(asyncio.sleep(0)) + except asyncio.CancelledError: + pass - except Exception as e: - return False if __name__ == '__main__': georedService = GeoredService() - while True: - georedService.handleQueue() \ No newline at end of file + asyncio.run(georedService.startService()) \ No newline at end of file diff --git a/services/hssService.py b/services/hssService.py index 3557dab..cc4c0b7 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml +import os, sys, json, yaml, time sys.path.append(os.path.realpath('../lib')) from messaging import RedisMessaging from diameter import Diameter @@ -25,6 +25,8 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) self.diameterLibrary = Diameter(redisMessaging=self.redisMessaging, logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) + self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + def handleQueue(self): """ @@ -32,6 +34,8 @@ def handleQueue(self): """ while True: try: + if self.benchmarking: + startTime = time.perf_counter() inboundQueue = self.redisMessaging.getNextQueue(pattern='diameter-inbound*') inboundMessage = self.redisMessaging.getMessage(queue=inboundQueue) assert(len(inboundMessage)) @@ -66,6 +70,9 @@ def handleQueue(self): self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + if self.benchmarking: + self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) + except Exception as e: continue diff --git a/services/metricService.py b/services/metricService.py index be0d0a7..09c6be2 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -27,18 +27,24 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.logTool.log(service='Metric', level='info', message=f"{self.banners.metricService()}", redisClient=self.redisMessaging) def handleMetrics(self): + """ + Collects queued metrics from redis, and exposes them using prometheus_client. + """ try: actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} metricQueue = self.redisMessaging.getNextQueue(pattern='metric-*') metric = self.redisMessaging.getMessage(queue=metricQueue) + if not (len(metric) > 0): return - self.logTool.log(service='Metric', level='debug', message=f"Received Metric: {metric}", redisClient=self.redisMessaging) + + self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) prometheusJsonList = json.loads(metric) + for prometheusJson in prometheusJsonList: - self.logTool.log(service='Metric', level='debug', message=f"{prometheusJson}", redisClient=self.redisMessaging) + self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] {prometheusJson}", redisClient=self.redisMessaging) if not all(key in prometheusJson for key in ('NAME', 'TYPE', 'ACTION', 'VALUE')): raise ValueError('All fields are not available for parsing') counterName = prometheusJson['NAME'] @@ -62,18 +68,17 @@ def handleMetrics(self): counterRecord = counterRecord.labels(*counterLabels.values()) action = actions.get(counterAction) if action is not None: - # Here we dynamically lookup the class from prometheus_client, and grab the matched function name called 'action'. prometheusMethod = getattr(counterRecord, action) prometheusMethod(counterValue) else: - self.logTool.log(service='Metric', level='warn', message=f"Invalid action '{counterAction}' in message: {metric}, skipping.", redisClient=self.redisMessaging) + self.logTool.log(service='Metric', level='warn', message=f"[Metric] [handleMetrics] Invalid action '{counterAction}' in message: {metric}, skipping.", redisClient=self.redisMessaging) continue else: - self.logTool.log(service='Metric', level='warn', message=f"Invalid type '{counterType}' in message: {metric}, skipping.", redisClient=self.redisMessaging) + self.logTool.log(service='Metric', level='warn', message=f"[Metric] [handleMetrics] Invalid type '{counterType}' in message: {metric}, skipping.", redisClient=self.redisMessaging) continue except Exception as e: - self.logTool.log(service='Metric', level='error', message=f"Unable to parse message: {metric}, due to {e}. Skipping.", redisClient=self.redisMessaging) + self.logTool.log(service='Metric', level='error', message=f"[Metric] [handleMetrics] Unable to parse message: {metric}, due to {e}. Skipping.", redisClient=self.redisMessaging) return diff --git a/services/webhookService.py b/services/webhookService.py deleted file mode 100644 index e69de29..0000000 diff --git a/test_Diameter.py b/tests/test_Diameter.py similarity index 100% rename from test_Diameter.py rename to tests/test_Diameter.py diff --git a/tests_API.py b/tests/tests_API.py similarity index 100% rename from tests_API.py rename to tests/tests_API.py From f65841894c1d2c0ffbf1ed3561ba999a8fba327e Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 30 Aug 2023 17:25:57 +1000 Subject: [PATCH 09/43] Remove old.hss.py --- old.hss.py | 1012 ---------------------------------------------------- 1 file changed, 1012 deletions(-) delete mode 100644 old.hss.py diff --git a/old.hss.py b/old.hss.py deleted file mode 100644 index f32c53e..0000000 --- a/old.hss.py +++ /dev/null @@ -1,1012 +0,0 @@ -# PyHSS -# This serves as a basic 3GPP Home Subscriber Server implimenting a EIR & IMS HSS functionality -import logging -import yaml -import os -import sys -import socket -import socketserver -import binascii -import time -import _thread -import threading -import sctp -import traceback -import pprint -import diameter as DiameterLib -import systemd.daemon -from threading import Thread, Lock -from logtool import * -import contextlib -import queue - - -class ThreadJoiner: - def __init__(self, threads, thread_event): - self.threads = threads - self.thread_event = thread_event - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is not None: - self.thread_event.set() - for thread in self.threads: - while thread.is_alive(): - try: - thread.join(timeout=1) - except Exception as e: - print( - f"ThreadJoiner Exception: failed to join thread {thread}: {e}" - ) - break - - -class PyHSS: - def __init__(self): - # Load config from yaml file - try: - with open("config.yaml", "r") as config_stream: - self.yaml_config = yaml.safe_load(config_stream) - except: - print(f"config.yaml not found, exiting PyHSS.") - quit() - - # Setup logging - self.logtool = LogTool(HSS_Init=True) - self.logtool.setup_logger( - "HSS_Logger", - self.yaml_config["logging"]["logfiles"]["hss_logging_file"], - level=self.yaml_config["logging"]["level"], - ) - self.logger = logging.getLogger("HSS_Logger") - if self.yaml_config["logging"]["log_to_terminal"]: - logging.getLogger().addHandler(logging.StreamHandler()) - - # Setup Diameter - self.diameter_instance = DiameterLib.Diameter( - str(self.yaml_config["hss"].get("OriginHost", "")), - str(self.yaml_config["hss"].get("OriginRealm", "")), - str(self.yaml_config["hss"].get("ProductName", "")), - str(self.yaml_config["hss"].get("MNC", "")), - str(self.yaml_config["hss"].get("MCC", "")), - ) - - self.max_diameter_retries = int( - self.yaml_config["hss"].get("diameter_max_retries", 1) - ) - - - - try: - assert(self.yaml_config['prometheus']['enabled'] == True) - assert(self.yaml_config['prometheus']['async_subscriber_count'] == True) - - self.logger.info("Enabling Prometheus Async Sub thread") - #Add Prometheus Async Calls - prom_async_thread = threading.Thread( - target=self.prom_async_function, - name=f"prom_async_function", - args=(), - ) - prom_async_thread.start() - except: - self.logger.info("Prometheus Async Sub Count thread disabled") - - - - def terminate_connection(self, clientsocket, client_address, thread_event): - thread_event.set() - clientsocket.close() - self.logtool.Manage_Diameter_Peer(client_address, client_address, "remove") - - def handle_new_connection(self, clientsocket, client_address): - # Create our threading event, accessible by sibling threads in this connection. - socket_close_event = threading.Event() - try: - send_queue = queue.Queue() - self.logger.debug(f"New connection from {client_address}") - if ( - "client_socket_timeout" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["client_socket_timeout"] == 0 - ): - self.yaml_config["hss"]["client_socket_timeout"] = 120 - clientsocket.settimeout( - self.yaml_config["hss"].get("client_socket_timeout", 120) - ) - - send_data_thread = threading.Thread( - target=self.send_data, - name=f"send_data_thread", - args=(clientsocket, send_queue, socket_close_event), - ) - self.logger.debug("handle_new_connection: Starting send_data thread") - send_data_thread.start() - - self.logtool.Manage_Diameter_Peer(client_address, client_address, "add") - manage_client_thread = threading.Thread( - target=self.manage_client, - name=f"manage_client_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug("handle_new_connection: Starting manage_client thread") - manage_client_thread.start() - - threads_to_join = [manage_client_thread] - threads_to_join.append(send_data_thread) - - # If Redis is enabled, start manage_client_async and manage_client_dwr threads. - if self.yaml_config["redis"]["enabled"]: - if ( - "async_check_interval" not in self.yaml_config["hss"] - or self.yaml_config["hss"]["async_check_interval"] == 0 - ): - self.yaml_config["hss"]["async_check_interval"] = 10 - manage_client_async_thread = threading.Thread( - target=self.manage_client_async, - name=f"manage_client_async_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_async thread" - ) - manage_client_async_thread.start() - - manage_client_dwr_thread = threading.Thread( - target=self.manage_client_dwr, - name=f"manage_client_dwr_thread: client_address: {client_address}", - args=( - clientsocket, - client_address, - self.diameter_instance, - socket_close_event, - send_queue, - ), - ) - self.logger.debug( - "handle_new_connection: Starting manage_client_dwr thread" - ) - manage_client_dwr_thread.start() - - threads_to_join.append(manage_client_async_thread) - threads_to_join.append(manage_client_dwr_thread) - - self.logger.debug( - f"handle_new_connection: Total PyHSS Active Threads: {threading.active_count()}" - ) - for thread in threading.enumerate(): - if "dummy" not in thread.name.lower(): - self.logger.debug(f"Active Thread name: {thread.name}") - - with ThreadJoiner(threads_to_join, socket_close_event): - socket_close_event.wait() - self.terminate_connection( - clientsocket, client_address, socket_close_event - ) - self.logger.debug(f"Closing thread for client; {client_address}") - return - - except Exception as e: - self.logger.error(f"Exception for client {client_address}: {e}") - self.logger.error(f"Closing connection for {client_address}") - self.terminate_connection(clientsocket, client_address, socket_close_event) - return - - @prom_diam_response_time_diam.time() - def process_Diameter_request( - self, clientsocket, client_address, diameter, data, thread_event, send_queue - ): - packet_length = diameter.decode_diameter_packet_length( - data - ) # Calculate length of packet from start of packet - if packet_length <= 32: - self.logger.error("Received an invalid packet with length <= 32") - self.terminate_connection(clientsocket, client_address, thread_event) - return - - data_sum = data + clientsocket.recv( - packet_length - 32 - ) # Recieve remainder of packet from buffer - packet_vars, avps = diameter.decode_diameter_packet( - data_sum - ) # Decode packet into array of AVPs and Dict of Packet Variables (packet_vars) - try: - packet_vars["Source_IP"] = client_address[0] - except: - self.logger.debug("Failed to add Source_IP to packet_vars") - - start_time = time.time() - origin_host = diameter.get_avp_data(avps, 264)[0] # Get OriginHost from AVP - origin_host = binascii.unhexlify(origin_host).decode("utf-8") # Format it - - # label_values = str(packet_vars['ApplicationId']), str(packet_vars['command_code']), origin_host, 'request' - prom_diam_request_count.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).inc() - - - self.logger.info( - "\n\nNew request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + ", flags " - + str(packet_vars["flags"]) - + ", e2e ID: " - + str(packet_vars["end-to-end-identifier"]) - ) - - # Gobble up any Response traffic that is sent to us: - if packet_vars["flags_bin"][0:1] == "0": - self.logger.info("Got a Response, not a request - dropping it.") - self.logger.info(packet_vars) - return - - # Send Capabilities Exchange Answer (CEA) response to Capabilites Exchange Request (CER) - elif ( - packet_vars["command_code"] == 257 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 257 (CER) from {origin_host}" - + "\n\tSending response (CEA)" - ) - try: - response = diameter.Answer_257( - packet_vars, avps, str(self.yaml_config["hss"]["bind_ip"][0]) - ) # Generate Diameter packet - # prom_diam_response_count_successful.inc() - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - # prom_diam_response_count_fail.inc() - self.logger.info("Generated CEA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - prom_diam_connected_peers.labels(origin_host).set(1) - - # Send Credit Control Answer (CCA) response to Credit Control Request (CCR) - elif ( - packet_vars["command_code"] == 272 - and packet_vars["ApplicationId"] == 16777238 - ): - self.logger.info( - f"Received 3GPP Credit-Control-Request from {origin_host}" - + "\n\tGenerating (CCA)" - ) - try: - response = diameter.Answer_16777238_272( - packet_vars, avps - ) # Generate Diameter packet - except Exception as E: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error(f"Failed to generate response {str(E)}") - self.logger.info("Generated CCA") - - # Send Device Watchdog Answer (DWA) response to Device Watchdog Requests (DWR) - elif ( - packet_vars["command_code"] == 280 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 280 (DWR) from {origin_host}" - + "\n\tSending response (DWA)" - ) - self.logger.debug(f"Total PyHSS Active Threads: {threading.active_count()}") - try: - response = diameter.Answer_280( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.info("Generated DWA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "update") - - # Send Disconnect Peer Answer (DPA) response to Disconnect Peer Request (DPR) - elif ( - packet_vars["command_code"] == 282 - and packet_vars["ApplicationId"] == 0 - and packet_vars["flags"] == "80" - ): - self.logger.info( - f"Received Request with command code 282 (DPR) from {origin_host}" - + "\n\tForwarding request..." - ) - response = diameter.Answer_282( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated DPA") - self.logtool.Manage_Diameter_Peer(origin_host, client_address, "remove") - prom_diam_connected_peers.labels(origin_host).set(0) - - # S6a Authentication Information Answer (AIA) response to Authentication Information Request (AIR) - elif ( - packet_vars["command_code"] == 318 - and packet_vars["ApplicationId"] == 16777251 - and packet_vars["flags"] == "c0" - ): - self.logger.info( - f"Received Request with command code 318 (3GPP Authentication-Information-Request) from {origin_host}" - + "\n\tGenerating (AIA)" - ) - try: - response = diameter.Answer_16777251_318( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated AIR") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for AIR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated DIAMETER_USER_DATA_NOT_AVAILABLE AIR") - - # S6a Update Location Answer (ULA) response to Update Location Request (ULR) - elif ( - packet_vars["command_code"] == 316 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 316 (3GPP Update Location-Request) from {origin_host}" - + "\n\tGenerating (ULA)" - ) - try: - response = diameter.Answer_16777251_316( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated ULA") - except Exception as e: - self.logger.info("Failed to generate Diameter Response for ULR") - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated error DIAMETER_USER_DATA_NOT_AVAILABLE ULA") - - # Send ULA data & clear tx buffer - clientsocket.sendall(bytes.fromhex(response)) - response = "" - if "Insert_Subscriber_Data_Force" in yaml_config["hss"]: - if yaml_config["hss"]["Insert_Subscriber_Data_Force"] == True: - self.logger.debug("ISD triggered after ULA") - # Generate Insert Subscriber Data Request - response = diameter.Request_16777251_319( - packet_vars, avps - ) # Generate Diameter packet - self.logger.info("Generated IDR") - # Send ISD data - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent IDR") - return - # S6a inbound Insert-Data-Answer in response to our IDR - elif ( - packet_vars["command_code"] == 319 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received response with command code 319 (3GPP Insert-Subscriber-Answer) from {origin_host}" - ) - return - # S6a Purge UE Answer (PUA) response to Purge UE Request (PUR) - elif ( - packet_vars["command_code"] == 321 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 321 (3GPP Purge UE Request) from {origin_host}" - + "\n\tGenerating (PUA)" - ) - try: - response = diameter.Answer_16777251_321( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate PUA") - self.logger.info("Generated PUA") - # S6a Notify Answer (NOA) response to Notify Request (NOR) - elif ( - packet_vars["command_code"] == 323 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info( - f"Received Request with command code 323 (3GPP Notify Request) from {origin_host}" - + "\n\tGenerating (NOA)" - ) - try: - response = diameter.Answer_16777251_323( - packet_vars, avps - ) # Generate Diameter packet - except: - response = diameter.Respond_ResultCode( - packet_vars, avps, 5012 - ) # Generate Diameter response with "DIAMETER_UNABLE_TO_COMPLY" (5012) - self.logger.error("Failed to generate NOA") - self.logger.info("Generated NOA") - # S6a Cancel Location Answer eater - elif ( - packet_vars["command_code"] == 317 - and packet_vars["ApplicationId"] == 16777251 - ): - self.logger.info("Received Response with command code 317 (3GPP Cancel Location Request) from " + str(origin_host)) - - # Cx Authentication Answer - elif ( - packet_vars["command_code"] == 300 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 300 (3GPP Cx User Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_300( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Auth Answer" - ) - self.logger.info(e) - self.logger.info(traceback.print_exc()) - self.logger.info( - type(e).__name__, # TypeError - __file__, # /tmp/example.py - e.__traceback__.tb_lineno # 2 - ) - - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Auth Answer") - - # Cx Server Assignment Answer - elif ( - packet_vars["command_code"] == 301 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 301 (3GPP Cx Server Assignemnt Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_301( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Server Assignment Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Server Assignment Answer") - - # Cx Location Information Answer - elif ( - packet_vars["command_code"] == 302 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 302 (3GPP Cx Location Information Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_302( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Location Information Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Location Information Answer") - - # Cx Multimedia Authentication Answer - elif ( - packet_vars["command_code"] == 303 - and packet_vars["ApplicationId"] == 16777216 - ): - self.logger.info( - f"Received Request with command code 303 (3GPP Cx Multimedia Authentication Request) from {origin_host}" - + "\n\tGenerating (MAA)" - ) - try: - response = diameter.Answer_16777216_303( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Cx Multimedia Authentication Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated Cx Multimedia Authentication Answer") - - # Sh User-Data-Answer - elif ( - packet_vars["command_code"] == 306 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 306 (3GPP Sh User-Data Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_306( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh User-Data Answer") - - # Sh Profile-Update-Answer - elif ( - packet_vars["command_code"] == 307 - and packet_vars["ApplicationId"] == 16777217 - ): - self.logger.info( - f"Received Request with command code 307 (3GPP Sh Profile-Update Request) from {origin_host}" - ) - try: - response = diameter.Answer_16777217_307( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for Sh User-Data Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 5001 - ) # DIAMETER_ERROR_USER_UNKNOWN - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) - self.logger.info("Sent negative response") - return - self.logger.info("Generated Sh Profile-Update Answer") - - # S13 ME-Identity-Check Answer - elif ( - packet_vars["command_code"] == 324 - and packet_vars["ApplicationId"] == 16777252 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP S13 ME-Identity-Check Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777252_324( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for S13 ME-Identity Check Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated S13 ME-Identity Check Answer") - - # SLh LCS-Routing-Info-Answer - elif ( - packet_vars["command_code"] == 8388622 - and packet_vars["ApplicationId"] == 16777291 - ): - self.logger.info( - f"Received Request with command code 324 (3GPP SLh LCS-Routing-Info-Answer Request) from {origin_host}" - + "\n\tGenerating (MICA)" - ) - try: - response = diameter.Answer_16777291_8388622( - packet_vars, avps - ) # Generate Diameter packet - except Exception as e: - self.logger.info( - "Failed to generate Diameter Response for SLh LCS-Routing-Info-Answer" - ) - self.logger.info(e) - traceback.print_exc() - response = diameter.Respond_ResultCode( - packet_vars, avps, 4100 - ) # DIAMETER_USER_DATA_NOT_AVAILABLE - self.logger.info("Generated SLh LCS-Routing-Info-Answer") - - # Handle Responses generated by the Async functions - elif packet_vars["flags"] == "00": - self.logger.info( - "Got response back with command code " - + str(packet_vars["command_code"]) - ) - self.logger.info("response packet_vars: " + str(packet_vars)) - self.logger.info("response avps: " + str(avps)) - response = "" - else: - self.logger.error( - "\n\nRecieved unrecognised request with Command Code: " - + str(packet_vars["command_code"]) - + ", ApplicationID: " - + str(packet_vars["ApplicationId"]) - + " and flags " - + str(packet_vars["flags"]) - ) - for keys in packet_vars: - self.logger.error(keys) - self.logger.error("\t" + str(packet_vars[keys])) - self.logger.error(avps) - self.logger.error("Sending negative response") - response = diameter.Respond_ResultCode( - packet_vars, avps, 3001 - ) # Generate Diameter response with "Command Unsupported" (3001) - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - - prom_diam_response_time_method.labels( - str(packet_vars["ApplicationId"]), - str(packet_vars["command_code"]), - origin_host, - "request", - ).observe(time.time() - start_time) - - # Diameter Transmission - retries = 0 - while retries < self.max_diameter_retries: - try: - send_queue.put(bytes.fromhex(response)) - break - except socket.error as e: - self.logger.error(f"Socket error for client {client_address}: {e}") - retries += 1 - if retries > self.max_diameter_retries: - self.logger.error( - f"Max retries reached for client {client_address}. Closing connection." - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - break - time.sleep(1) # Wait for 1 second before retrying - except Exception as e: - self.logger.info("Failed to send Diameter Response") - self.logger.debug(f"Diameter Response Body: {str(response)}") - self.logger.info(e) - traceback.print_exc() - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info("Thread terminated to " + str(client_address)) - break - - def manage_client( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - data = clientsocket.recv(32) - if not data: - self.logger.info( - f"manage_client: Connection closed by {str(client_address)}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - self.process_Diameter_request( - clientsocket, - client_address, - diameter, - data, - thread_event, - send_queue, - ) - - except socket.timeout: - self.logger.warning( - f"manage_client: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except KeyboardInterrupt: - # Clean up the connection on keyboard interrupt - response = ( - diameter.Request_282() - ) # Generate Disconnect Peer Request Diameter packet - send_queue.put(bytes.fromhex(response)) - # clientsocket.sendall(bytes.fromhex(response)) # Send it - self.terminate_connection(clientsocket, client_address, thread_event) - self.logger.info( - "manage_client: Connection closed nicely due to keyboard interrupt" - ) - sys.exit() - - except Exception as manage_client_exception: - self.logger.error( - f"manage_client: Exception in manage_client: {manage_client_exception}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_async( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - # # Sleep for 10 seconds to wait for the connection to come up - time.sleep(10) - self.logger.debug("manage_client_async: Getting ActivePeerDict") - self.logger.debug( - f"manage_client_async: Total PyHSS Active Threads: {threading.active_count()}" - ) - ActivePeerDict = self.logtool.GetDiameterPeers() - self.logger.debug( - f"manage_client_async: Got Active Peer dict in Async Thread: {str(ActivePeerDict)}" - ) - if client_address[0] in ActivePeerDict: - self.logger.debug( - "manage_client_async: This is host: " - + str(ActivePeerDict[str(client_address[0])]["DiameterHostname"]) - ) - DiameterHostname = str( - ActivePeerDict[str(client_address[0])]["DiameterHostname"] - ) - else: - self.logger.debug("manage_client_async: No matching Diameter Host found.") - return - - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"manage_client_async: Closing manage_client_async thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - time.sleep(self.yaml_config["hss"]["async_check_interval"]) - self.logger.debug( - f"manage_client_async: Sleep interval expired for Diameter Peer {str(DiameterHostname)}" - ) - if int(self.yaml_config["hss"]["async_check_interval"]) == 0: - self.logger.error( - f"manage_client_async: No async_check_interval Timer set - Not checking Async Queue for host connection {str(DiameterHostname)}" - ) - return - try: - self.logger.debug( - "manage_client_async: Reading from request queue '" - + str(DiameterHostname) - + "_request_queue'" - ) - data_to_send = self.logtool.RedisHMGET( - str(DiameterHostname) + "_request_queue" - ) - for key in data_to_send: - data = data_to_send[key].decode("utf-8") - send_queue.put(bytes.fromhex(data)) - self.logtool.RedisHDEL( - str(DiameterHostname) + "_request_queue", key - ) - except Exception as redis_exception: - self.logger.error( - f"manage_client_async: Redis exception in manage_client_async: {redis_exception}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - - except socket.timeout: - self.logger.warning( - f"manage_client_async: Socket timeout for client: {client_address}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - except socket.error as e: - self.logger.error( - f"manage_client_async: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception: - self.logger.error( - f"manage_client_async: Terminating for host connection {str(DiameterHostname)}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - - def manage_client_dwr( - self, clientsocket, client_address, diameter, thread_event, send_queue - ): - while True: - try: - if thread_event.is_set(): - self.logger.debug( - f"Closing manage_client_dwr thread for client: {client_address}" - ) - self.terminate_connection( - clientsocket, client_address, thread_event - ) - return - if ( - int(self.yaml_config["hss"]["device_watchdog_request_interval"]) - != 0 - ): - time.sleep( - self.yaml_config["hss"]["device_watchdog_request_interval"] - ) - else: - self.logger.info("DWR Timer to set to 0 - Not sending DWRs") - return - - except: - self.logger.error( - "No DWR Timer set - Not sending Device Watchdog Requests" - ) - return - try: - self.logger.debug("Sending Keepalive to " + str(client_address) + "...") - request = diameter.Request_280() - send_queue.put(bytes.fromhex(request)) - # clientsocket.sendall(bytes.fromhex(request)) # Send it - self.logger.debug("Sent Keepalive to " + str(client_address) + "...") - except socket.error as e: - self.logger.error( - f"manage_client_dwr: Socket error for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - return - except Exception as e: - self.logger.error( - f"manage_client_dwr: General exception for client {client_address}: {e}" - ) - self.terminate_connection(clientsocket, client_address, thread_event) - - def get_socket_family(self): - if ":" in self.yaml_config["hss"]["bind_ip"][0]: - self.logger.info("IPv6 Address Specified") - return socket.AF_INET6 - else: - self.logger.info("IPv4 Address Specified") - return socket.AF_INET - - def send_data(self, clientsocket, send_queue, thread_event): - while not thread_event.is_set(): - try: - data = send_queue.get(timeout=1) - # Check if data is bytes, otherwise convert it using bytes.fromhex() - if not isinstance(data, bytes): - data = bytes.fromhex(data) - - clientsocket.sendall(data) - except ( - queue.Empty - ): # Catch the Empty exception when the queue is empty and the timeout has expired - continue - except Exception as e: - self.logger.error(f"send_data_thread: Exception: {e}") - return - - def start_server(self): - if self.yaml_config["hss"]["transport"] == "SCTP": - self.logger.debug("Using SCTP for Transport") - # Create a SCTP socket - sock = sctp.sctpsocket_tcp(self.get_socket_family()) - sock.initparams.num_ostreams = 64 - # Loop through the possible Binding IPs from the config and bind to each for Multihoming - server_addresses = [] - - # Prepend each entry into list, so the primary IP is bound first - for host in self.yaml_config["hss"]["bind_ip"]: - self.logger.info("Seting up SCTP binding on IP address " + str(host)) - this_IP_binding = [ - (str(host), int(self.yaml_config["hss"]["bind_port"])) - ] - server_addresses = this_IP_binding + server_addresses - - print("server_addresses are: " + str(server_addresses)) - sock.bindx(server_addresses) - self.logger.info("PyHSS listening on SCTP port " + str(server_addresses)) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming SCTP connections - sock.listen(20) - elif self.yaml_config["hss"]["transport"] == "TCP": - self.logger.debug("Using TCP socket") - # Create a TCP/IP socket - sock = socket.socket(self.get_socket_family(), socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Bind the socket to the port - server_address = ( - str(self.yaml_config["hss"]["bind_ip"][0]), - int(self.yaml_config["hss"]["bind_port"]), - ) - sock.bind(server_address) - self.logger.debug( - "PyHSS listening on TCP port " - + str(self.yaml_config["hss"]["bind_ip"][0]) - ) - systemd.daemon.notify("READY=1") - # Listen for up to 20 incoming TCP connections - sock.listen(20) - else: - self.logger.error("No valid transports found (No SCTP or TCP) - Exiting") - quit() - - while True: - # Wait for a connection - self.logger.info("Waiting for a connection...") - connection, client_address = sock.accept() - _thread.start_new_thread( - self.handle_new_connection, - ( - connection, - client_address, - ), - ) - - - def prom_async_function(self): - while True: - self.logger.debug("Running prom_async_function") - self.diameter_instance.Generate_Prom_Stats() - time.sleep(120) - - -if __name__ == "__main__": - pyHss = PyHSS() - pyHss.start_server() From 3a6e12dfef98df88661a7ea1b9874ed4c3d32c25 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 1 Sep 2023 07:33:59 +1000 Subject: [PATCH 10/43] Progress Freeze, update config.yaml, functional --- config.yaml | 38 ++++--- lib/database.py | 7 +- lib/diameter.py | 53 +++++---- lib/diameterAsync.py | 209 +++++++++++++++++++++++++++--------- lib/messaging.py | 2 +- lib/messagingAsync.py | 9 +- services/diameterService.py | 78 +++++++------- services/hssService.py | 2 +- 8 files changed, 271 insertions(+), 127 deletions(-) diff --git a/config.yaml b/config.yaml index 1cebdc0..54d96a0 100644 --- a/config.yaml +++ b/config.yaml @@ -33,6 +33,18 @@ hss: #The maximum time to wait, in seconds, before disconnecting a client when no data is received. client_socket_timeout: 120 + #Enable benchmarking log output for response times - set to False in production. + enable_benchmarking: False + + #The maximum time to wait, in seconds, before disconnecting a client when no data is received. + client_socket_timeout: 300 + + #The maximum time to wait, in seconds, before discarding a diameter request. + diameter_request_timeout: 3 + + #The amount of time, in seconds, before purging a disconnected client from the Active Diameter Peers key in redis. + active_diameter_peers_timeout: 10 + #Prevent updates from being performed without a valid 'Provisioning-Key' in the header lock_provisioning: False @@ -59,22 +71,18 @@ hss: api: page_size: 200 -external: - external_webhook_notification_enabled: False - external_webhook_notification_url: https://api.example.com/webhook - eir: imsi_imei_logging: True #Store current IMEI / IMSI pair in backend - sim_swap_notify_webhook: http://localhost:5000/webhooks/sim_swap_notify/ no_match_response: 2 #Greylist tac_database_csv: '/etc/pyhss/tac_database_Nov2022.csv' logging: - level: DEBUG + level: INFO logfiles: - hss_logging_file: log/hss.log - diameter_logging_file: log/diameter.log - database_logging_file: log/db.log + hss_logging_file: /var/log/pyhss_hss.log + diameter_logging_file: /var/log/pyhss_diameter.log + geored_logging_file: /var/log/pyhss_geored.log + metric_logging_file: /var/log/pyhss_metrics.log log_to_terminal: True sqlalchemy_sql_echo: True sqlalchemy_pool_recycle: 15 @@ -89,18 +97,22 @@ database: password: password database: hss2 +## External Webhook Notifications +webhooks: + enabled: True + endpoints: + - http://10.5.5.66:8080 + ## Geographic Redundancy Parameters geored: enabled: False sync_actions: ['HSS', 'IMS', 'PCRF', 'EIR'] #What event actions should be synced - sync_endpoints: #List of PyHSS API Endpoints to update + endpoints: #List of PyHSS API Endpoints to update - 'http://hss01.mnc001.mcc001.3gppnetwork.org:8080' - 'http://hss02.mnc001.mcc001.3gppnetwork.org:8080' -## Stats Parameters +#Redis is required to run PyHSS. A locally running instance is recommended for production. redis: - enabled: False - clear_stats_on_boot: True host: localhost port: 6379 diff --git a/lib/database.py b/lib/database.py index 9381f81..64ee118 100755 --- a/lib/database.py +++ b/lib/database.py @@ -258,12 +258,15 @@ class SUBSCRIBER_ATTRIBUTES(Base): class Database: - def __init__(self, logTool, redisMessaging): + def __init__(self, logTool, redisMessaging=None): with open("../config.yaml", 'r') as stream: self.config = (yaml.safe_load(stream)) self.logTool = logTool - self.redisMessaging = redisMessaging + if redisMessaging: + self.redisMessaging = redisMessaging + else: + self.redisMessaging = RedisMessaging() db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") self.engine = create_engine( diff --git a/lib/diameter.py b/lib/diameter.py index 52f462f..af578ab 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -8,11 +8,12 @@ import ipaddress import jinja2 from database import Database +from messaging import RedisMessaging import yaml class Diameter: - def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999"): + def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999", redisMessaging=None): with open("../config.yaml", 'r') as stream: self.yaml_config = (yaml.safe_load(stream)) @@ -22,8 +23,11 @@ def __init__(self, redisMessaging, logTool, originHost: str="hss01", originRealm self.MNC = str(mnc) self.MCC = str(mcc) self.logTool = logTool - self.redisMessaging=redisMessaging - self.database = Database(logTool=logTool, redisMessaging=redisMessaging) + if redisMessaging: + self.redisMessaging=redisMessaging + else: + self.redisMessaging=RedisMessaging() + self.database = Database(logTool=logTool) self.logTool.log(service='HSS', level='info', message=f"Initialized Diameter Library", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=f"Origin Host: {str(originHost)}", redisClient=self.redisMessaging) @@ -276,20 +280,23 @@ def generate_vendor_avp(self, avp_code, avp_flags, avp_vendorid, avp_content): return avp def generate_diameter_packet(self, packet_version, packet_flags, packet_command_code, packet_application_id, packet_hop_by_hop_id, packet_end_to_end_id, avp): - #Placeholder that is updated later on - packet_length = 228 - packet_length = format(packet_length,"x").zfill(6) - - packet_command_code = format(packet_command_code,"x").zfill(6) - - packet_application_id = format(packet_application_id,"x").zfill(8) - - packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp - packet_length = int(round(len(packet_hex))/2) - packet_length = format(packet_length,"x").zfill(6) + try: + packet_length = 228 + packet_length = format(packet_length,"x").zfill(6) - packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp - return packet_hex + packet_command_code = format(packet_command_code,"x").zfill(6) + + packet_application_id = format(packet_application_id,"x").zfill(8) + + packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp + packet_length = int(round(len(packet_hex))/2) + packet_length = format(packet_length,"x").zfill(6) + + packet_hex = packet_version + packet_length + packet_flags + packet_command_code + packet_application_id + packet_hop_by_hop_id + packet_end_to_end_id + avp + return packet_hex + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [generate_diameter_packet] Exception: {e}", redisClient=self.redisMessaging) + def decode_diameter_packet(self, data): @@ -479,7 +486,8 @@ def generateDiameterResponse(self, binaryData: str) -> str: if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) response = diameterApplication["responseMethod"](packet_vars, avps) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Successfully generated response: {response}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + break except Exception as e: continue @@ -748,7 +756,7 @@ def Answer_16777251_316(self, packet_vars, avps): except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) - self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777251_316] [ULR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) self.database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) @@ -1095,7 +1103,7 @@ def Answer_16777238_272(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) avp = '' #Initiate empty var AVP @@ -1241,7 +1249,7 @@ def Answer_16777216_300(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_300] [UAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) @@ -1318,7 +1326,6 @@ def Answer_16777216_300(self, packet_vars, avps): avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response #3GPP Cx Server Assignment Answer @@ -1344,7 +1351,7 @@ def Answer_16777216_301(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.logTool.log(service='HSS', level='debug', message="Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_301] [SAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) @@ -1369,7 +1376,7 @@ def Answer_16777216_301(self, packet_vars, avps): #Cx-User-Data (XML) #This loads a Jinja XML template as the default iFC - templateLoader = jinja2.FileSystemLoader(searchpath="./") + templateLoader = jinja2.FileSystemLoader(searchpath="../") templateEnv = jinja2.Environment(loader=templateLoader) self.logTool.log(service='HSS', level='debug', message="Loading iFC from path " + str(ims_subscriber_details['ifc_path']), redisClient=self.redisMessaging) template = templateEnv.get_template(ims_subscriber_details['ifc_path']) diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index a8a60c7..fe1fc22 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -4,7 +4,7 @@ class DiameterAsync: - def __init__(self): + def __init__(self, redisMessaging, logTool): self.diameterCommandList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, @@ -24,9 +24,12 @@ def __init__(self): {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] + self.redisMessaging = redisMessaging + self.logTool = logTool + #Generates rounding for calculating padding - async def myRoundAsync(self, n, base=4): + async def myRound(self, n, base=4): if(n > 0): return math.ceil(n/4.0) * 4 elif( n < 0): @@ -34,7 +37,7 @@ async def myRoundAsync(self, n, base=4): else: return 4 - async def getAvpDataAsync(self, avps, avp_code): + async def getAvpData(self, avps, avp_code): #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) misc_data = [] for keys in avps: @@ -42,10 +45,38 @@ async def getAvpDataAsync(self, avps, avp_code): misc_data.append(keys['misc_data']) return misc_data - async def decodeDiameterPacketAsync(self, data): + # async def decodeDiameterPacket(self, data): + # packet_vars = {} + # avps = [] + + # if type(data) is bytes: + # data = data.hex() + + # packet_vars['packet_version'] = data[0:2] + # packet_vars['length'] = int(data[2:8], 16) + # packet_vars['flags'] = data[8:10] + # packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + # packet_vars['command_code'] = int(data[10:16], 16) + # packet_vars['ApplicationId'] = int(data[16:24], 16) + # packet_vars['hop-by-hop-identifier'] = data[24:32] + # packet_vars['end-to-end-identifier'] = data[32:40] + + # avp_sum = data[40:] + + # avp_vars, remaining_avps = await(self.decodeAvpPacket(avp_sum)) + # avps.append(avp_vars) + + # while len(remaining_avps) > 0: + # avp_vars, remaining_avps = await(self.decodeAvpPacket(remaining_avps)) + # avps.append(avp_vars) + # else: + # pass + # return packet_vars, avps + + async def decodeDiameterPacket(self, data): packet_vars = {} avps = [] - + if type(data) is bytes: data = data.hex() @@ -58,77 +89,162 @@ async def decodeDiameterPacketAsync(self, data): packet_vars['hop-by-hop-identifier'] = data[24:32] packet_vars['end-to-end-identifier'] = data[32:40] - avp_sum = data[40:] + remaining_avps = data[40:] - avp_vars, remaining_avps = await(self.decodeAvpPacketAsync(avp_sum)) - avps.append(avp_vars) - while len(remaining_avps) > 0: - avp_vars, remaining_avps = await(self.decodeAvpPacketAsync(remaining_avps)) + avp_vars, remaining_avps = await self.decodeAvpPacket(remaining_avps) avps.append(avp_vars) else: pass + return packet_vars, avps - async def decodeAvpPacketAsync(self, data): + async def decodeAvpPacket(self, data): + avp_vars = {} + sub_avps = [] if len(data) <= 8: - #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP raise ValueError("Length of data is too short to be valid AVP") - avp_vars = {} avp_vars['avp_code'] = int(data[0:8], 16) - + avp_vars['avp_flags'] = data[8:10] avp_vars['avp_length'] = int(data[10:16], 16) + avp_padded_length = (avp_vars['avp_length'] + 3) // 4 * 4 + if avp_vars['avp_flags'] == "c0": - #If c0 is present AVP is Vendor AVP avp_vars['vendor_id'] = int(data[16:24], 16) avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] else: - #if is not a vendor AVP avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] + sub_avp_data = avp_vars['misc_data'] + + while len(sub_avp_data) >= 16: + sub_avp_vars = {} + sub_avp_vars['avp_code'] = int(sub_avp_data[0:8], 16) + sub_avp_vars['avp_flags'] = sub_avp_data[8:10] + sub_avp_vars['avp_length'] = int(sub_avp_data[10:16], 16) + sub_avp_padded_length = (sub_avp_vars['avp_length'] + 3) // 4 * 4 + + if sub_avp_vars['avp_code'] > 9999: + break + + if '40' <= sub_avp_vars['avp_flags'] <= '7F': + sub_avp_vars['vendor_id'] = int(sub_avp_data[16:24], 16) + sub_avp_vars['misc_data'] = sub_avp_data[24:(24 + (sub_avp_vars['avp_length'] - 8) * 2)] + else: + sub_avp_vars['misc_data'] = sub_avp_data[16:(16 + (sub_avp_vars['avp_length'] - 8) * 2)] + + sub_avps.append(sub_avp_vars) + + sub_avp_data = sub_avp_data[(sub_avp_padded_length * 2):] + + avp_vars['sub_avps'] = sub_avps + if avp_vars['avp_length'] % 4 == 0: - #Multiple of 4 - No Padding needed avp_vars['padding'] = 0 else: - #Not multiple of 4 - Padding needed - rounded_value = await(self.myRoundAsync(avp_vars['avp_length'])) + rounded_value = await self.myRound(avp_vars['avp_length']) avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] + remaining_avps = data[(avp_padded_length * 2):] - #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] - try: - sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacketAsync(avp_vars['misc_data'])) - #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... - if int(sub_avp_vars['avp_code']) > 9999: - pass - else: - #If the decoded AVP is valid store it - avp_vars['misc_data'] = [] - avp_vars['misc_data'].append(sub_avp_vars) - #While there are more AVPs to be decoded, decode them: - while len(sub_remaining_avps) > 0: - sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacketAsync(sub_remaining_avps)) - avp_vars['misc_data'].append(sub_avp_vars) + return avp_vars, remaining_avps + + + + + + # async def decodeAvpPacket(self, data): + + # if len(data) <= 8: + # #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP + # raise ValueError("Length of data is too short to be valid AVP") + + # avp_vars = {} + # avp_vars['avp_code'] = int(data[0:8], 16) + + # avp_vars['avp_flags'] = data[8:10] + # avp_vars['avp_length'] = int(data[10:16], 16) + # if avp_vars['avp_flags'] == "c0": + # #If c0 is present AVP is Vendor AVP + # avp_vars['vendor_id'] = int(data[16:24], 16) + # avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] + # else: + # #if is not a vendor AVP + # avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] + + # if avp_vars['avp_length'] % 4 == 0: + # #Multiple of 4 - No Padding needed + # avp_vars['padding'] = 0 + # else: + # #Not multiple of 4 - Padding needed + # rounded_value = await(self.myRound(avp_vars['avp_length'])) + # avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 + # avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] + + + # #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] + # try: + # sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacket(avp_vars['misc_data'])) + # #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... + # if int(sub_avp_vars['avp_code']) > 9999: + # pass + # else: + # #If the decoded AVP is valid store it + # avp_vars['misc_data'] = [] + # avp_vars['misc_data'].append(sub_avp_vars) + # #While there are more AVPs to be decoded, decode them: + # while len(sub_remaining_avps) > 0: + # sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacket(sub_remaining_avps)) + # avp_vars['misc_data'].append(sub_avp_vars) - except Exception as e: - if str(e) == "invalid literal for int() with base 16: ''": - pass - elif str(e) == "Length of data is too short to be valid AVP": - pass - else: - #self.logger.warn("[Diameter] [decodeAvpPacketAsync] failed to decode sub-avp - error: " + str(e)) - pass + # except Exception as e: + # if str(e) == "invalid literal for int() with base 16: ''": + # pass + # elif str(e) == "Length of data is too short to be valid AVP": + # pass + # else: + # pass remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again return avp_vars, remaining_avps + async def getPeerType(self, originHost: str) -> str: + try: + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + + for peer in peerTypes: + if peer in originHost.lower(): + return peer + + except Exception as e: + return '' + + async def getConnectedPeersByType(self, peerType: str) -> list: + try: + peerType = peerType.lower() + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + + if peerType not in peerTypes: + return [] + filteredConnectedPeers = [] + activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers")) + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + filteredConnectedPeers.append(activePeers.get(key, {})) + + return filteredConnectedPeers + + except Exception as e: + return [] + - async def getDiameterMessageTypeAsync(self, binaryData: str) -> dict: - packet_vars, avps = await(self.decodeDiameterPacketAsync(binaryData)) + async def getDiameterMessageType(self, binaryData: str) -> dict: + packet_vars, avps = await(self.decodeDiameterPacket(binaryData)) response = {} for diameterApplication in self.diameterCommandList: @@ -137,19 +253,17 @@ async def getDiameterMessageTypeAsync(self, binaryData: str) -> dict: assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) response['inbound'] = diameterApplication["requestAcronym"] response['outbound'] = diameterApplication["responseAcronym"] - #self.logger.debug(f"[Diameter] [getDiameterMessageTypeAsync] Successfully got message type: {response}") except Exception as e: continue return response - async def generateDiameterResponseAsync(self, binaryData: str) -> str: - packet_vars, avps = await(self.decodeDiameterPacketAsync(binaryData)) + async def generateDiameterResponse(self, binaryData: str) -> str: + packet_vars, avps = await(self.decodeDiameterPacket(binaryData)) response = '' # Drop packet if it's a response packet: if packet_vars["flags_bin"][0:1] == "0": - #self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Got a Response, not a request - dropping it: {packet_vars}") return for diameterApplication in self.diameterCommandList: @@ -159,7 +273,6 @@ async def generateDiameterResponseAsync(self, binaryData: str) -> str: if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) response = diameterApplication["responseMethod"](packet_vars, avps) - #self.logger.debug(f"[Diameter] [generateDiameterResponseAsync] Successfully generated response: {response}") except Exception as e: continue diff --git a/lib/messaging.py b/lib/messaging.py index 2491f44..8e783dd 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -8,7 +8,7 @@ class RedisMessaging: """ def __init__(self, host: str='localhost', port: int=6379): - self.redisClient = Redis(host=host, port=port) + self.redisClient = Redis(unix_socket_path='/var/run/redis/redis-server.sock') pass def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index bc9c76d..a3eb297 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -9,7 +9,7 @@ class RedisMessagingAsync: """ def __init__(self, host: str='localhost', port: int=6379): - self.redisClient = redis.Redis(host=host, port=port) + self.redisClient = redis.Redis(unix_socket_path='/var/run/redis/redis-server.sock') async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: """ @@ -106,9 +106,12 @@ async def getNextQueue(self, pattern: str='*') -> str: Returns the next Queue (Key) in the list, asynchronously. """ try: - nextQueue = await(self.redisClient.keys(pattern)) - return nextQueue[0].decode() + result = [] + async for nextQueue in self.redisClient.scan_iter(match=pattern): + result.append(nextQueue) + return next(iter(result), '') if result else '' except Exception as e: + print(e) return '' async def deleteQueue(self, queue: str) -> bool: diff --git a/services/diameterService.py b/services/diameterService.py index 9d49d75..ca7eec6 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -7,6 +7,7 @@ from diameterAsync import DiameterAsync from banners import Banners from logtool import LogTool +import traceback class DiameterService: """ @@ -26,7 +27,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.banners = Banners() self.logTool = LogTool(config=self.config) - self.diameterLibrary = DiameterAsync() + self.diameterLibrary = DiameterAsync(redisMessaging=self.redisMessaging, logTool=self.logTool) self.activePeers = {} self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) @@ -36,9 +37,9 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb Asynchronously validates a given diameter inbound, and increments the 'Number of Diameter Inbounds' metric. """ try: - packetVars, avps = await(self.diameterLibrary.decodeDiameterPacketAsync(inboundData)) - messageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(inboundData)) - originHost = (await self.diameterLibrary.getAvpDataAsync(avps, 264))[0] + packetVars, avps = await(self.diameterLibrary.decodeDiameterPacket(inboundData)) + messageType = await(self.diameterLibrary.getDiameterMessageType(inboundData)) + originHost = (await self.diameterLibrary.getAvpData(avps, 264))[0] originHost = bytes.fromhex(originHost).decode("utf-8") peerType = await(self.diameterLibrary.getPeerType(originHost)) self.activePeers[f"{clientAddress}-{clientPort}"].update({'lastDwrTimestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activePeers[f"{clientAddress}-{clientPort}"]['lastDwrTimestamp'], @@ -55,9 +56,9 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb "type": "inbound"}, metricExpiry=60)) except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}", redisClient=self.redisMessaging)) return False - return TruediameterHostname + return True async def handleActiveDiameterPeers(self): """ @@ -123,21 +124,27 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc if len(inboundData) > 0: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + if self.benchmarking: + diamteterValidationStartTime = time.perf_counter() if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, terminating connection.", redisClient=self.redisMessaging)) - return False - - diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=inboundData)) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.", redisClient=self.redisMessaging)) + await(asyncio.sleep(0)) + continue + if self.benchmarking: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to validate diameter request: {round(((time.perf_counter() - diamteterValidationStartTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + + + diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=inboundData)) diameterMessageType = diameterMessageType.get('inbound', '') inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) + await(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) if self.benchmarking: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) - - + await(asyncio.sleep(0)) + except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}", redisClient=self.redisMessaging)) return False @@ -155,31 +162,30 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s if writer.transport.is_closing(): return False - pendingOutboundQueues = await(self.redisMessaging.getQueues(pattern='diameter-outbound*')) - if not len(pendingOutboundQueues) > 0: + pendingOutboundQueue = await(self.redisMessaging.getNextQueue(pattern=f'diameter-outbound-{clientAddress.replace(".", "*")}-{clientPort}-*')) + if not len(pendingOutboundQueue) > 0: await(asyncio.sleep(0)) continue - - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queues: {pendingOutboundQueues}", redisClient=self.redisMessaging)) - for outboundQueue in pendingOutboundQueues: - outboundQueueSplit = str(outboundQueue).split('-') - queuedMessageType = outboundQueueSplit[1] - diameterOutboundHost = outboundQueueSplit[2] - diameterOutboundPort = outboundQueueSplit[3] - - if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {outboundQueue} to host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) - diameterOutbound = json.loads(await(self.redisMessaging.getMessage(queue=outboundQueue))) - diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) - diameterMessageType = await(self.diameterLibrary.getDiameterMessageTypeAsync(binaryData=diameterOutboundBinary)) - diameterMessageType = diameterMessageType.get('outbound', '') - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.", redisClient=self.redisMessaging)) - writer.write(diameterOutboundBinary) - await(writer.drain()) - await(asyncio.sleep(0)) - if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) - + pendingOutboundQueue = pendingOutboundQueue.decode() + + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}", redisClient=self.redisMessaging)) + outboundQueueSplit = str(pendingOutboundQueue).split('-') + queuedMessageType = outboundQueueSplit[1] + diameterOutboundHost = outboundQueueSplit[2] + diameterOutboundPort = outboundQueueSplit[3] + + if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {pendingOutboundQueue} to host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + diameterOutbound = json.loads(await(self.redisMessaging.getMessage(queue=pendingOutboundQueue))) + diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) + diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=diameterOutboundBinary)) + diameterMessageType = diameterMessageType.get('outbound', '') + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.", redisClient=self.redisMessaging)) + writer.write(diameterOutboundBinary) + await(writer.drain()) + await(asyncio.sleep(0)) + if self.benchmarking: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) except Exception: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.", redisClient=self.redisMessaging)) diff --git a/services/hssService.py b/services/hssService.py index cc4c0b7..d8910ac 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -24,7 +24,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.originHost = self.config.get('hss', {}).get('OriginHost', f'hss01') self.productName = self.config.get('hss', {}).get('ProductName', f'PyHSS') self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) - self.diameterLibrary = Diameter(redisMessaging=self.redisMessaging, logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) + self.diameterLibrary = Diameter(logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) From a9d7ccdd67318bad11217a3a64485508e8fbd62f Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 1 Sep 2023 07:53:30 +1000 Subject: [PATCH 11/43] Update requirements.txt --- requirements.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 548af43..3fecc47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -construct==2.10.68 +aiohttp==3.8.5 Flask==2.2.3 flask_restx==1.1.0 Jinja2==3.1.2 @@ -9,11 +9,10 @@ pymongo==4.3.3 pysctp==0.7.2 pysnmp==4.4.12 PyYAML==6.0 -redis==4.5.4 -Requests==2.28.2 +redis==5.0.0 +Requests==2.31.0 SQLAlchemy==2.0.9 -sqlalchemy_utils -systemd-python==234 +SQLAlchemy_Utils==0.41.1 Werkzeug==2.2.3 mysqlclient prometheus_flask_exporter \ No newline at end of file From 276c94d3334d53fa31e8ca1d34dc08f26629e415 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Mon, 11 Sep 2023 17:43:56 +1000 Subject: [PATCH 12/43] Fix Redis hang, update CLR --- lib/diameter.py | 79 ++++++++++++++++++++++++++++++++++--- lib/diameterAsync.py | 6 ++- lib/logtool.py | 12 +++++- services/apiService.py | 19 +++++++++ services/diameterService.py | 58 ++++++++++++++------------- services/georedService.py | 75 ++++++++++++++++++----------------- services/hssService.py | 1 - 7 files changed, 174 insertions(+), 76 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index af578ab..7fad040 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -10,6 +10,8 @@ from database import Database from messaging import RedisMessaging import yaml +import json +import time class Diameter: @@ -28,6 +30,7 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 else: self.redisMessaging=RedisMessaging() self.database = Database(logTool=logTool) + self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) self.logTool.log(service='HSS', level='info', message=f"Initialized Diameter Library", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=f"Origin Host: {str(originHost)}", redisClient=self.redisMessaging) @@ -55,8 +58,8 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 ] self.diameterRequestList = [ - {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, - {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, + {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer", "validPeerTypes": ['MME']}, + {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer", "validPeerTypes": ['MME']}, {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, ] @@ -425,7 +428,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: activePeers = self.redisMessaging.getValue(key="ActiveDiameterPeers") for key, value in activePeers.items(): - if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': filteredConnectedPeers.append(activePeers.get(key, {})) return filteredConnectedPeers @@ -433,6 +436,18 @@ def getConnectedPeersByType(self, peerType: str) -> list: except Exception as e: return [] + def getPeerByHostname(self, hostname: str) -> dict: + try: + hostname = hostname.lower() + activePeers = self.redisMessaging.getValue(key="ActiveDiameterPeers") + + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + return(activePeers.get(key, {})) + + except Exception as e: + return {} + def getDiameterMessageType(self, binaryData: str) -> dict: packet_vars, avps = self.decode_diameter_packet(binaryData) response = {} @@ -448,7 +463,7 @@ def getDiameterMessageType(self, binaryData: str) -> dict: continue return response - def generateDiameterRequest(self, requestType: str, **kwargs) -> str: + def generateDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: try: request = '' self.logTool.log(service='HSS', level='debug', message=f"Generating a diameter outbound request", redisClient=self.redisMessaging) @@ -456,10 +471,17 @@ def generateDiameterRequest(self, requestType: str, **kwargs) -> str: for diameterApplication in self.diameterRequestList: try: assert(requestType == diameterApplication["requestAcronym"]) - request = diameterApplication["requestMethod"](kwargs) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Successfully generated request: {request}", redisClient=self.redisMessaging) except Exception as e: continue + connectedPeer = self.getPeerByHostname(hostname=hostname) + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + request = diameterApplication["requestMethod"](kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Successfully generated request: {request}", redisClient=self.redisMessaging) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" + outboundMessage = {'diameter-outbound': json.dumps(request)} + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request except Exception as e: return '' @@ -727,6 +749,21 @@ def Answer_16777251_316(self, packet_vars, avps): try: subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details self.logTool.log(service='HSS', level='debug', message="Got back subscriber_details: " + str(subscriber_details), redisClient=self.redisMessaging) + + if subscriber_details['enabled'] == 0: + self.logTool.log(service='HSS', level='debug', message=f"Subscriber {imsi} is disabled", redisClient=self.redisMessaging) + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + self.logTool.log(service='HSS', level='debug', message=f"Successfully Generated ULA for disabled Subscriber: {imsi}", redisClient=self.redisMessaging) + response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) + return response + except ValueError as e: self.logTool.log(service='HSS', level='error', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='error', message="Error is " + str(e), redisClient=self.redisMessaging) @@ -922,6 +959,36 @@ def Answer_16777251_318(self, packet_vars, avps): try: subscriber_details = self.database.Get_Subscriber(imsi=imsi) #Get subscriber details + if subscriber_details['enabled'] == 0: + self.logTool.log(service='HSS', level='debug', message=f"Subscriber {imsi} is disabled", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #Result Code + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Disabled User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(5001, 4), avps=avps, packet_vars=packet_vars) #AVP Experimental-Result-Code: DIAMETER_ERROR_USER_UNKNOWN (5001) + avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) + + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message=f"Successfully Generated ULA for disabled Subscriber: {imsi}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"{response}", redisClient=self.redisMessaging) + return response except ValueError as e: self.logTool.log(service='HSS', level='info', message="Minor getting subscriber details for IMSI " + str(imsi), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=e, redisClient=self.redisMessaging) diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index fe1fc22..585f864 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -1,10 +1,12 @@ #Diameter Packet Decoder / Encoder & Tools import math import asyncio +from messagingAsync import RedisMessagingAsync + class DiameterAsync: - def __init__(self, redisMessaging, logTool): + def __init__(self, logTool): self.diameterCommandList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, @@ -24,7 +26,7 @@ def __init__(self, redisMessaging, logTool): {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] - self.redisMessaging = redisMessaging + self.redisMessaging = RedisMessagingAsync() self.logTool = logTool diff --git a/lib/logtool.py b/lib/logtool.py index 668a30e..b3528f4 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -4,6 +4,8 @@ from datetime import datetime sys.path.append(os.path.realpath('../')) import asyncio +from messagingAsync import RedisMessagingAsync +from messaging import RedisMessaging class TimestampFilter (logging.Filter): """ @@ -31,11 +33,15 @@ def __init__(self, config: dict): 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, } self.logLevel = config.get('logging', {}).get('level', 'INFO') + self.redisMessagingAsync = RedisMessagingAsync() + self.redisMessaging = RedisMessaging() - async def logAsync(self, service: str, level: str, message: str, redisClient) -> bool: + async def logAsync(self, service: str, level: str, message: str, redisClient=None) -> bool: """ Tests loglevel, prints to console and queues a log message to an asynchronous redis messaging client. """ + if redisClient == None: + redisClient = self.redisMessagingAsync configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) if not messageLogLevelVerbosity <= configLogLevelVerbosity: @@ -46,10 +52,12 @@ async def logAsync(self, service: str, level: str, message: str, redisClient) -> asyncio.ensure_future(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) return True - def log(self, service: str, level: str, message: str, redisClient) -> bool: + def log(self, service: str, level: str, message: str, redisClient=None) -> bool: """ Tests loglevel, prints to console and queues a log message to a synchronous redis messaging client. """ + if redisClient == None: + redisClient = self.redisMessaging configLogLevelVerbosity = self.logLevels.get(self.logLevel.upper(), {}).get('verbosity', 4) messageLogLevelVerbosity = self.logLevels.get(level.upper(), {}).get('verbosity', 4) if not messageLogLevelVerbosity <= configLogLevelVerbosity: diff --git a/services/apiService.py b/services/apiService.py index 00f9892..59d1b41 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -448,6 +448,24 @@ def patch(self, subscriber_id): operation_id = args.get('operation_id', None) data = databaseClient.UpdateObj(SUBSCRIBER, json_data, subscriber_id, False, operation_id) + #If the "enabled" flag on the subscriber is now disabled, trigger a CLR + if 'enabled' in json_data and json_data['enabled'] == False: + print("Subscriber is now disabled, checking to see if we need to trigger a CLR") + #See if we have a serving MME set + try: + assert(json_data['serving_mme']) + print("Serving MME set - Sending CLR") + diameterClient.generateDiameterRequest( + requestType='CLR', + imsi=json_data['imsi'], + DestinationHost=json_data['serving_mme'], + DestinationRealm=json_data['serving_mme_realm'], + CancellationType=1 + ) + print("Sent CLR via Peer " + str(json_data['serving_mme'])) + except: + print("No serving MME set - Not sending CLR") + print("Updated object") print(data) return data, 200 @@ -1414,6 +1432,7 @@ def put(self, imsi): if 'DestinationHost' not in json_data: json_data['DestinationHost'] = None diam_hex = diameterClient.sendDiameterRequest( + requestType='CLR', imsi=imsi, DestinationHost=json_data['DestinationHost'], DestinationRealm=json_data['DestinationRealm'], diff --git a/services/diameterService.py b/services/diameterService.py index ca7eec6..e22ea39 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -24,10 +24,12 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): print(f"[Diameter] [__init__] Fatal Error - config.yaml not found, exiting.") quit() - self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisReaderMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisWriterMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisPeerMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.banners = Banners() self.logTool = LogTool(config=self.config) - self.diameterLibrary = DiameterAsync(redisMessaging=self.redisMessaging, logTool=self.logTool) + self.diameterLibrary = DiameterAsync(logTool=self.logTool) self.activePeers = {} self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) @@ -46,7 +48,7 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb 'diameterHostname': originHost, 'peerType': peerType, }) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', + await(self.redisReaderMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Diameter Inbounds', metricLabels={ @@ -56,7 +58,7 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb "type": "inbound"}, metricExpiry=60)) except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}")) return False return True @@ -81,12 +83,12 @@ async def handleActiveDiameterPeers(self): stalePeers.append(key) if len(stalePeers) > 0: - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [handleActiveDiameterPeers] Pruning disconnected peers: {stalePeers}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [handleActiveDiameterPeers] Pruning disconnected peers: {stalePeers}")) for key in stalePeers: del self.activePeers[key] await(self.logActivePeers()) - await(self.redisMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) + await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) await(asyncio.sleep(1)) except Exception as e: @@ -101,14 +103,14 @@ async def logActivePeers(self): activePeers = self.activePeers if not len(activePeers) > 0: activePeers = '' - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActivePeers] {len(self.activePeers)} Active Peers {activePeers}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActivePeers] {len(self.activePeers)} Active Peers {activePeers}")) async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. Terminates the connection if diameter traffic is not received, or if the client disconnects. """ - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}")) while True: try: @@ -118,20 +120,20 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc startTime = time.perf_counter() if reader.at_eof(): - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.")) return False if len(inboundData) > 0: - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}")) if self.benchmarking: diamteterValidationStartTime = time.perf_counter() if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): - await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) await(asyncio.sleep(0)) continue if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to validate diameter request: {round(((time.perf_counter() - diamteterValidationStartTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to validate diameter request: {round(((time.perf_counter() - diamteterValidationStartTime)*1000), 3)} ms")) diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=inboundData)) @@ -139,21 +141,21 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}", redisClient=self.redisMessaging)) - await(self.redisMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}")) + await(self.redisReaderMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) await(asyncio.sleep(0)) except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}")) return False async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Continually polls the Redis queue for outbound messages. Received messages from the queue are validated against the connected client, and sent. """ - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}")) while True: try: if self.benchmarking: @@ -162,33 +164,33 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s if writer.transport.is_closing(): return False - pendingOutboundQueue = await(self.redisMessaging.getNextQueue(pattern=f'diameter-outbound-{clientAddress.replace(".", "*")}-{clientPort}-*')) + pendingOutboundQueue = await(self.redisWriterMessaging.getNextQueue(pattern=f'diameter-outbound-{clientAddress.replace(".", "*")}-{clientPort}-*')) if not len(pendingOutboundQueue) > 0: await(asyncio.sleep(0)) continue pendingOutboundQueue = pendingOutboundQueue.decode() - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}")) outboundQueueSplit = str(pendingOutboundQueue).split('-') queuedMessageType = outboundQueueSplit[1] diameterOutboundHost = outboundQueueSplit[2] diameterOutboundPort = outboundQueueSplit[3] if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {pendingOutboundQueue} to host {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) - diameterOutbound = json.loads(await(self.redisMessaging.getMessage(queue=pendingOutboundQueue))) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {pendingOutboundQueue} to host {clientAddress} on port {clientPort}")) + diameterOutbound = json.loads(await(self.redisWriterMessaging.getMessage(queue=pendingOutboundQueue))) diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=diameterOutboundBinary)) diameterMessageType = diameterMessageType.get('outbound', '') - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) writer.write(diameterOutboundBinary) await(writer.drain()) await(asyncio.sleep(0)) if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) except Exception: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.")) return False await(asyncio.sleep(0)) @@ -200,7 +202,7 @@ async def handleConnection(self, reader, writer): try: coroutineUuid = str(uuid.uuid4()) (clientAddress, clientPort) = writer.get_extra_info('peername') - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}")) if f"{clientAddress}-{clientPort}" not in self.activePeers: self.activePeers[f"{clientAddress}-{clientPort}"] = { "connectTimestamp": '', @@ -246,12 +248,12 @@ async def handleConnection(self, reader, writer): "connectionStatus": 'disconnected', "disconnectTimestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }) - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}.")) await(self.logActivePeers()) return except Exception as e: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [handleConnection] [{coroutineUuid}] Unhandled exception in diameterService.handleConnection: {e}")) return async def startServer(self, host: str=None, port: int=None, type: str=None): @@ -276,7 +278,7 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) - await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}")) handleActiveDiameterPeerTask = asyncio.create_task(self.handleActiveDiameterPeers()) async with server: diff --git a/services/georedService.py b/services/georedService.py index 39db3e4..81d32ec 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -21,7 +21,8 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.logTool = LogTool(self.config) self.banners = Banners() - self.redisMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisGeoredMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisWebhookMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) self.georedPeers = self.config.get('geored', {}).get('endpoints', []) self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) @@ -65,9 +66,9 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr responseStatusCode = response.status if 200 <= responseStatusCode <= 299: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendGeored] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}")) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -78,7 +79,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr metricExpiry=60)) break else: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -89,9 +90,9 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr metricExpiry=60)) except aiohttp.ClientConnectionError as e: error_message = str(e) - await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) if "Name or service not known" in error_message: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -101,7 +102,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr "error": "No matching DNS entry found"}, metricExpiry=60)) else: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -111,8 +112,8 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr "error": "Connection Refused"}, metricExpiry=60)) except aiohttp.ServerTimeoutError: - await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendGeored] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -122,8 +123,8 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr "error": "Timeout"}, metricExpiry=60)) except Exception as e: - await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendGeored] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisGeoredMessaging.sendMetric(serviceName='geored', metricName='prom_http_geored', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -133,7 +134,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr "error": e}, metricExpiry=60)) if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendGeored] Time taken to send individual geored request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendGeored] Time taken to send individual geored request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) return True @@ -167,9 +168,9 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h responseStatusCode = response.status if 200 <= responseStatusCode <= 299: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [sendWebhook] Operation {operation} executed successfully on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}")) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes', metricLabels={ @@ -180,7 +181,7 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h metricExpiry=60)) break else: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -191,9 +192,9 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h metricExpiry=60)) except aiohttp.ClientConnectionError as e: error_message = str(e) - await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} failed on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) if "Name or service not known" in error_message: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -203,7 +204,7 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h "error": "No matching DNS entry found"}, metricExpiry=60)) else: - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -213,8 +214,8 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h "error": "Connection Refused"}, metricExpiry=60)) except aiohttp.ServerTimeoutError: - await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + await(self.logTool.logAsync(service='Geored', level='warning', message=f"[Geored] [sendWebhook] Operation {operation} timed out on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -224,8 +225,8 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h "error": "Timeout"}, metricExpiry=60)) except Exception as e: - await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendWebhook] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}", redisClient=self.redisMessaging)) - asyncio.ensure_future(self.redisMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', + await(self.logTool.logAsync(service='Geored', level='error', message=f"[Geored] [sendWebhook] Operation {operation} encountered unknown error on {url}, with body: ({body}) and transactionId {transactionId}. Response code: {responseStatusCode}. Error Message: {e}")) + asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ @@ -235,7 +236,7 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h "error": e}, metricExpiry=60)) if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendWebhook] Time taken to send individual webhook request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendWebhook] Time taken to send individual webhook request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) return True @@ -248,11 +249,11 @@ async def handleGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedQueue = await(self.redisMessaging.getNextQueue(pattern='geored-*')) - georedMessage = await(self.redisMessaging.getMessage(queue=georedQueue)) + georedQueue = await(self.redisGeoredMessaging.getNextQueue(pattern='geored-*')) + georedMessage = await(self.redisGeoredMessaging.getMessage(queue=georedQueue)) if len(georedMessage) > 0: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}", redisClient=self.redisMessaging)) - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) georedDict = json.loads(georedMessage) georedOperation = georedDict['operation'] @@ -263,12 +264,12 @@ async def handleGeoredQueue(self): georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody)) await asyncio.gather(*georedTasks) if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) await(asyncio.sleep(0)) except Exception as e: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}")) await(asyncio.sleep(0)) continue @@ -281,11 +282,11 @@ async def handleWebhookQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - webhookQueue = await(self.redisMessaging.getNextQueue(pattern='webhook-*')) - webhookMessage = await(self.redisMessaging.getMessage(queue=webhookQueue)) + webhookQueue = await(self.redisWebhookMessaging.getNextQueue(pattern='webhook-*')) + webhookMessage = await(self.redisWebhookMessaging.getMessage(queue=webhookQueue)) if len(webhookMessage) > 0: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}", redisClient=self.redisMessaging)) - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) webhookDict = json.loads(webhookMessage) webhookHeaders = webhookDict['headers'] @@ -297,12 +298,12 @@ async def handleWebhookQueue(self): webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders)) await asyncio.gather(*webhookTasks) if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) await(asyncio.sleep(0)) except Exception as e: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}")) await(asyncio.sleep(0)) continue @@ -310,7 +311,7 @@ async def startService(self): """ Performs sanity checks on configuration and starts the geored and webhook tasks, when enabled. """ - await(self.logTool.logAsync(service='Geored', level='info', message=f"{self.banners.georedService()}", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"{self.banners.georedService()}")) while True: georedEnabled = self.config.get('geored', {}).get('enabled', False) @@ -323,7 +324,7 @@ async def startService(self): webhooksEnabled = False if not georedEnabled and not webhooksEnabled: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [startService] Geored and Webhook services both disabled or missing peers, exiting.", redisClient=self.redisMessaging)) + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [startService] Geored and Webhook services both disabled or missing peers, exiting.")) sys.exit() activeTasks = [] diff --git a/services/hssService.py b/services/hssService.py index d8910ac..23cf436 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -73,7 +73,6 @@ def handleQueue(self): if self.benchmarking: self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) - except Exception as e: continue From 011ad94d2961db0ab329065828cba0d2416872eb Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 13 Sep 2023 13:14:21 +1000 Subject: [PATCH 13/43] Update diameter.py --- lib/diameter.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index 7fad040..4a4cfa6 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -17,7 +17,7 @@ class Diameter: def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc999.mcc999.3gppnetwork.org", productName: str="PyHSS", mcc: str="999", mnc: str="999", redisMessaging=None): with open("../config.yaml", 'r') as stream: - self.yaml_config = (yaml.safe_load(stream)) + self.config = (yaml.safe_load(stream)) self.OriginHost = self.string_to_hex(originHost) self.OriginRealm = self.string_to_hex(originRealm) @@ -663,7 +663,7 @@ def Answer_257(self, packet_vars, avps): for avps_to_check in avps: #Only include AVP 278 (Origin State) if inital request included it if avps_to_check['avp_code'] == 278: avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) - for host in self.yaml_config['hss']['bind_ip']: #Loop through all IPs from Config and add to response + for host in self.config['hss']['bind_ip']: #Loop through all IPs from Config and add to response avp += self.generate_avp(257, 40, self.ip_to_hex(host)) #Host-IP-Address (For this to work on Linux this is the IP defined in the hostsfile for localhost) avp += self.generate_avp(266, 40, "00000000") #Vendor-Id avp += self.generate_avp(269, "00", self.ProductName) #Product-Name @@ -792,7 +792,7 @@ def Answer_16777251_316(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777251_316] [ULR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) self.database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) @@ -1171,7 +1171,7 @@ def Answer_16777238_272(self, packet_vars, avps): except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) - remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1213,7 +1213,7 @@ def Answer_16777238_272(self, packet_vars, avps): ue_ip = 'Failed to Decode / Get UE IP' #Store PGW location into Database - remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) #Supported-Features(628) (Gx feature list) @@ -1376,9 +1376,9 @@ def Answer_16777216_300(self, packet_vars, avps): avp += self.generate_avp(297, 40, experimental_avp) #Expermental-Result else: self.logTool.log(service='HSS', level='debug', message="No SCSCF Assigned from DB", redisClient=self.redisMessaging) - if 'scscf_pool' in self.yaml_config['hss']: + if 'scscf_pool' in self.config['hss']: try: - scscf = random.choice(self.yaml_config['hss']['scscf_pool']) + scscf = random.choice(self.config['hss']['scscf_pool']) self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: @@ -1468,7 +1468,7 @@ def Answer_16777216_301(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="Subscriber is served by S-CSCF " + str(ServingCSCF), redisClient=self.redisMessaging) if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): self.logTool.log(service='HSS', level='debug', message="SAR is Register / Re-Restister", redisClient=self.redisMessaging) - remote_peer = remote_peer + ";" + str(self.yaml_config['hss']['OriginHost']) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) self.database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) else: self.logTool.log(service='HSS', level='debug', message="SAR is not Register", redisClient=self.redisMessaging) @@ -1501,9 +1501,9 @@ def Answer_16777216_302(self, packet_vars, avps): avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(str(ims_subscriber_details['scscf']))),'ascii')) else: self.logTool.log(service='HSS', level='debug', message="No SCSF assigned - Using SCSCF Pool", redisClient=self.redisMessaging) - if 'scscf_pool' in self.yaml_config['hss']: + if 'scscf_pool' in self.config['hss']: try: - scscf = random.choice(self.yaml_config['hss']['scscf_pool']) + scscf = random.choice(self.config['hss']['scscf_pool']) self.logTool.log(service='HSS', level='debug', message="Randomly picked SCSCF address " + str(scscf) + " from pool", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: @@ -1770,7 +1770,7 @@ def Answer_16777217_306(self, packet_vars, avps): #This loads a Jinja XML template containing the Sh-User-Data templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] + sh_userdata_template = self.config['hss']['Default_Sh_UserData'] self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use @@ -1963,7 +1963,7 @@ def Answer_16777291_8388622(self, packet_vars, avps): avp_serving_node = '' avp_serving_node += self.generate_vendor_avp(2402, "c0", 10415, self.string_to_hex(subscriber_details['serving_mme'])) #MME-Name avp_serving_node += self.generate_vendor_avp(2408, "c0", 10415, self.OriginRealm) #MME-Realm - avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(self.yaml_config['hss']['bind_ip'][0])) #GMLC-Address + avp_serving_node += self.generate_vendor_avp(2405, "c0", 10415, self.ip_to_hex(self.config['hss']['bind_ip'][0])) #GMLC-Address avp += self.generate_vendor_avp(2401, "c0", 10415, avp_serving_node) #Serving-Node AVP #Set Result-Code @@ -2042,7 +2042,7 @@ def Request_16777251_316(self, imsi, DestinationRealm): sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.yaml_config['hss']['OriginHost'])),'ascii')) + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(283, 40, self.string_to_hex(DestinationRealm)) #Destination Realm avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) @@ -2377,7 +2377,7 @@ def Request_16777216_301(self, imsi, domain, server_assignment_type): avp = '' #Initiate empty var AVP #Session-ID sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session Session ID - avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.yaml_config['hss']['OriginHost'])),'ascii')) #Origin Host + avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000000") #Vendor-Specific-Application-ID for Cx @@ -2716,7 +2716,7 @@ def Request_16777217_307(self, msisdn): #This loads a Jinja XML template containing the Sh-User-Data templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) - sh_userdata_template = self.yaml_config['hss']['Default_Sh_UserData'] + sh_userdata_template = self.config['hss']['Default_Sh_UserData'] self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use From f0c811f83495bea3423deb4ebd66430aa1fb7487 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Thu, 14 Sep 2023 10:22:12 +1000 Subject: [PATCH 14/43] Add geored peers and webhooks to api --- services/apiService.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/services/apiService.py b/services/apiService.py index 59d1b41..0651115 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1421,6 +1421,36 @@ def get(self): response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored Schema: " + str(E)} return response_json +@ns_geored.route('/peers') +class PyHSS_Geored_Peers(Resource): + def get(self): + '''Return the configured geored peers''' + try: + georedEnabled = config.get('geored', {}).get('enabled', False) + if not georedEnabled: + return {'result': 'Failed', 'Reason' : "Geored not enabled"} + georedPeers = config.get('geored', {}).get('endpoints', []) + return {'peers': georedPeers}, 200 + except Exception as E: + print("Exception when returning geored peers: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored peers: " + str(E)} + return response_json + +@ns_geored.route('/webhooks') +class PyHSS_Geored_Webhooks(Resource): + def get(self): + '''Return the configured geored webhooks''' + try: + georedEnabled = config.get('webhooks', {}).get('enabled', False) + if not georedEnabled: + return {'result': 'Failed', 'Reason' : "Webhooks not enabled"} + georedWebhooks = config.get('webhooks', {}).get('endpoints', []) + return {'endpoints': georedWebhooks}, 200 + except Exception as E: + print("Exception when returning geored webhooks: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to return Geored webhooks: " + str(E)} + return response_json + @ns_push.route('/clr/') class PyHSS_Push_CLR(Resource): @ns_push.expect(Push_CLR_Model) From ac67d0eae8dbb304a9a5984ee6213318722e6cbc Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 19 Sep 2023 17:56:51 +1000 Subject: [PATCH 15/43] Add metrics.py --- lib/database.py | 4 ++-- lib/diameter.py | 13 +++++++------ lib/metrics.py | 41 +++++++++++++++++++++++++++++++++++++++ services/georedService.py | 2 +- 4 files changed, 51 insertions(+), 9 deletions(-) create mode 100644 lib/metrics.py diff --git a/lib/database.py b/lib/database.py index 64ee118..3a5c73f 100755 --- a/lib/database.py +++ b/lib/database.py @@ -320,14 +320,14 @@ def load_IMEI_database_into_Redis(self): model = result[2].lstrip() if count == 0: self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) - redis_imei_result = self.redisMessaging.getMessage(key=str(tac_prefix)) + redis_imei_result = self.redisMessaging.getMessage(queue=str(tac_prefix)) if len(redis_imei_result) != 0: self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) break else: self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - self.redisMessaging.sendMessage(key=str(tac_prefix), value_dict=imei_result) + self.redisMessaging.sendMessage(queue=str(tac_prefix), message=imei_result) count = count +1 self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) except Exception as E: diff --git a/lib/diameter.py b/lib/diameter.py index 4a4cfa6..df3fffd 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -32,6 +32,9 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.database = Database(logTool=logTool) self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) + self.templateLoader = jinja2.FileSystemLoader(searchpath="../") + self.templateEnv = jinja2.Environment(loader=self.templateLoader) + self.logTool.log(service='HSS', level='info', message=f"Initialized Diameter Library", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=f"Origin Host: {str(originHost)}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='info', message=f"Realm: {str(originRealm)}", redisClient=self.redisMessaging) @@ -40,7 +43,7 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.diameterResponseList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, - {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, @@ -1147,7 +1150,7 @@ def Answer_16777251_323(self, packet_vars, avps): SupportedFeatures += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP response = self.generate_diameter_packet("01", "40", 323, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.logTool.log(service='HSS', level='debug', message="Successfully Generated PUA", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Successfully Generated NOA", redisClient=self.redisMessaging) return response #3GPP Gx Credit Control Answer @@ -1725,7 +1728,7 @@ def Answer_16777217_306(self, packet_vars, avps): self.logTool.log(service='HSS', level='error', message="No MSISDN", redisClient=self.redisMessaging) try: username = self.get_avp_data(avps, 601)[0] - except: + except Exception as e: self.logTool.log(service='HSS', level='error', message="No Username", redisClient=self.redisMessaging) if msisdn is not None: @@ -1768,11 +1771,9 @@ def Answer_16777217_306(self, packet_vars, avps): #Sh-User-Data (XML) #This loads a Jinja XML template containing the Sh-User-Data - templateLoader = jinja2.FileSystemLoader(searchpath="./") - templateEnv = jinja2.Environment(loader=templateLoader) sh_userdata_template = self.config['hss']['Default_Sh_UserData'] self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) - template = templateEnv.get_template(sh_userdata_template) + template = self.templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) subscriber_details['mcc'] = self.MCC.zfill(3) diff --git a/lib/metrics.py b/lib/metrics.py new file mode 100644 index 0000000..3db7918 --- /dev/null +++ b/lib/metrics.py @@ -0,0 +1,41 @@ +class Metrics: + + def __init__(self, redisMessaging): + self.redisMessaging = redisMessaging + + def initializeMetrics(self) -> bool: + """ + Preloads all metrics, and sets their initial value to 0. + """ + + print("Initializing Metrics") + + metricList = [ + {'serviceName':'api', 'metricName':'prom_flask_http_geored_endpoints', 'metricType':'counter', 'metricHelp':'Number of Geored Pushes Received'}, + {'serviceName':'diameter', 'metricName':'prom_diam_inbound_count', 'metricType':'counter', 'metricHelp':'Number of Diameter Inbounds'}, + {'serviceName':'geored', 'metricName':'prom_http_geored', 'metricType':'counter', 'metricHelp':'Number of Geored Pushes'}, + {'serviceName':'webhook', 'metricName':'prom_http_webhook', 'metricType':'counter', 'metricHelp':'Number of Webhook Pushes'}, + {'serviceName':'database', 'metricName':'prom_eir_devices', 'metricType':'counter', 'metricHelp':'Profile of attached devices'}, + {'serviceName':'diameter', 'metricName':'prom_ims_subs', 'metricType':'gauge', 'metricHelp':'Number of attached IMS Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_mme_subs', 'metricType':'gauge', 'metricHelp':'Number of attached MME Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_pcrf_subs', 'metricType':'gauge', 'metricHelp':'Number of attached PCRF Subscribers'}, + {'serviceName':'diameter', 'metricName':'prom_diam_auth_event_count', 'metricType':'counter', 'metricHelp':'Diameter Authentication related Counters'}, + {'serviceName':'diameter', 'metricName':'prom_diam_response_count_successful', 'metricType':'counter', 'metricHelp':'Number of Successful Diameter Responses'}, + {'serviceName':'diameter', 'metricName':'prom_diam_response_count_fail', 'metricType':'counter', 'metricHelp':'Number of Failed Diameter Responses'} + ] + + for metric in metricList: + try: + self.redisMessaging.sendMetric(serviceName=metric['serviceName'], + metricName=metric['metricName'], + metricType=metric['metricType'], + metricAction='inc', + metricValue=0.0, + metricHelp=metric['metricHelp'], + metricLabels=metric['metricLabels'], + metricExpiry=60) + except Exception as e: + print(e) + pass + + return True \ No newline at end of file diff --git a/services/georedService.py b/services/georedService.py index 81d32ec..67e5b3f 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -172,7 +172,7 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h asyncio.ensure_future(self.redisWebhookMessaging.sendMetric(serviceName='webhook', metricName='prom_http_webhook', metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Geored Pushes', + metricValue=1.0, metricHelp='Number of Webhook Pushes', metricLabels={ "webhook_host": str(url.replace('https://', '').replace('http://', '')), "endpoint": "webhook", From c1996f3e268a5ae42c520e4a0375e4b239732bd4 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Thu, 21 Sep 2023 14:56:46 +1000 Subject: [PATCH 16/43] Memory leak fix --- lib/database.py | 3 +- lib/diameter.py | 718 +++++++++++++++++++++--------------- lib/diameterAsync.py | 282 +++++++------- lib/messagingAsync.py | 7 +- services/diameterService.py | 25 +- services/georedService.py | 78 ++-- services/hssService.py | 1 + services/logService.py | 2 + services/metricService.py | 1 + 9 files changed, 629 insertions(+), 488 deletions(-) diff --git a/lib/database.py b/lib/database.py index 3a5c73f..5fcd575 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1560,8 +1560,7 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of HSS events", redisClient=self.redisMessaging) except Exception as E: - self.logTool.log(service='Database', level='error', message="Error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) - raise + self.logTool.log(service='Database', level='error', message="Error occurred in Update_Serving_MME: " + str(E), redisClient=self.redisMessaging) finally: self.safe_close(session) diff --git a/lib/diameter.py b/lib/diameter.py index df3fffd..08e8568 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -12,6 +12,7 @@ import yaml import json import time +import traceback class Diameter: @@ -305,98 +306,200 @@ def generate_diameter_packet(self, packet_version, packet_flags, packet_command_ + def roundUpToMultiple(self, n, multiple): + return ((n + multiple - 1) // multiple) * multiple + + + def validateSingleAvp(self, data) -> bool: + """ + Attempts to validate a single hex string diameter AVP as being an AVP. + """ + try: + avpCode = int(data[0:8], 16) + # The next byte contains the AVP Flags + avpFlags = data[8:10] + # The next 3 bytes contain the AVP Length + avpLength = int(data[10:16], 16) + if avpFlags not in ['80', '40', '20', '00', 'c0']: + return False + if int(len(data[16:]) / 2) < ((avpLength - 8)): + return False + return True + except Exception as e: + return False + + def decode_diameter_packet(self, data): + """ + Handles decoding of a full diameter packet. + """ packet_vars = {} avps = [] - + if type(data) is bytes: data = data.hex() - - + # One byte is 2 hex characters + # First Byte is the Diameter Packet Version packet_vars['packet_version'] = data[0:2] + # Next 3 Bytes are the length of the entire Diameter packet packet_vars['length'] = int(data[2:8], 16) + # Next Byte is the Diameter Flags packet_vars['flags'] = data[8:10] packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + # Next 3 Bytes are the Diameter Command Code packet_vars['command_code'] = int(data[10:16], 16) + # Next 4 Bytes are the Application Id packet_vars['ApplicationId'] = int(data[16:24], 16) + # Next 4 Bytes are the Hop By Hop Identifier packet_vars['hop-by-hop-identifier'] = data[24:32] + # Next 4 Bytes are the End to End Identifier packet_vars['end-to-end-identifier'] = data[32:40] - avp_sum = data[40:] - avp_vars, remaining_avps = self.decode_avp_packet(avp_sum) - avps.append(avp_vars) - - while len(remaining_avps) > 0: - avp_vars, remaining_avps = self.decode_avp_packet(remaining_avps) - avps.append(avp_vars) - else: - pass + lengthOfDiameterVars = int(len(data[:40]) / 2) + + #Length of all AVPs, in bytes + avpLength = int(packet_vars['length'] - lengthOfDiameterVars) + avpCharLength = int((avpLength * 2)) + remaining_avps = data[40:] + + avps = self.decodeAvpPacket(remaining_avps) + return packet_vars, avps + def decodeAvpPacket(self, data): + """ + Returns a list of decoded AVP Packet dictionaries. + This function is called at a high frequency, decoding methods should stick to iteration and not recursion, to avoid a memory leak. + """ + # Note: After spending hours on this, I'm leaving the following technical debt: + # Subavps and all their descendents are lifted up, flat, side by side into the parent's sub_avps list. + # It's definitely possible to keep the nested tree structure, if anyone wants to improve this function. But I can't figure out a simple way to do so, without invoking recursion. - def decode_avp_packet(self, data): - if len(data) <= 8: - #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP - raise ValueError("Length of data is too short to be valid AVP") + # Our final list of AVP Dictionaries, which will be returned once processing is complete. + processed_avps = [] + # Initialize a failsafe counter, to prevent packets that pass validation but aren't AVPs from causing an infinite loop + failsafeCounter = 0 - avp_vars = {} - avp_vars['avp_code'] = int(data[0:8], 16) - - avp_vars['avp_flags'] = data[8:10] - avp_vars['avp_length'] = int(data[10:16], 16) - if avp_vars['avp_flags'] == "c0": - #If c0 is present AVP is Vendor AVP - avp_vars['vendor_id'] = int(data[16:24], 16) - avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] - else: - #if is not a vendor AVP - avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] + # If the avp data is 8 bytes (16 chars) or less, it's invalid. + if len(data) < 16: + return [] - if avp_vars['avp_length'] % 4 == 0: - #Multiple of 4 - No Padding needed - avp_vars['padding'] = 0 - else: - #Not multiple of 4 - Padding needed - rounded_value = self.myround(avp_vars['avp_length']) - avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 - avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] + # Working stack to aid in iterative processing of sub-avps. + subAvpUnprocessedStack = [] + # Keep processing AVPs until they're all dealt with + while len(data) > 16: + try: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + avp_vars = {} + # The first 4 bytes contains the AVP code + avp_vars['avp_code'] = int(data[0:8], 16) + # The next byte contains the AVP Flags + avp_vars['avp_flags'] = data[8:10] + # The next 3 bytes contains the AVP Length + avp_vars['avp_length'] = int(data[10:16], 16) + # The remaining bytes (until the end, defined by avp_length) is the AVP payload. + # Padding is excluded from avp_length. It's calculated separately, and unknown by the AVP itself. + # We calculate the avp payload length (in bytes) by subtracting 8, because the avp headers are always 8 bytes long. + # The result is then multiplied by 2 to give us chars. + avpPayloadLength = int((avp_vars['avp_length'])*2) + + # Work out our vendor id and add the payload itself (misc_data) + if avp_vars['avp_flags'] == 'c0' or avp_vars['avp_flags'] == '80': + avp_vars['vendor_id'] = int(data[16:24], 16) + avp_vars['misc_data'] = data[24:avpPayloadLength] + else: + avp_vars['vendor_id'] = '' + avp_vars['misc_data'] = data[16:avpPayloadLength] - #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] - try: - sub_avp_vars, sub_remaining_avps = self.decode_avp_packet(avp_vars['misc_data']) - #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... - if int(sub_avp_vars['avp_code']) > 9999: - pass - else: - #If the decoded AVP is valid store it - avp_vars['misc_data'] = [] - avp_vars['misc_data'].append(sub_avp_vars) - #While there are more AVPs to be decoded, decode them: - while len(sub_remaining_avps) > 0: - sub_avp_vars, sub_remaining_avps = self.decode_avp_packet(sub_remaining_avps) - avp_vars['misc_data'].append(sub_avp_vars) - - except Exception as e: - if str(e) == "invalid literal for int() with base 16: ''": - pass - elif str(e) == "Length of data is too short to be valid AVP": - pass - else: - self.logTool.log(service='HSS', level='debug', message="failed to decode sub-avp - error: " + str(e), redisClient=self.redisMessaging) - pass + payloadContainsSubAvps = self.validateSingleAvp(avp_vars['misc_data']) + if payloadContainsSubAvps: + # If the payload contains sub or grouped AVPs, append misc_data to the subAvpUnprocessedStack to start working through one or more subavp + subAvpUnprocessedStack.append(avp_vars["misc_data"]) + avp_vars["misc_data"] = '' + + # Rounds up the length to the nearest multiple of 4, which we can differential against the avp length to give us the padding length (if required) + avp_padded_length = int((self.roundUpToMultiple(avp_vars['avp_length'], 4))) + avpPaddingLength = ((avp_padded_length - avp_vars['avp_length']) * 2) + # Initialize a blank sub_avps list, regardless of whether or not we have any sub avps. + avp_vars['sub_avps'] = [] - remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again - return avp_vars, remaining_avps + while payloadContainsSubAvps: + # Increment our failsafe counter, which will fail after 100 tries. This prevents a rare validation error from causing the function to hang permanently. + failsafeCounter += 1 + + if failsafeCounter > 100: + break + + # Pop the sub avp data from the list (remove from the end) + sub_avp_data = subAvpUnprocessedStack.pop() + + # Initialize our sub avp dictionary, and grab the usual values + sub_avp = {} + sub_avp['avp_code'] = int(sub_avp_data[0:8], 16) + sub_avp['avp_flags'] = sub_avp_data[8:10] + sub_avp['avp_length'] = int(sub_avp_data[10:16], 16) + sub_avpPayloadLength = int((sub_avp['avp_length'])*2) + + if sub_avp['avp_flags'] == 'c0' or sub_avp['avp_flags'] == '80': + sub_avp['vendor_id'] = int(sub_avp_data[16:24], 16) + sub_avp['misc_data'] = sub_avp_data[24:sub_avpPayloadLength] + else: + sub_avp['vendor_id'] = '' + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + + containsSubAvps = self.validateSingleAvp(sub_avp["misc_data"]) + if containsSubAvps: + subAvpUnprocessedStack.append(sub_avp["misc_data"]) + sub_avp["misc_data"] = '' + + avp_vars['sub_avps'].append(sub_avp) + + sub_avp_padded_length = int((self.roundUpToMultiple(sub_avp['avp_length'], 4))) + subAvpPaddingLength = ((sub_avp_padded_length - sub_avp['avp_length']) * 2) + + sub_avp_data = sub_avp_data[sub_avpPayloadLength+subAvpPaddingLength:] + containsNestedSubAvps = self.validateSingleAvp(sub_avp_data) + + # Check for nested sub avps and bring them to the top of the stack, for further processing. + if containsNestedSubAvps: + subAvpUnprocessedStack.append(sub_avp_data) + + if containsSubAvps or containsNestedSubAvps: + payloadContainsSubAvps = True + else: + payloadContainsSubAvps = False + + if avpPaddingLength > 0: + processed_avps.append(avp_vars) + data = data[avpPayloadLength+avpPaddingLength:] + else: + processed_avps.append(avp_vars) + data = data[avpPayloadLength:] + except Exception as e: + print(e) + continue + + return processed_avps def get_avp_data(self, avps, avp_code): #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) misc_data = [] - for keys in avps: - if keys['avp_code'] == avp_code: - misc_data.append(keys['misc_data']) + for avpObject in avps: + if int(avpObject['avp_code']) == int(avp_code): + if len(avpObject['misc_data']) == 0: + misc_data.append(avpObject['sub_avps']) + else: + misc_data.append(avpObject['misc_data']) + if 'sub_avps' in avpObject: + for sub_avp in avpObject['sub_avps']: + if int(sub_avp['avp_code']) == int(avp_code): + misc_data.append(sub_avp['misc_data']) return misc_data def decode_diameter_packet_length(self, data): @@ -461,7 +564,7 @@ def getDiameterMessageType(self, binaryData: str) -> dict: assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) response['inbound'] = diameterApplication["requestAcronym"] response['outbound'] = diameterApplication["responseAcronym"] - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Successfully generated response: {response}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Matched message types: {response}", redisClient=self.redisMessaging) except Exception as e: continue return response @@ -778,8 +881,6 @@ def Answer_16777251_316(self, packet_vars, avps): except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.logTool.critical(message) - self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise #Store MME Location into Database @@ -884,7 +985,7 @@ def Answer_16777251_316(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(subscriber_routing_dict['ip_address']), redisClient=self.redisMessaging) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) except Exception as E: - self.logTool.log(service='HSS', level='debug', message="Error getting static UE IP: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="No static UE IP found: " + str(E), redisClient=self.redisMessaging) Served_Party_Address = "" @@ -956,6 +1057,7 @@ def Answer_16777251_316(self, packet_vars, avps): #3GPP S6a/S6d Authentication Information Answer def Answer_16777251_318(self, packet_vars, avps): + self.logTool.log(service='HSS', level='debug', message=f"AIA AVPS: {avps}", redisClient=self.redisMessaging) imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from User-Name AVP in request @@ -1026,77 +1128,78 @@ def Answer_16777251_318(self, packet_vars, avps): except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.logTool.critical(message) - self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise + try: + requested_vectors = 1 + for avp in avps: + if avp['avp_code'] == 1408: + self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) + EUTRAN_Authentication_Info = avp['misc_data'] + self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) + for sub_avp in EUTRAN_Authentication_Info: + #If resync request + if sub_avp['avp_code'] == 1411: + self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Resync", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + auts = str(sub_avp['misc_data'])[32:] + rand = str(sub_avp['misc_data'])[:32] + rand = binascii.unhexlify(rand) + #Calculate correct SQN + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + + #Get number of requested vectors + if sub_avp['avp_code'] == 1410: + self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) + requested_vectors = int(sub_avp['misc_data'], 16) + if requested_vectors >= 32: + self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) + requested_vectors = 32 + + self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) + eutranvector_complete = '' + while requested_vectors != 0: + self.logTool.log(service='HSS', level='debug', message="Generating vector number " + str(requested_vectors), redisClient=self.redisMessaging) + plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request + vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) + eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) + eutranvector += self.generate_vendor_avp(1419, "c0", 10415, self.int_to_hex(requested_vectors, 4)) + eutranvector += self.generate_vendor_avp(1447, "c0", 10415, vector_dict['rand']) #And is made up of other AVPs joined together with RAND + eutranvector += self.generate_vendor_avp(1448, "c0", 10415, vector_dict['xres']) #XRes + eutranvector += self.generate_vendor_avp(1449, "c0", 10415, vector_dict['autn']) #AUTN + eutranvector += self.generate_vendor_avp(1450, "c0", 10415, vector_dict['kasme']) #And KASME + + requested_vectors = requested_vectors - 1 + eutranvector_complete += self.generate_vendor_avp(1414, "c0", 10415, eutranvector) #Put EUTRAN vectors in E-UTRAN-Vector AVP + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_vendor_avp(1413, "c0", 10415, eutranvector_complete) #Authentication-Info (3GPP) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") + #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) + + response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='debug', message="Successfully Generated AIA", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=response, redisClient=self.redisMessaging) + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=traceback.format_exc(), redisClient=self.redisMessaging) - requested_vectors = 1 - for avp in avps: - if avp['avp_code'] == 1408: - self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) - EUTRAN_Authentication_Info = avp['misc_data'] - self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) - for sub_avp in EUTRAN_Authentication_Info: - #If resync request - if sub_avp['avp_code'] == 1411: - self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', - metricType='counter', metricAction='inc', - metricValue=1.0, - metricLabels={ - "diameter_application_id": 16777251, - "diameter_cmd_code": 318, - "event": "Resync", - "imsi_prefix": str(imsi[0:6])}, - metricHelp='Diameter Authentication related Counters', - metricExpiry=60) - auts = str(sub_avp['misc_data'])[32:] - rand = str(sub_avp['misc_data'])[:32] - rand = binascii.unhexlify(rand) - #Calculate correct SQN - self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - - #Get number of requested vectors - if sub_avp['avp_code'] == 1410: - self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) - requested_vectors = int(sub_avp['misc_data'], 16) - if requested_vectors >= 32: - self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) - requested_vectors = 32 - - self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) - eutranvector_complete = '' - while requested_vectors != 0: - self.logTool.log(service='HSS', level='debug', message="Generating vector number " + str(requested_vectors), redisClient=self.redisMessaging) - plmn = self.get_avp_data(avps, 1407)[0] #Get PLMN from request - vector_dict = self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "air", plmn=plmn) - eutranvector = '' #This goes into the payload of AVP 10415 (Authentication info) - eutranvector += self.generate_vendor_avp(1419, "c0", 10415, self.int_to_hex(requested_vectors, 4)) - eutranvector += self.generate_vendor_avp(1447, "c0", 10415, vector_dict['rand']) #And is made up of other AVPs joined together with RAND - eutranvector += self.generate_vendor_avp(1448, "c0", 10415, vector_dict['xres']) #XRes - eutranvector += self.generate_vendor_avp(1449, "c0", 10415, vector_dict['autn']) #AUTN - eutranvector += self.generate_vendor_avp(1450, "c0", 10415, vector_dict['kasme']) #And KASME - - requested_vectors = requested_vectors - 1 - eutranvector_complete += self.generate_vendor_avp(1414, "c0", 10415, eutranvector) #Put EUTRAN vectors in E-UTRAN-Vector AVP - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_vendor_avp(1413, "c0", 10415, eutranvector_complete) #Authentication-Info (3GPP) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) - avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000023") - #avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777251),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (S6a) - - response = self.generate_diameter_packet("01", "40", 318, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - self.logTool.log(service='HSS', level='debug', message="Successfully Generated AIA", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=response, redisClient=self.redisMessaging) - return response #Purge UE Answer (PUA) def Answer_16777251_321(self, packet_vars, avps): @@ -1155,146 +1258,166 @@ def Answer_16777251_323(self, packet_vars, avps): #3GPP Gx Credit Control Answer def Answer_16777238_272(self, packet_vars, avps): - CC_Request_Type = self.get_avp_data(avps, 416)[0] - CC_Request_Number = self.get_avp_data(avps, 415)[0] - #Called Station ID - self.logTool.log(service='HSS', level='debug', message="Attempting to find APN in CCR", redisClient=self.redisMessaging) - apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message="CCR for APN " + str(apn), redisClient=self.redisMessaging) - - OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it - - OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP - OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it - - try: #Check if we have a record-route set as that's where we'll need to send the response - remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header - remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it - except: #If we don't have a record-route set, we'll send the response to the OriginHost - remote_peer = OriginHost - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) - remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set - avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) - avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type - avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number - - - #Get Subscriber info from Subscription ID - for SubscriptionIdentifier in self.get_avp_data(avps, 443): - for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - self.logTool.log(service='HSS', level='debug', message="Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) - if UniqueSubscriptionIdentifier['avp_code'] == 444: - imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message="Found IMSI " + str(imsi), redisClient=self.redisMessaging) - - self.logTool.log(service='HSS', level='info', message="SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) try: - self.logTool.log(service='HSS', level='info', message="Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details - ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) - self.logTool.log(service='HSS', level='info', message="Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) - except Exception as E: - #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) + CC_Request_Type = self.get_avp_data(avps, 416)[0] + CC_Request_Number = self.get_avp_data(avps, 415)[0] + #Called Station ID + self.logTool.log(service='HSS', level='debug', message="Attempting to find APN in CCR", redisClient=self.redisMessaging) + apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="CCR for APN " + str(apn), redisClient=self.redisMessaging) + + OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP + OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it + + OriginRealm = self.get_avp_data(avps, 296)[0] #Get OriginRealm from AVP + OriginRealm = binascii.unhexlify(OriginRealm).decode('utf-8') #Format it + + try: #Check if we have a record-route set as that's where we'll need to send the response + remote_peer = self.get_avp_data(avps, 282)[-1] #Get first record-route header + remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it + except: #If we don't have a record-route set, we'll send the response to the OriginHost + remote_peer = OriginHost + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Session Id is " + str(binascii.unhexlify(session_id).decode()), redisClient=self.redisMessaging) + avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) + avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type + avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number + - if int(CC_Request_Type) == 1: - self.logTool.log(service='HSS', level='info', message="Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) + #Get Subscriber info from Subscription ID + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + self.logTool.log(service='HSS', level='debug', message="Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Found IMSI " + str(imsi), redisClient=self.redisMessaging) - #Get UE IP + self.logTool.log(service='HSS', level='info', message="SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) try: - ue_ip = self.get_avp_data(avps, 8)[0] - ue_ip = str(self.hex_to_ip(ue_ip)) + self.logTool.log(service='HSS', level='info', message="Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details + ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) + self.logTool.log(service='HSS', level='info', message="Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) except Exception as E: - self.logTool.log(service='HSS', level='error', message="Failed to get UE IP", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) - ue_ip = 'Failed to Decode / Get UE IP' + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) - #Store PGW location into Database - remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) - self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) - #Supported-Features(628) (Gx feature list) - avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + if int(CC_Request_Type) == 1: + self.logTool.log(service='HSS', level='info', message="Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) - #Default EPS Beaerer QoS (From database with fallback source CCR-I) - try: - apn_data = ChargingRules['apn_data'] - self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) - #AMBR - AMBR = '' #Initiate empty var AVP for AMBR - apn_ambr_ul = int(apn_data['apn_ambr_ul']) - apn_ambr_dl = int(apn_data['apn_ambr_dl']) - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - - self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) - #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) - avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) - except Exception as E: - self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) - default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] - avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - - - self.logTool.log(service='HSS', level='info', message="Creating QoS Information", redisClient=self.redisMessaging) - #QoS-Information - try: - apn_data = ChargingRules['apn_data'] - apn_ambr_ul = int(apn_data['apn_ambr_ul']) - apn_ambr_dl = int(apn_data['apn_ambr_dl']) - QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) - QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - self.logTool.log(service='HSS', level='info', message="Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="Populated QoS_Information", redisClient=self.redisMessaging) - avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - except Exception as E: - self.logTool.log(service='HSS', level='error', message="Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) - - QoS_Information = '' - for AMBR_Part in self.get_avp_data(avps, 1016)[0]: - self.logTool.log(service='HSS', level='debug', message=AMBR_Part, redisClient=self.redisMessaging) - AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) - QoS_Information += AMBR_AVP - self.logTool.log(service='HSS', level='debug', message="QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) - avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - self.logTool.log(service='HSS', level='debug', message="QoS information set statically", redisClient=self.redisMessaging) - - self.logTool.log(service='HSS', level='info', message="Added to AVP List", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message="QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) - - #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I - if ChargingRules and ChargingRules['charging_rules'] is not None: + #Get UE IP try: - self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) - for individual_charging_rule in ChargingRules['charging_rules']: - self.logTool.log(service='HSS', level='debug', message="Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) - avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) + ue_ip = self.get_avp_data(avps, 8)[0] + ue_ip = str(self.hex_to_ip(ue_ip)) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Failed to get UE IP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + ue_ip = 'Failed to Decode / Get UE IP' + + #Store PGW location into Database + remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=OriginHost, subscriber_routing=str(ue_ip), serving_pgw_realm=OriginRealm, serving_pgw_peer=remote_peer) + + #Supported-Features(628) (Gx feature list) + avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + #Default EPS Beaerer QoS (From database with fallback source CCR-I) + try: + apn_data = ChargingRules['apn_data'] + self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) + #AMBR + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = int(apn_data['apn_ambr_ul']) + apn_ambr_dl = int(apn_data['apn_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) + #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) except Exception as E: - self.logTool.log(service='HSS', level='debug', message="Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) + default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] + avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - elif int(CC_Request_Type) == 3: - self.logTool.log(service='HSS', level='info', message="Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) - self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + self.logTool.log(service='HSS', level='info', message="Creating QoS Information", redisClient=self.redisMessaging) + #QoS-Information + try: + apn_data = ChargingRules['apn_data'] + apn_ambr_ul = int(apn_data['apn_ambr_ul']) + apn_ambr_dl = int(apn_data['apn_ambr_dl']) + QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) + QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) + self.logTool.log(service='HSS', level='info', message="Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Populated QoS_Information", redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) + + QoS_Information = '' + for AMBR_Part in self.get_avp_data(avps, 1016)[0]: + self.logTool.log(service='HSS', level='debug', message=AMBR_Part, redisClient=self.redisMessaging) + AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) + QoS_Information += AMBR_AVP + self.logTool.log(service='HSS', level='debug', message="QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) + self.logTool.log(service='HSS', level='debug', message="QoS information set statically", redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='info', message="Added to AVP List", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) + + #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I + if ChargingRules and ChargingRules['charging_rules'] is not None: + try: + self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) + for individual_charging_rule in ChargingRules['charging_rules']: + self.logTool.log(service='HSS', level='debug', message="Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) + avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) + + except Exception as E: + self.logTool.log(service='HSS', level='debug', message="Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) + + elif int(CC_Request_Type) == 3: + self.logTool.log(service='HSS', level='info', message="Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) + + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + except Exception as e: #Get subscriber details + #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777238, + "diameter_cmd_code": 272, + "event": "Unknown User", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #Result Code (DIAMETER ERROR - User Unknown) + experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") + #Experimental Result (297) + avp += self.generate_avp(297, 40, experimental_result) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response #3GPP Cx User Authorization Answer @@ -1349,7 +1472,7 @@ def Answer_16777216_300(self, packet_vars, avps): avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 300, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + response = self.generate_diameter_packet("01", "40", 300, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response #Determine SAR Type & Store @@ -1534,7 +1657,7 @@ def Answer_16777216_302(self, packet_vars, avps): avp_experimental_result += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID avp_experimental_result += self.generate_avp(298, 40, self.int_to_hex(result_code, 4)) #AVP Experimental-Result-Code avp += self.generate_avp(297, 40, avp_experimental_result) #AVP Experimental-Result(297) - response = self.generate_diameter_packet("01", "40", 302, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + response = self.generate_diameter_packet("01", "40", 302, 16777216, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response avp += self.generate_avp(268, 40, "000007d1") #DIAMETER_SUCCESS @@ -1552,6 +1675,7 @@ def Answer_16777216_303(self, packet_vars, avps): imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part self.logTool.log(service='HSS', level='debug', message="Got MAR username: " + str(username), redisClient=self.redisMessaging) + auth_scheme = '' avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID @@ -1615,7 +1739,8 @@ def Answer_16777216_303(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="IMSI is " + str(imsi), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) avp += self.generate_avp(1, 40, str(binascii.hexlify(str.encode(imsi + "@" + domain)),'ascii')) #Username - + + #Determine Vectors to Generate if auth_scheme == "Digest-MD5": @@ -1828,6 +1953,7 @@ def Answer_16777252_324(self, packet_vars, avps): #Get IMSI try: + imei = '' imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) @@ -1836,36 +1962,40 @@ def Answer_16777252_324(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) - #Get IMEI - for sub_avp in self.get_avp_data(avps, 1401)[0]: - self.logTool.log(service='HSS', level='debug', message="Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI", redisClient=self.redisMessaging) - if sub_avp['avp_code'] == 1402: - imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message="Found IMEI " + str(imei), redisClient=self.redisMessaging) - - avp = '' #Initiate empty var AVP - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID - avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 - avp += self.generate_avp(277, 40, "00000001") #Auth Session State - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - #Experimental Result AVP(Response Code for Failure) - avp_experimental_result = '' - avp_experimental_result += self.generate_vendor_avp(266, 'c0', 10415, '') #AVP Vendor ID - avp_experimental_result += self.generate_avp(298, 'c0', self.int_to_hex(2001, 4)) #AVP Experimental-Result-Code: SUCESS (2001) - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + try: + #Get IMEI + for sub_avp in self.get_avp_data(avps, 1401)[0]: + self.logTool.log(service='HSS', level='debug', message="Evaluating sub_avp AVP " + str(sub_avp) + " to find IMSI", redisClient=self.redisMessaging) + if sub_avp['avp_code'] == 1402: + imei = binascii.unhexlify(sub_avp['misc_data']).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="Found IMEI " + str(imei), redisClient=self.redisMessaging) + + avp = '' #Initiate empty var AVP + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(260, 40, "0000010a4000000c000028af000001024000000c01000024") #Vendor-Specific-Application-ID for S13 + avp += self.generate_avp(277, 40, "00000001") #Auth Session State + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + #Experimental Result AVP(Response Code for Failure) + avp_experimental_result = '' + avp_experimental_result += self.generate_vendor_avp(266, 'c0', 10415, '') #AVP Vendor ID + avp_experimental_result += self.generate_avp(298, 'c0', self.int_to_hex(2001, 4)) #AVP Experimental-Result-Code: SUCESS (2001) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + + #Equipment-Status + EquipmentStatus = self.database.Check_EIR(imsi=imsi, imei=imei) + avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_eir_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "response": EquipmentStatus}, + metricHelp='Diameter EIR event related Counters', + metricExpiry=60) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=traceback.format_exc(), redisClient=self.redisMessaging) - #Equipment-Status - EquipmentStatus = self.database.Check_EIR(imsi=imsi, imei=imei) - avp += self.generate_vendor_avp(1445, 'c0', 10415, self.int_to_hex(EquipmentStatus, 4)) - self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_eir_event_count', - metricType='counter', metricAction='inc', - metricValue=1.0, - metricLabels={ - "response": EquipmentStatus}, - metricHelp='Diameter EIR event related Counters', - metricExpiry=60) response = self.generate_diameter_packet("01", "40", 324, 16777252, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response @@ -2185,8 +2315,6 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): except Exception as ex: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) - self.logTool.critical(message) - self.logTool.critical("Unhandled general exception when getting subscriber details for IMSI " + str(imsi)) raise diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index 585f864..ff15735 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -39,6 +39,9 @@ async def myRound(self, n, base=4): else: return 4 + async def roundUpToMultiple(self, n, multiple): + return ((n + multiple - 1) // multiple) * multiple + async def getAvpData(self, avps, avp_code): #Loops through list of dicts generated by the packet decoder, and returns the data for a specific AVP code in list (May be more than one AVP with same code but different data) misc_data = [] @@ -47,172 +50,169 @@ async def getAvpData(self, avps, avp_code): misc_data.append(keys['misc_data']) return misc_data - # async def decodeDiameterPacket(self, data): - # packet_vars = {} - # avps = [] - - # if type(data) is bytes: - # data = data.hex() - - # packet_vars['packet_version'] = data[0:2] - # packet_vars['length'] = int(data[2:8], 16) - # packet_vars['flags'] = data[8:10] - # packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) - # packet_vars['command_code'] = int(data[10:16], 16) - # packet_vars['ApplicationId'] = int(data[16:24], 16) - # packet_vars['hop-by-hop-identifier'] = data[24:32] - # packet_vars['end-to-end-identifier'] = data[32:40] - - # avp_sum = data[40:] - - # avp_vars, remaining_avps = await(self.decodeAvpPacket(avp_sum)) - # avps.append(avp_vars) - - # while len(remaining_avps) > 0: - # avp_vars, remaining_avps = await(self.decodeAvpPacket(remaining_avps)) - # avps.append(avp_vars) - # else: - # pass - # return packet_vars, avps + async def validateSingleAvp(self, data) -> bool: + """ + Attempts to validate a single hex string diameter AVP as being an AVP. + """ + try: + avpCode = int(data[0:8], 16) + # The next byte contains the AVP Flags + avpFlags = data[8:10] + # The next 3 bytes contain the AVP Length + avpLength = int(data[10:16], 16) + if avpFlags not in ['80', '40', '20', '00', 'c0']: + #print(f"[AVP VALIDATION] Failed to validate due to invalid Flag: {data}") + return False + if int(len(data[16:]) / 2) < ((avpLength - 8)): + #print(f"[AVP VALIDATION] Failed to validate due to invalid length: {data}") + return False + return True + except Exception as e: + return False + async def decodeDiameterPacket(self, data): + """ + Handles decoding of a full diameter packet. + """ packet_vars = {} avps = [] if type(data) is bytes: data = data.hex() - + # One byte is 2 hex characters + # First Byte is the Diameter Packet Version packet_vars['packet_version'] = data[0:2] + # Next 3 Bytes are the length of the entire Diameter packet packet_vars['length'] = int(data[2:8], 16) + # Next Byte is the Diameter Flags packet_vars['flags'] = data[8:10] packet_vars['flags_bin'] = bin(int(data[8:10], 16))[2:].zfill(8) + # Next 3 Bytes are the Diameter Command Code packet_vars['command_code'] = int(data[10:16], 16) + # Next 4 Bytes are the Application Id packet_vars['ApplicationId'] = int(data[16:24], 16) + # Next 4 Bytes are the Hop By Hop Identifier packet_vars['hop-by-hop-identifier'] = data[24:32] + # Next 4 Bytes are the End to End Identifier packet_vars['end-to-end-identifier'] = data[32:40] - remaining_avps = data[40:] - - while len(remaining_avps) > 0: - avp_vars, remaining_avps = await self.decodeAvpPacket(remaining_avps) - avps.append(avp_vars) - else: - pass - - return packet_vars, avps - - async def decodeAvpPacket(self, data): - avp_vars = {} - sub_avps = [] - - if len(data) <= 8: - raise ValueError("Length of data is too short to be valid AVP") - - avp_vars['avp_code'] = int(data[0:8], 16) - - avp_vars['avp_flags'] = data[8:10] - avp_vars['avp_length'] = int(data[10:16], 16) - avp_padded_length = (avp_vars['avp_length'] + 3) // 4 * 4 - - if avp_vars['avp_flags'] == "c0": - avp_vars['vendor_id'] = int(data[16:24], 16) - avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] - else: - avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] - - sub_avp_data = avp_vars['misc_data'] + #We're enforcing correct length, and calculate the end byte based on the length of the remaining AVPs and the known 'length' packet var. - while len(sub_avp_data) >= 16: - sub_avp_vars = {} - sub_avp_vars['avp_code'] = int(sub_avp_data[0:8], 16) - sub_avp_vars['avp_flags'] = sub_avp_data[8:10] - sub_avp_vars['avp_length'] = int(sub_avp_data[10:16], 16) - sub_avp_padded_length = (sub_avp_vars['avp_length'] + 3) // 4 * 4 + lengthOfDiameterVars = int(len(data[:40]) / 2) + #print(f"Length of Diameter Vars (Bytes): {lengthOfDiameterVars}") - if sub_avp_vars['avp_code'] > 9999: - break - - if '40' <= sub_avp_vars['avp_flags'] <= '7F': - sub_avp_vars['vendor_id'] = int(sub_avp_data[16:24], 16) - sub_avp_vars['misc_data'] = sub_avp_data[24:(24 + (sub_avp_vars['avp_length'] - 8) * 2)] - else: - sub_avp_vars['misc_data'] = sub_avp_data[16:(16 + (sub_avp_vars['avp_length'] - 8) * 2)] - - sub_avps.append(sub_avp_vars) - - sub_avp_data = sub_avp_data[(sub_avp_padded_length * 2):] - - avp_vars['sub_avps'] = sub_avps - - if avp_vars['avp_length'] % 4 == 0: - avp_vars['padding'] = 0 - else: - rounded_value = await self.myRound(avp_vars['avp_length']) - avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 - avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] - - remaining_avps = data[(avp_padded_length * 2):] - - return avp_vars, remaining_avps + #Length of all AVPs, in bytes + avpLength = int(packet_vars['length'] - lengthOfDiameterVars) + #print(f"avpLength (bytes): {avpLength}") + avpCharLength = int((avpLength * 2)) + #print(f"avpCharLength (chars): {avpCharLength}") + #print(f"Total Data Length (bytes) {len(data) / 2}") + remaining_avps = data[40:] + #print(remaining_avps) + avps = await self.decodeAvpPacket(remaining_avps) + #print(f"Got Back: {avps}") + return packet_vars, avps + async def decodeAvpPacket(self, data): + """ + Returns a list of decoded AVP Packet dictionaries. + """ + processed_avps = [] + # Initialize a failsafe counter, to prevent packets that pass validation but aren't AVPs from causing an infinite loop + failsafeCounter = 0 - # async def decodeAvpPacket(self, data): + # If the avp data is 8 bytes (16 chars) or less, it's invalid. + if len(data) < 16: + return [] - # if len(data) <= 8: - # #if length is less than 8 it is too short to be an AVP and is most likely the data from the last AVP being attempted to be parsed as another AVP - # raise ValueError("Length of data is too short to be valid AVP") + # Keep processing AVPs until they're all dealt with + while len(data) > 16: + try: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + avp_vars = {} + #print(f"AVP Data: {data}") + # The first 4 bytes contains the AVP code + avp_vars['avp_code'] = int(data[0:8], 16) + # The next byte contains the AVP Flags + avp_vars['avp_flags'] = data[8:10] + # The next 3 bytes contains the AVP Length + avp_vars['avp_length'] = int(data[10:16], 16) + #print(f"Individual AVP Length: {avp_vars['avp_length']}") + # The remaining bytes (until the end, defined by avp_length) is the AVP payload. + # Padding is excluded from avp_length. It's calculated separately, and unknown by the AVP itself. + # We calculate the avp payload length (in bytes) by subtracting 8, because the avp headers are always 8 bytes long. + # The result is then multiplied by 2 to give us chars. + avpPayloadLength = int((avp_vars['avp_length'])*2) + #print(f"AVP Payload Length (Chars): {avpPayloadLength}") + + # Work out our vendor id and add the payload itself (misc_data) + if avp_vars['avp_code'] == 266: + avp_vars['vendor_id'] = int(data[16:24], 16) + avp_vars['misc_data'] = data[16:avpPayloadLength] + else: + avp_vars['vendor_id'] = '' + avp_vars['misc_data'] = data[16:avpPayloadLength] + + # Rounds up the length to the nearest multiple of 4, which we can differential against the avp length to give us the padding length (if required) + avp_padded_length = int((await(self.roundUpToMultiple(avp_vars['avp_length'], 4)))) + # avp_padded_length = (avp_vars['avp_length'] + 3) // 4 * 4 + avpPaddingLength = ((avp_padded_length - avp_vars['avp_length']) * 2) + #print(f"AVP Padding length (Chars): {avpPaddingLength}") + + avp_vars['sub_avps'] = [] + + # Check if the payload data contains sub or grouped AVPs inside + payloadContainsSubAvps = await(self.validateSingleAvp(avp_vars['misc_data'])) + + if payloadContainsSubAvps: + # If the payload contains sub or grouped AVPs, assign misc_data to sub_avps to start working through them + sub_avp_data = avp_vars['misc_data'] + + while payloadContainsSubAvps: + failsafeCounter += 1 + + if failsafeCounter > 100: + break + sub_avp = {} + sub_avp['avp_code'] = int(sub_avp_data[0:8], 16) + sub_avp['avp_flags'] = sub_avp_data[8:10] + sub_avp['avp_length'] = int(sub_avp_data[10:16], 16) + sub_avpPayloadLength = int((sub_avp['avp_length'])*2) + + if sub_avp['avp_code'] == 266: + sub_avp['vendor_id'] = int(sub_avp_data[16:24], 16) + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + else: + sub_avp['vendor_id'] = '' + sub_avp['misc_data'] = sub_avp_data[16:sub_avpPayloadLength] + + avp_vars['sub_avps'].append(sub_avp) + + #print(f"Sub Avp Data before trimming: {sub_avp_data}") + #print(f"Sub Avp payload length: {sub_avpPayloadLength}") + sub_avp_data = sub_avp_data[sub_avpPayloadLength:] + avp_vars['misc_data'] = avp_vars['misc_data'][sub_avpPayloadLength:] + #print(f"Sub Avp Data after trimming: {sub_avp_data}") + payloadContainsSubAvps = await(self.validateSingleAvp(sub_avp_data)) + + if avpPaddingLength > 0: + processed_avps.append(avp_vars) + data = data[avpPayloadLength+avpPaddingLength:] + else: + processed_avps.append(avp_vars) + data = data[avpPayloadLength:] + except Exception as e: + #print(f"EXCEPTION: {e}") + continue - # avp_vars = {} - # avp_vars['avp_code'] = int(data[0:8], 16) - - # avp_vars['avp_flags'] = data[8:10] - # avp_vars['avp_length'] = int(data[10:16], 16) - # if avp_vars['avp_flags'] == "c0": - # #If c0 is present AVP is Vendor AVP - # avp_vars['vendor_id'] = int(data[16:24], 16) - # avp_vars['misc_data'] = data[24:(avp_vars['avp_length']*2)] - # else: - # #if is not a vendor AVP - # avp_vars['misc_data'] = data[16:(avp_vars['avp_length']*2)] - - # if avp_vars['avp_length'] % 4 == 0: - # #Multiple of 4 - No Padding needed - # avp_vars['padding'] = 0 - # else: - # #Not multiple of 4 - Padding needed - # rounded_value = await(self.myRound(avp_vars['avp_length'])) - # avp_vars['padding'] = int( rounded_value - avp_vars['avp_length']) * 2 - # avp_vars['padded_data'] = data[(avp_vars['avp_length']*2):(avp_vars['avp_length']*2)+avp_vars['padding']] - - - # #If body of avp_vars['misc_data'] contains AVPs, then decode each of them as a list of dicts like avp_vars['misc_data'] = [avp_vars, avp_vars] - # try: - # sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacket(avp_vars['misc_data'])) - # #Sanity check - If the avp code is greater than 9999 it's probably not an AVP after all... - # if int(sub_avp_vars['avp_code']) > 9999: - # pass - # else: - # #If the decoded AVP is valid store it - # avp_vars['misc_data'] = [] - # avp_vars['misc_data'].append(sub_avp_vars) - # #While there are more AVPs to be decoded, decode them: - # while len(sub_remaining_avps) > 0: - # sub_avp_vars, sub_remaining_avps = await(self.decodeAvpPacket(sub_remaining_avps)) - # avp_vars['misc_data'].append(sub_avp_vars) - - # except Exception as e: - # if str(e) == "invalid literal for int() with base 16: ''": - # pass - # elif str(e) == "Length of data is too short to be valid AVP": - # pass - # else: - # pass - - remaining_avps = data[(avp_vars['avp_length']*2)+avp_vars['padding']:] #returns remaining data in avp string back for processing again - return avp_vars, remaining_avps + return processed_avps async def getPeerType(self, originHost: str) -> str: try: diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index a3eb297..af4f54c 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -106,13 +106,12 @@ async def getNextQueue(self, pattern: str='*') -> str: Returns the next Queue (Key) in the list, asynchronously. """ try: - result = [] async for nextQueue in self.redisClient.scan_iter(match=pattern): - result.append(nextQueue) - return next(iter(result), '') if result else '' + if nextQueue is not None: + return nextQueue.decode('utf-8') except Exception as e: print(e) - return '' + return '' async def deleteQueue(self, queue: str) -> bool: """ diff --git a/services/diameterService.py b/services/diameterService.py index e22ea39..b7ad54a 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -13,7 +13,7 @@ class DiameterService: """ PyHSS Diameter Service A class for handling diameter inbounds and replies on Port 3868, via TCP. - Functions in this class are high-performance, please edit with care. Last benchmarked on 24-08-2023. + Functions in this class are high-performance, please edit with care. Last profiled on 20-09-2023. """ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): @@ -41,7 +41,7 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb try: packetVars, avps = await(self.diameterLibrary.decodeDiameterPacket(inboundData)) messageType = await(self.diameterLibrary.getDiameterMessageType(inboundData)) - originHost = (await self.diameterLibrary.getAvpData(avps, 264))[0] + originHost = (await(self.diameterLibrary.getAvpData(avps, 264)))[0] originHost = bytes.fromhex(originHost).decode("utf-8") peerType = await(self.diameterLibrary.getPeerType(originHost)) self.activePeers[f"{clientAddress}-{clientPort}"].update({'lastDwrTimestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activePeers[f"{clientAddress}-{clientPort}"]['lastDwrTimestamp'], @@ -59,6 +59,7 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb metricExpiry=60)) except Exception as e: await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}")) + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] AVPs: {avps}\nPacketVars: {packetVars}")) return False return True @@ -70,7 +71,7 @@ async def handleActiveDiameterPeers(self): while True: try: if not len(self.activePeers) > 0: - await(asyncio.sleep(0)) + await(asyncio.sleep(1)) continue activeDiameterPeersTimeout = self.config.get('hss', {}).get('active_diameter_peers_timeout', 3600) @@ -130,7 +131,7 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc diamteterValidationStartTime = time.perf_counter() if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) continue if self.benchmarking: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to validate diameter request: {round(((time.perf_counter() - diamteterValidationStartTime)*1000), 3)} ms")) @@ -145,7 +146,7 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc await(self.redisReaderMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) if self.benchmarking: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}")) @@ -166,11 +167,11 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s pendingOutboundQueue = await(self.redisWriterMessaging.getNextQueue(pattern=f'diameter-outbound-{clientAddress.replace(".", "*")}-{clientPort}-*')) if not len(pendingOutboundQueue) > 0: - await(asyncio.sleep(0)) + await(asyncio.sleep(0.01)) continue - pendingOutboundQueue = pendingOutboundQueue.decode() + pendingOutboundQueue = pendingOutboundQueue - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}")) + # await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}")) outboundQueueSplit = str(pendingOutboundQueue).split('-') queuedMessageType = outboundQueueSplit[1] diameterOutboundHost = outboundQueueSplit[2] @@ -185,14 +186,14 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) writer.write(diameterOutboundBinary) await(writer.drain()) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) if self.benchmarking: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - except Exception: + except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.")) return False - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) async def handleConnection(self, reader, writer): """ @@ -238,7 +239,7 @@ async def handleConnection(self, reader, writer): for pendingTask in pendingTasks: try: pendingTask.cancel() - await(asyncio.sleep(0)) + await(asyncio.sleep(0.1)) except asyncio.CancelledError: pass diff --git a/services/georedService.py b/services/georedService.py index 67e5b3f..c0d8879 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -250,27 +250,32 @@ async def handleGeoredQueue(self): if self.benchmarking: startTime = time.perf_counter() georedQueue = await(self.redisGeoredMessaging.getNextQueue(pattern='geored-*')) + if not len(georedQueue) > 0: + await(asyncio.sleep(0.01)) + continue georedMessage = await(self.redisGeoredMessaging.getMessage(queue=georedQueue)) - if len(georedMessage) > 0: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}")) - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) - - georedDict = json.loads(georedMessage) - georedOperation = georedDict['operation'] - georedBody = georedDict['body'] - georedTasks = [] - - for remotePeer in self.georedPeers: - georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody)) - await asyncio.gather(*georedTasks) - if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + if not len(georedMessage) > 0: + await(asyncio.sleep(0.01)) + continue + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) + + georedDict = json.loads(georedMessage) + georedOperation = georedDict['operation'] + georedBody = georedDict['body'] + georedTasks = [] + + for remotePeer in self.georedPeers: + georedTasks.append(self.sendGeored(asyncSession=session, url=remotePeer+'/geored/', operation=georedOperation, body=georedBody)) + await asyncio.gather(*georedTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) except Exception as e: await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) continue async def handleWebhookQueue(self): @@ -283,28 +288,33 @@ async def handleWebhookQueue(self): if self.benchmarking: startTime = time.perf_counter() webhookQueue = await(self.redisWebhookMessaging.getNextQueue(pattern='webhook-*')) + if not len(webhookQueue) > 0: + await(asyncio.sleep(0.01)) + continue webhookMessage = await(self.redisWebhookMessaging.getMessage(queue=webhookQueue)) - if len(webhookMessage) > 0: - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}")) - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) - - webhookDict = json.loads(webhookMessage) - webhookHeaders = webhookDict['headers'] - webhookOperation = webhookDict['operation'] - webhookBody = webhookDict['body'] - webhookTasks = [] - - for remotePeer in self.webhookPeers: - webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders)) - await asyncio.gather(*webhookTasks) - if self.benchmarking: - await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + if not len(webhookMessage) > 0: + await(asyncio.sleep(0.001)) + continue + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) + + webhookDict = json.loads(webhookMessage) + webhookHeaders = webhookDict['headers'] + webhookOperation = webhookDict['operation'] + webhookBody = webhookDict['body'] + webhookTasks = [] + + for remotePeer in self.webhookPeers: + webhookTasks.append(self.sendWebhook(asyncSession=session, url=remotePeer, operation=webhookOperation, body=webhookBody, headers=webhookHeaders)) + await asyncio.gather(*webhookTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleWebhookQueue] Time taken to send webhook to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) except Exception as e: await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Error handling webhook queue: {e}")) - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) continue async def startService(self): @@ -343,7 +353,7 @@ async def startService(self): for pendingTask in pendingTasks: try: pendingTask.cancel() - await(asyncio.sleep(0)) + await(asyncio.sleep(0.001)) except asyncio.CancelledError: pass diff --git a/services/hssService.py b/services/hssService.py index 23cf436..e7cd5f4 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -74,6 +74,7 @@ def handleQueue(self): self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) except Exception as e: + time.sleep(0.001) continue diff --git a/services/logService.py b/services/logService.py index 568c0d7..a6a4e03 100644 --- a/services/logService.py +++ b/services/logService.py @@ -1,5 +1,6 @@ import os, sys, json, yaml from datetime import datetime +import time import logging sys.path.append(os.path.realpath('../lib')) from messaging import RedisMessaging @@ -46,6 +47,7 @@ def handleLogs(self): logMessage = self.redisMessaging.getMessage(queue=logQueue) if not len(logMessage) > 0: + time.sleep(0.001) continue print(f"[Log] Queue: {logQueue}") diff --git a/services/metricService.py b/services/metricService.py index 09c6be2..d75902d 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -38,6 +38,7 @@ def handleMetrics(self): metric = self.redisMessaging.getMessage(queue=metricQueue) if not (len(metric) > 0): + time.sleep(0.001) return self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) From db6f6e1628add84f5b4a8568340c7a0ab12a8328 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Thu, 21 Sep 2023 17:03:05 +1000 Subject: [PATCH 17/43] Fix for sqn resync --- lib/database.py | 1 + lib/diameter.py | 74 +++++++++++++++++++++++++------------------------ 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/lib/database.py b/lib/database.py index 5fcd575..6350e46 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1674,6 +1674,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber self.logTool.log(service='Database', level='info', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) #Create if does not exist self.CreateObj(SERVING_APN, json_data, True) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) self.handleWebhook(objectData, 'PUT') diff --git a/lib/diameter.py b/lib/diameter.py index 08e8568..7dc4dea 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1133,38 +1133,39 @@ def Answer_16777251_318(self, packet_vars, avps): try: requested_vectors = 1 - for avp in avps: - if avp['avp_code'] == 1408: - self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) - EUTRAN_Authentication_Info = avp['misc_data'] - self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) - for sub_avp in EUTRAN_Authentication_Info: - #If resync request - if sub_avp['avp_code'] == 1411: - self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) - self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', - metricType='counter', metricAction='inc', - metricValue=1.0, - metricLabels={ - "diameter_application_id": 16777251, - "diameter_cmd_code": 318, - "event": "Resync", - "imsi_prefix": str(imsi[0:6])}, - metricHelp='Diameter Authentication related Counters', - metricExpiry=60) - auts = str(sub_avp['misc_data'])[32:] - rand = str(sub_avp['misc_data'])[:32] - rand = binascii.unhexlify(rand) - #Calculate correct SQN - self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) - - #Get number of requested vectors - if sub_avp['avp_code'] == 1410: - self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) - requested_vectors = int(sub_avp['misc_data'], 16) - if requested_vectors >= 32: - self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) - requested_vectors = 32 + EUTRAN_Authentication_Info = self.get_avp_data(avps, 1408) + self.logTool.log(service='HSS', level='debug', message=f"authInfo: {EUTRAN_Authentication_Info}", redisClient=self.redisMessaging) + if len(EUTRAN_Authentication_Info) > 0: + EUTRAN_Authentication_Info = EUTRAN_Authentication_Info[0] + self.logTool.log(service='HSS', level='debug', message="AVP: Requested-EUTRAN-Authentication-Info(1408) l=44 f=VM- vnd=TGPP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="EUTRAN_Authentication_Info is " + str(EUTRAN_Authentication_Info), redisClient=self.redisMessaging) + for sub_avp in EUTRAN_Authentication_Info: + #If resync request + if sub_avp['avp_code'] == 1411: + self.logTool.log(service='HSS', level='debug', message="Re-Synchronization required - SQN is out of sync", redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', + metricType='counter', metricAction='inc', + metricValue=1.0, + metricLabels={ + "diameter_application_id": 16777251, + "diameter_cmd_code": 318, + "event": "Resync", + "imsi_prefix": str(imsi[0:6])}, + metricHelp='Diameter Authentication related Counters', + metricExpiry=60) + auts = str(sub_avp['misc_data'])[32:] + rand = str(sub_avp['misc_data'])[:32] + rand = binascii.unhexlify(rand) + #Calculate correct SQN + self.database.Get_Vectors_AuC(subscriber_details['auc_id'], "sqn_resync", auts=auts, rand=rand) + + #Get number of requested vectors + if sub_avp['avp_code'] == 1410: + self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) + requested_vectors = int(sub_avp['misc_data'], 16) + if requested_vectors >= 32: + self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) + requested_vectors = 32 self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) eutranvector_complete = '' @@ -1403,6 +1404,8 @@ def Answer_16777238_272(self, packet_vars, avps): except Exception as e: #Get subscriber details #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=traceback.format_exc(), redisClient=self.redisMessaging) + self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', metricType='counter', metricAction='inc', metricValue=1.0, @@ -1413,10 +1416,9 @@ def Answer_16777238_272(self, packet_vars, avps): "imsi_prefix": str(imsi[0:6])}, metricHelp='Diameter Authentication related Counters', metricExpiry=60) - experimental_result = self.generate_avp(298, 40, self.int_to_hex(5001, 4)) #Result Code (DIAMETER ERROR - User Unknown) - experimental_result = experimental_result + self.generate_vendor_avp(266, 40, 10415, "") - #Experimental Result (297) - avp += self.generate_avp(297, 40, experimental_result) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) #Result Code (DIAMETER ERROR - User Unknown) response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response From e309932ceaa5b4c884d276aef1dff74f946087c6 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sat, 23 Sep 2023 22:33:20 +1000 Subject: [PATCH 18/43] Fix for Tac Database handling --- lib/database.py | 72 ++++++++++++++++++++++++------------------ services/apiService.py | 1 - 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/lib/database.py b/lib/database.py index 6350e46..2f7fc75 100755 --- a/lib/database.py +++ b/lib/database.py @@ -287,8 +287,10 @@ def __init__(self, logTool, redisMessaging=None): #Load IMEI TAC database into Redis if enabled if ('tac_database_csv' in self.config['eir']): self.load_IMEI_database_into_Redis() + self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase")) else: self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) + self.tacData = {} # Create individual tables if they do not exist. inspector = Inspector.from_engine(self.engine) @@ -310,25 +312,28 @@ def load_IMEI_database_into_Redis(self): return try: count = 0 + tacList = {"tacList": []} for line in csvfile: line = line.replace('"', '') #Strip excess invered commas line = line.replace("'", '') #Strip excess invered commas line = line.rstrip() #Strip newlines result = line.split(',') - tac_prefix = result[0] + tacPrefix = result[0] name = result[1].lstrip() model = result[2].lstrip() + if count == 0: self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) - redis_imei_result = self.redisMessaging.getMessage(queue=str(tac_prefix)) - if len(redis_imei_result) != 0: - self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) - break - else: + redis_imei_result = self.redisMessaging.getValue(key="tacDatabase") + if redis_imei_result is not None: + if len(redis_imei_result) > 0: + self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) + return self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) - imei_result = {'tac_prefix': tac_prefix, 'name': name, 'model': model} - self.redisMessaging.sendMessage(queue=str(tac_prefix), message=imei_result) - count = count +1 + tacList['tacList'].append({str(tacPrefix): {'name': name, 'model': model}}) + count += 1 + self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList)) + self.tacData = tacList self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) except Exception as E: self.logTool.log(service='Database', level='error', message="Failed to load IMEI Database into Redis due to error: " + (str(E)), redisClient=self.redisMessaging) @@ -833,13 +838,17 @@ def get_last_operation_log(self, existingSession=None): self.safe_close(session) raise ValueError(E) - def handleGeored(self, jsonData): + def handleGeored(self, jsonData, operation: str): try: + operation = operation.upper() + if operation not in ['PUT', 'PATCH', 'DELETE']: + self.logTool.log(service='Database', level='warning', message="Failed to send Geored message invalid operation type, received: " + str(operation), redisClient=self.redisMessaging) + return georedDict = {} if self.config.get('geored', {}).get('enabled', False): if self.config.get('geored', {}).get('endpoints', []) is not None and len(self.config.get('geored', {}).get('endpoints', [])) > 0: georedDict['body'] = jsonData - georedDict['operation'] = 'PATCH' + georedDict['operation'] = operation self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) except Exception as E: self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) @@ -1122,7 +1131,7 @@ def Generate_JSON_Model_for_Flask(self, obj_type): self.logTool.log(service='Database', level='debug', message="Generating JSON model for Flask for object type: " + str(obj_type), redisClient=self.redisMessaging) dictty = dict(self.generate_json_schema(obj_type)) - pprint.pprint(dictty) + # pprint.pprint(dictty) #dictty['properties'] = dict(dictty['properties']) @@ -1853,7 +1862,7 @@ def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=Tru self.redisMessaging.sendMetric(serviceName='database', metricName='prom_eir_devices', metricType='counter', metricAction='inc', metricValue=1, metricHelp='Profile of attached devices', - metricLabels={'imei_prefix': device_info['tac_prefix'], + metricLabels={'imei_prefix': device_info['tacPrefix'], 'device_type': device_info['name'], 'device_name': device_info['model']}, metricExpiry=60) @@ -1999,31 +2008,32 @@ def dict_bytes_to_dict_string(self, dict_bytes): dict_string = {} for key, value in dict_bytes.items(): dict_string[key.decode()] = value.decode() - return dict_string - - - def get_device_info_from_TAC(self, imei): + return + + def find_imei_in_tac_list(self, imei, tacList): + """ + Iterate over every tac in the tacList and try to match the first 8 digits of the IMEI. + If that fails, try to match the first 6 digits of the IMEI. + """ + for tac in tacList['tacList']: + for key, value in tac.items(): + if str(key) == str(imei[0:8]): + return {'tacPrefix': key, 'name': tac[key]['name'], 'model': tac[key]['model']} + for key, value in tac.items(): + if str(key) == str(imei[0:6]): + return {'tacPrefix': key, 'name': tac[key]['name'], 'model': tac[key]['model']} + return {} + + def get_device_info_from_TAC(self, imei) -> dict: self.logTool.log(service='Database', level='debug', message="Getting Device Info from IMEI: " + str(imei), redisClient=self.redisMessaging) - #Try 8 digit TAC try: - self.logTool.log(service='Database', level='debug', message="Trying to match on 8 Digit IMEI", redisClient=self.redisMessaging) - imei_result = self.redisMessaging.RedisHGetAll(str(imei[0:8])) - imei_result = self.dict_bytes_to_dict_string(imei_result) + self.logTool.log(service='Database', level='debug', message="Taclist: self.tacList ", redisClient=self.redisMessaging) + imei_result = self.find_imei_in_tac_list(imei, self.tacData) assert(len(imei_result) != 0) self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) return imei_result except: self.logTool.log(service='Database', level='debug', message="Failed to match on 8 digit IMEI", redisClient=self.redisMessaging) - - try: - self.logTool.log(service='Database', level='debug', message="Trying to match on 6 Digit IMEI", redisClient=self.redisMessaging) - imei_result = self.redisMessaging.RedisHGetAll(str(imei[0:6])) - imei_result = self.dict_bytes_to_dict_string(imei_result) - assert(len(imei_result) != 0) - self.logTool.log(service='Database', level='debug', message="Found match for IMEI " + str(imei) + " with result " + str(imei_result), redisClient=self.redisMessaging) - return imei_result - except: - self.logTool.log(service='Database', level='debug', message="Failed to match on 6 digit IMEI", redisClient=self.redisMessaging) raise ValueError("No matching TAC in IMEI Database") diff --git a/services/apiService.py b/services/apiService.py index 0651115..3483ad5 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -949,7 +949,6 @@ def get(self, imei): print(E) return handle_exception(E) - @ns_subscriber_attributes.route('/list') class PyHSS_Subscriber_Attributes_All(Resource): @ns_subscriber_attributes.expect(paginatorParser) From 3ea11b40157a7d20d39e62dfe56f6b70b8545c48 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sun, 24 Sep 2023 13:58:01 +1000 Subject: [PATCH 19/43] Asymmetric geored, CLR fix, logic fixes --- lib/database.py | 29 +++++++--- lib/diameter.py | 101 +++++++++++++++++++++------------ services/apiService.py | 108 +++++++++++++++++++++++++++++++----- services/diameterService.py | 2 - services/georedService.py | 42 +++++++++++++- 5 files changed, 219 insertions(+), 63 deletions(-) diff --git a/lib/database.py b/lib/database.py index 2f7fc75..1bae5b2 100755 --- a/lib/database.py +++ b/lib/database.py @@ -838,7 +838,11 @@ def get_last_operation_log(self, existingSession=None): self.safe_close(session) raise ValueError(E) - def handleGeored(self, jsonData, operation: str): + def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, asymmetricUrls: list=[]) -> bool: + """ + Validate the request, check configuration and queue the geored message. + Asymmetric geored is supported (where one or more specific or foreign geored endpoints are specified). + """ try: operation = operation.upper() if operation not in ['PUT', 'PATCH', 'DELETE']: @@ -850,8 +854,17 @@ def handleGeored(self, jsonData, operation: str): georedDict['body'] = jsonData georedDict['operation'] = operation self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) + if asymmetric: + if len(asymmetricUrls) > 0: + georedDict['body'] = jsonData + georedDict['operation'] = operation + georedDict['urls'] = asymmetricUrls + self.redisMessaging.sendMessage(queue=f'asymmetric-geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) + return True + except Exception as E: self.logTool.log(service='Database', level='warning', message="Failed to send Geored message due to error: " + str(E), redisClient=self.redisMessaging) + return False def handleWebhook(self, objectData, operation: str="PATCH"): webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) @@ -1523,17 +1536,16 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) self.logTool.log(service='Database', level='debug', message="Sending CLR to API at " + str(URL), redisClient=self.redisMessaging) - json_data = { + + self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data), redisClient=self.redisMessaging) + transaction_id = str(uuid.uuid4()) + self.handleGeored({ + "imsi": str(imsi), "DestinationRealm": result.serving_mme_realm, "DestinationHost": result.serving_mme, "cancellationType": 2, "diameterPeer": serving_mme_peer, - } - - self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data), redisClient=self.redisMessaging) - transaction_id = str(uuid.uuid4()) - GeoRed_Push_thread = threading.Thread(target=self.GeoRed_Push_Request, args=(serving_hss, json_data, transaction_id, URL)) - GeoRed_Push_thread.start() + }, asymmetric=True, asymmetricUrls=[URL]) else: #No currently serving MME - No action to take self.logTool.log(service='Database', level='debug', message="No currently serving MME - No need to send CLR", redisClient=self.redisMessaging) @@ -2037,6 +2049,7 @@ def get_device_info_from_TAC(self, imei) -> dict: raise ValueError("No matching TAC in IMEI Database") + if __name__ == "__main__": import binascii,os,pprint DeleteAfter = True diff --git a/lib/diameter.py b/lib/diameter.py index 7dc4dea..ab38924 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -62,9 +62,11 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 ] self.diameterRequestList = [ - {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer", "validPeerTypes": ['MME']}, - {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer", "validPeerTypes": ['MME']}, + {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, + {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 304, "applicationId": 16777216, "requestMethod": self.Request_16777216_304, "failureResultCode": 5012 ,"requestAcronym": "RTR", "responseAcronym": "RTA", "requestName": "Registration Termination Request", "responseName": "Registration Termination Answer"}, + ] #Generates rounding for calculating padding @@ -531,7 +533,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: if peerType not in peerTypes: return [] filteredConnectedPeers = [] - activePeers = self.redisMessaging.getValue(key="ActiveDiameterPeers") + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': @@ -545,7 +547,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: def getPeerByHostname(self, hostname: str) -> dict: try: hostname = hostname.lower() - activePeers = self.redisMessaging.getValue(key="ActiveDiameterPeers") + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': @@ -569,28 +571,60 @@ def getDiameterMessageType(self, binaryData: str) -> dict: continue return response - def generateDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: - try: - request = '' - self.logTool.log(service='HSS', level='debug', message=f"Generating a diameter outbound request", redisClient=self.redisMessaging) - - for diameterApplication in self.diameterRequestList: - try: - assert(requestType == diameterApplication["requestAcronym"]) - except Exception as e: - continue - connectedPeer = self.getPeerByHostname(hostname=hostname) + def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: + """ + Sends a given diameter request of requestType to the provided peer hostname, if the peer is connected. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeer = self.getPeerByHostname(hostname=hostname) + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" + outboundMessage = json.dumps({'diameter-outbound': request}) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + return request + except Exception as e: + return '' + + def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> bool: + """ + Sends a diameter request of requestType to one or more connected peers, specified by peerType. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Broadcasting a diameter outbound request of type: {requestType} to peers of type: {peerType}", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeerList = self.getConnectedPeersByType(peerType=peerType) + for connectedPeer in connectedPeerList: peerIp = connectedPeer['ipAddress'] peerPort = connectedPeer['port'] - request = diameterApplication["requestMethod"](kwargs) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Successfully generated request: {request}", redisClient=self.redisMessaging) + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" - outboundMessage = {'diameter-outbound': json.dumps(request)} + outboundMessage = json.dumps({'diameter-outbound': request}) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) - return request - except Exception as e: - return '' + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Queueing for peer type: {peerType} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + return connectedPeerList + except Exception as e: + return '' def generateDiameterResponse(self, binaryData: str) -> str: try: @@ -599,11 +633,11 @@ def generateDiameterResponse(self, binaryData: str) -> str: origin_host = binascii.unhexlify(origin_host).decode("utf-8") response = '' - self.logTool.log(service='HSS', level='debug', message=f"Generating a diameter response", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] Generating a diameter response", redisClient=self.redisMessaging) # Drop packet if it's a response packet: if packet_vars["flags_bin"][0:1] == "0": - self.logTool.log(service='HSS', level='debug', message="Got a Response, not a request - dropping it.", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [generateDiameterResponse] Got a Response, not a request - dropping it.", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=packet_vars, redisClient=self.redisMessaging) return @@ -839,7 +873,6 @@ def Answer_16777251_316(self, packet_vars, avps): VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777251),"x").zfill(8)) #Auth-Application-ID Relay avp += self.generate_avp(260, 40, VendorSpecificApplicationId) #AVP: Auth-Application-Id(258) l=12 f=-M- val=3GPP S6a/S6d (16777251) - #AVP: Supported-Features(628) l=36 f=V-- vnd=TGPP SupportedFeatures = '' SupportedFeatures += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID @@ -847,7 +880,6 @@ def Answer_16777251_316(self, packet_vars, avps): SupportedFeatures += self.generate_vendor_avp(630, 80, 10415, "1c000607") #Feature-List Flags avp += self.generate_vendor_avp(628, "80", 10415, SupportedFeatures) #Supported-Features(628) l=36 f=V-- vnd=TGPP - #APNs from DB APN_Configuration = '' imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request @@ -2542,7 +2574,7 @@ def Request_16777216_303(self, imsi, domain): return response #3GPP Cx Registration Termination Request (RTR) - def Request_16777216_304(self, imsi, domain): + def Request_16777216_304(self, imsi, domain, destinationHost, destinationRealm): avp = '' #Initiate empty var AVP #Session-ID sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP @@ -2553,25 +2585,20 @@ def Request_16777216_304(self, imsi, domain): #SIP-Deregistration-Reason reason_code_avp = self.generate_vendor_avp(616, "c0", 10415, "00000000") - reason_info_avp = self.generate_vendor_avp(617, "c0", 10415, self.string_to_hex("Test Reason")) + reason_info_avp = self.generate_vendor_avp(617, "c0", 10415, self.string_to_hex("Administrative Deregistration")) avp += self.generate_vendor_avp(615, "c0", 10415, reason_code_avp + reason_info_avp) - avp += self.generate_avp(283, 40, str(binascii.hexlify(b'localdomain'),'ascii')) #Destination Realm - avp += self.generate_avp(293, 40, str(binascii.hexlify(b'hss.localdomain'),'ascii')) #Destination Host + avp += self.generate_avp(283, 40, self.string_to_hex(destinationRealm)) #Destination Realm + avp += self.generate_avp(293, 40, self.string_to_hex(destinationHost)) #Destination Host avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) avp += self.generate_avp(1, 40, self.string_to_hex(str(imsi) + "@" + domain)) #User-Name avp += self.generate_vendor_avp(601, "c0", 10415, self.string_to_hex("sip:" + str(imsi) + "@" + domain)) #Public-Identity avp += self.generate_vendor_avp(602, "c0", 10415, self.ProductName) #Server-Name - #* [ Proxy-Info ] - proxy_host_avp = self.generate_avp(280, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - proxy_state_avp = self.generate_avp(33, "40", "0001") - avp += self.generate_avp(284, "40", proxy_host_avp + proxy_state_avp) #Proxy-Info AVP ( 284 ) #* [ Route-Record ] - avp += self.generate_avp(282, "40", str(binascii.hexlify(b'localdomain'),'ascii')) - - + avp += self.generate_avp(282, "40", self.OriginHost) + response = self.generate_diameter_packet("01", "c0", 304, 16777216, self.generate_id(4), self.generate_id(4), avp) #Generate Diameter packet return response diff --git a/services/apiService.py b/services/apiService.py index 3483ad5..7f00c3e 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -455,8 +455,10 @@ def patch(self, subscriber_id): try: assert(json_data['serving_mme']) print("Serving MME set - Sending CLR") - diameterClient.generateDiameterRequest( + + diameterClient.sendDiameterRequest( requestType='CLR', + hostname=json_data['serving_mme'], imsi=json_data['imsi'], DestinationHost=json_data['serving_mme'], DestinationRealm=json_data['serving_mme_realm'], @@ -1070,6 +1072,69 @@ def get(self): print(E) return handle_exception(E) +@ns_oam.route('/deregister/') +class PyHSS_OAM_Deregister(Resource): + def get(self, imsi): + '''Deregisters a given IMSI from the entire network''' + try: + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) + imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) + servingMme = subscriberInfo.get('serving_mme', None) + servingMmeRealm = subscriberInfo.get('serving_mme_realm', None) + servingScscf = imsSubscriberInfo.get('scscf_peer', None) + servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) + if servingMme is not None and servingMmeRealm is not None: + diameterRequest = diameterClient.broadcastDiameterRequest( + requestType='CLR', + peerType='MME', + imsi=imsi, + DestinationHost=servingMme, + DestinationRealm=servingMmeRealm, + CancellationType=2 + ) + databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) + if servingScscf and servingScscfRealm is not None: + servingScscf = servingScscf.split(';')[0] + diameterRequest = diameterClient.broadcastDiameterRequest( + requestType='RTR', + peerType='SCSCF', + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + return {"result": f"Successfully deregistered {imsi} from the entire network"}, 200 + except Exception as E: + print(E) + return handle_exception(E) + + +# The below function is kept as a placeholder until Rx is implemented. +# @ns_oam.route('/deregister_ims/') +# class PyHSS_OAM_DeregisterIMS(Resource): +# def get(self, imsi): +# '''Deregisters a given IMSI from the IMS network''' +# try: +# imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) +# servingScscf = imsSubscriberInfo.get('scscf_peer', None) +# servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) +# if servingScscf and servingScscfRealm is not None: +# servingScscf = servingScscf.split(';')[0] +# diameterRequest = diameterClient.broadcastDiameterRequest( +# requestType='RTR', +# peerType='SCSCF', +# imsi=imsi, +# destinationHost=servingScscf, +# destinationRealm=servingScscfRealm, +# domain=servingScscfRealm +# ) +# databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) +# return {"result": f"Successfully deregistered {imsi} from the IMS network"}, 200 +# except Exception as E: +# print(E) +# return handle_exception(E) + @ns_oam.route("/ping") class PyHSS_OAM_Ping(Resource): def get(self): @@ -1455,20 +1520,33 @@ class PyHSS_Push_CLR(Resource): @ns_push.expect(Push_CLR_Model) @ns_push.doc('Push CLR (Cancel Location Request) to MME') def put(self, imsi): - '''Push CLR (Cancel Location Request) to MME''' - json_data = request.get_json(force=True) - print("JSON Data sent: " + str(json_data)) - if 'DestinationHost' not in json_data: - json_data['DestinationHost'] = None - diam_hex = diameterClient.sendDiameterRequest( - requestType='CLR', - imsi=imsi, - DestinationHost=json_data['DestinationHost'], - DestinationRealm=json_data['DestinationRealm'], - CancellationType=json_data['cancellationType'] - ) - return diam_hex, 200 + try: + '''Push CLR (Cancel Location Request) to MME''' + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + if 'DestinationHost' not in json_data: + json_data['DestinationHost'] = None + diameterRequest = diameterClient.sendDiameterRequest( + requestType='CLR', + hostname=json_data['diameterPeer'], + imsi=imsi, + DestinationHost=json_data['DestinationHost'], + DestinationRealm=json_data['DestinationRealm'], + CancellationType=json_data['cancellationType'] + ) + if not len(diameterRequest) > 0: + return {'result': f'Failed queueing CLR to {json_data["diameterPeer"]}'}, 400 + + subscriber_details = databaseClient.Get_Subscriber(imsi=str(imsi)) + if subscriber_details['serving_mme'] == json_data['DestinationHost']: + databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) + + return {'result': f'Successfully queued CLR to {json_data["diameterPeer"]}'}, 200 + except Exception as E: + print("Exception when sending CLR: " + str(E)) + response_json = {'result': 'Failed', 'Reason' : "Unable to send CLR: " + str(E)} + return response_json if __name__ == '__main__': - apiService.run(debug=False) + apiService.run(debug=False, host='0.0.0.0', port=8080) diff --git a/services/diameterService.py b/services/diameterService.py index b7ad54a..5d787d5 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -170,8 +170,6 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s await(asyncio.sleep(0.01)) continue pendingOutboundQueue = pendingOutboundQueue - - # await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Pending Outbound Queue: {pendingOutboundQueue}")) outboundQueueSplit = str(pendingOutboundQueue).split('-') queuedMessageType = outboundQueueSplit[1] diameterOutboundHost = outboundQueueSplit[2] diff --git a/services/georedService.py b/services/georedService.py index c0d8879..cfcac66 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -239,7 +239,46 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [sendWebhook] Time taken to send individual webhook request to {url}: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) return True - + + async def handleAsymmetricGeoredQueue(self): + """ + Collects and processes asymmetric geored messages. + """ + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: + while True: + try: + if self.benchmarking: + startTime = time.perf_counter() + asymmetricGeoredQueue = await(self.redisGeoredMessaging.getNextQueue(pattern='asymmetric-geored-*')) + if not len(asymmetricGeoredQueue) > 0: + await(asyncio.sleep(0.01)) + continue + georedMessage = await(self.redisGeoredMessaging.getMessage(queue=georedQueue)) + if not len(georedMessage) > 0: + await(asyncio.sleep(0.01)) + continue + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Queue: {georedQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) + + georedDict = json.loads(georedMessage) + georedOperation = georedDict['operation'] + georedBody = georedDict['body'] + georedUrls = georedDict['urls'] + georedTasks = [] + + for georedEndpoint in georedUrls: + georedTasks.append(self.sendGeored(asyncSession=session, url=georedEndpoint, operation=georedOperation, body=georedBody)) + await asyncio.gather(*georedTasks) + if self.benchmarking: + await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleAsymmetricGeoredQueue] Time taken to send asymmetric geored message to specified peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) + + await(asyncio.sleep(0.001)) + + except Exception as e: + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Error handling asymmetric geored queue: {e}")) + await(asyncio.sleep(0.001)) + continue + async def handleGeoredQueue(self): """ Collects and processes queued geored messages. @@ -341,6 +380,7 @@ async def startService(self): if georedEnabled: georedTask = asyncio.create_task(self.handleGeoredQueue()) + asymmetricGeoredTask = asyncio.create_task(self.asymmetricGeoredQueue()) activeTasks.append(georedTask) if webhooksEnabled: From eb603bf51313aae6b837cae164fdda62a2eff06b Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sun, 24 Sep 2023 14:22:22 +1000 Subject: [PATCH 20/43] Update README.md --- README.md | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index cff724e..e8972ee 100644 --- a/README.md +++ b/README.md @@ -41,20 +41,28 @@ Basic configuration is set in the ``config.yaml`` file, You will need to set the IP address to bind to (IPv4 or IPv6), the Diameter hostname, realm, your PLMN and transport type to use (SCTP or TCP). -Once the configuration is done you can run the HSS by running ``hss.py`` and the server will run using whichever transport (TCP/SCTP) you have selected. +The diameter service runs in a trusting mode allowing Diameter connections from any other Diameter hosts. -The service runs in a trusting mode allowing Diameter connections from any other Diameter hosts. +To perform as a functioning HSS, the following services must be run as a minimum: +- diameterService.py +- hssService.py -## Structure +If you're provisioning the HSS for the first time, you'll also want to run: + - apiService.py -The file *hss.py* runs a threaded Sockets based listener (SCTP or TCP) to receive Diameter requests, process them and send back Diameter responses. +The rest of the services aren't strictly necessary, however your own configuration will dictate whether or not they are required. -Most of the heavy lifting in this is managed by the Diameter class, in ``diameter.py``. This: +## Structure - * Decodes incoming packets (Requests)(Returns AVPs as an array, called *avp*, and a Dict containing the packet variables (called *packet_vars*) - * Generates responses (Answer messages) to Requests (when provided with the AVP and packet_vars of the original Request) - * Generates Requests to send to other peers +PyHSS uses a queued microservices model. Each service performs a specific set of tasks, and uses redis messages to communicate with other services. +The following services make up PyHSS: + - diameterService.py: Handles receiving and sending of diameter messages, and diameter client connection state. + - hssService.py: Provides decoding and encoding of diameter requests and responses, as well as logic to perform as a HSS. + - apiService.py: Provides the API, to allow management of PyHSS. + - georedService.py: Sends georaphic redundancy messages to geored peers when defined. Also handles webhook messages. + - logService.py: Handles logging for all services. + - metricService.py: Exposes prometheus metrics from other services. ## Subscriber Information Storage @@ -71,12 +79,15 @@ Dependencies can be installed using Pip3: pip3 install -r requirements.txt ``` -Then after setting up the config, you can fire up the HSS itself by running: +Then after setting up the config, you can fire up the necessary PyHSS services by running: ```shell -python3 hss.py +python3 diameterService.py +python3 hssService.py +python3 apiService.py ``` -All going well you'll have a functioning HSS at this point. +All going well you'll have a functioning HSS at this point. For production use, systemd scripts are located in `./systemd` +PyHSS API uses Flask, and can be configured with your favourite WSGI server. To get everything more production ready checkout [Monit with PyHSS](docs/monit.md) for more info. From b354a61bb38faf007e7b25f4702434f7a27651bc Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sun, 24 Sep 2023 21:59:34 +1000 Subject: [PATCH 21/43] Fix method typo --- services/georedService.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/georedService.py b/services/georedService.py index cfcac66..6300a39 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -380,7 +380,7 @@ async def startService(self): if georedEnabled: georedTask = asyncio.create_task(self.handleGeoredQueue()) - asymmetricGeoredTask = asyncio.create_task(self.asymmetricGeoredQueue()) + asymmetricGeoredTask = asyncio.create_task(self.handleAsymmetricGeoredQueue()) activeTasks.append(georedTask) if webhooksEnabled: From 5179e0d0ed4f17b2e44bc817bbd1c7886b9e0b8d Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 26 Sep 2023 14:30:45 +1000 Subject: [PATCH 22/43] Basic Rx support --- lib/diameter.py | 216 +++++++++++++++++++++++++++++++++++++++++-- lib/diameterAsync.py | 28 ++++-- 2 files changed, 229 insertions(+), 15 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index ab38924..d863eba 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -44,28 +44,34 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.diameterResponseList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, - {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, - {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, - {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, - {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, - {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, + {"commandCode": 258, "applicationId": 16777236, "responseMethod": self.Answer_16777236_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, + {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, + + ] self.diameterRequestList = [ + {"commandCode": 304, "applicationId": 16777216, "requestMethod": self.Request_16777216_304, "failureResultCode": 5012 ,"requestAcronym": "RTR", "responseAcronym": "RTA", "requestName": "Registration Termination Request", "responseName": "Registration Termination Answer"}, + {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, - {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, - {"commandCode": 304, "applicationId": 16777216, "requestMethod": self.Request_16777216_304, "failureResultCode": 5012 ,"requestAcronym": "RTR", "responseAcronym": "RTA", "requestName": "Registration Termination Request", "responseName": "Registration Termination Answer"}, ] @@ -647,6 +653,7 @@ def generateDiameterResponse(self, binaryData: str) -> str: assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Attempting to generate response", redisClient=self.redisMessaging) response = diameterApplication["responseMethod"](packet_vars, avps) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) break @@ -665,6 +672,32 @@ def generateDiameterResponse(self, binaryData: str) -> str: metricExpiry=60) return '' + def validateImsSubscriber(self, imsi=None, msisdn=None) -> bool: + """ + Ensures that a given IMSI or MSISDN (Or both, if specified) are associated with a subscriber that is enabled, and has an associated IMS Subscriber record. + """ + if imsi == None and msisdn == None: + return False + + try: + if imsi is not None: + subscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if not subscriberDetails.get('enabled', False): + return False + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=imsi) + except Exception as e: + return False + try: + if msisdn is not None: + subscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + if not subscriberDetails.get('enabled', False): + return False + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=msisdn) + except Exception as e: + return False + + return True + def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: if avp_dicts['avp_code'] == 278: @@ -1977,11 +2010,176 @@ def Answer_16777217_307(self, packet_vars, avps): VendorSpecificApplicationId += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777217),"x").zfill(8)) #Auth-Application-ID Sh avp += self.generate_avp(260, 40, VendorSpecificApplicationId) - - response = self.generate_diameter_packet("01", "40", 307, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response + + ################################ + #### 3GPP RX #### + ################################ + + #3GPP Rx - AA Answer (AAA) + def Answer_16777236_265(self, packet_vars, avps): + try: + """ + Generates a response to a provided AAR. + The response is determined by whether or not the subscriber is enabled, and has a matching ims_subscriber entry. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + subscriptionId = bytes.fromhex(self.get_avp_data(avps, 444)[0]).decode('ascii') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Received subscription ID: {subscriptionId}", redisClient=self.redisMessaging) + subscriptionId = subscriptionId.replace('sip:', '') + imsi = None + msisdn = None + identifier = None + if '@' in subscriptionId: + subscriberIdentifier = subscriptionId.split('@')[0] + # Subscriber Identifier can be either an IMSI or an MSISDN + try: + subscriberDetails = self.database.Get_Subscriber(imsi=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=subscriberIdentifier) + identifier = 'imsi' + imsi = imsSubscriberDetails.get('imsi', None) + except Exception as e: + pass + try: + subscriberDetails = self.database.Get_Subscriber(msisdn=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=subscriberIdentifier) + identifier = 'msisdn' + msisdn = imsSubscriberDetails.get('msisdn', None) + except Exception as e: + pass + else: + imsi = None + msisdn = None + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) + imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) + + if imsEnabled: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request authorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request unauthorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + + response = self.generate_diameter_packet("01", "40", 265, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error generating AAA: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5012, 4)) #Result Code 5012 UNABLE_TO_COMPLY + response = self.generate_diameter_packet("01", "40", 265, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Rx - Re Auth Answer (RAA) + def Answer_16777236_258(self, packet_vars, avps): + try: + """ + Generates a response to a provided RAR. + The response is determined by whether or not the subscriber is enabled, and has a matching ims_subscriber entry. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + subscriptionId = bytes.fromhex(self.get_avp_data(avps, 444)[0]).decode('ascii') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Received subscription ID: {subscriptionId}", redisClient=self.redisMessaging) + subscriptionId = subscriptionId.replace('sip:', '') + imsi = None + msisdn = None + identifier = None + if '@' in subscriptionId: + subscriberIdentifier = subscriptionId.split('@')[0] + # Subscriber Identifier can be either an IMSI or an MSISDN + try: + subscriberDetails = self.database.Get_Subscriber(imsi=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(imsi=subscriberIdentifier) + identifier = 'imsi' + imsi = imsSubscriberDetails.get('imsi', None) + except Exception as e: + pass + try: + subscriberDetails = self.database.Get_Subscriber(msisdn=subscriberIdentifier) + imsSubscriberDetails = self.database.Get_IMS_Subscriber(msisdn=subscriberIdentifier) + identifier = 'msisdn' + msisdn = imsSubscriberDetails.get('msisdn', None) + except Exception as e: + pass + else: + imsi = None + msisdn = None + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) + imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) + + if imsEnabled: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Request authorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_258] [RAA] Request unauthorized", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + + response = self.generate_diameter_packet("01", "40", 258, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_258] [RAA] Error generating RAA: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(5012, 4)) #Result Code 5012 UNABLE_TO_COMPLY + response = self.generate_diameter_packet("01", "40", 258, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + #3GPP Rx - Session Termination Answer (STA) + def Answer_16777236_275(self, packet_vars, avps): + try: + """ + Generates a response to a provided STR. + Returns Result-Code 2001. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA: {traceback.format_exc()}", redisClient=self.redisMessaging) + + #3GPP Rx - Abort Session Answer (ASA) + def Answer_16777236_274(self, packet_vars, avps): + try: + """ + Generates a response to a provided ASR. + Returns Result-Code 2001. + """ + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + response = self.generate_diameter_packet("01", "40", 274, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_274] [STA] Error generating STA: {traceback.format_exc()}", redisClient=self.redisMessaging) + + + #3GPP S13 - ME-Identity-Check Answer def Answer_16777252_324(self, packet_vars, avps): diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index ff15735..6ca952a 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -9,23 +9,27 @@ class DiameterAsync: def __init__(self, logTool): self.diameterCommandList = [ {"commandCode": 257, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_257, "failureResultCode": 5012 ,"requestAcronym": "CER", "responseAcronym": "CEA", "requestName": "Capabilites Exchange Request", "responseName": "Capabilites Exchange Answer"}, - {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCR", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 280, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_280, "failureResultCode": 5012 ,"requestAcronym": "DWR", "responseAcronym": "DWA", "requestName": "Device Watchdog Request", "responseName": "Device Watchdog Answer"}, {"commandCode": 282, "applicationId": 0, "flags": 80, "responseMethod": self.Answer_282, "failureResultCode": 5012 ,"requestAcronym": "DPR", "responseAcronym": "DPA", "requestName": "Disconnect Peer Request", "responseName": "Disconnect Peer Answer"}, - {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, - {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, - {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, - {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, {"commandCode": 300, "applicationId": 16777216, "responseMethod": self.Answer_16777216_300, "failureResultCode": 4100 ,"requestAcronym": "UAR", "responseAcronym": "UAA", "requestName": "User Authentication Request", "responseName": "User Authentication Answer"}, {"commandCode": 301, "applicationId": 16777216, "responseMethod": self.Answer_16777216_301, "failureResultCode": 4100 ,"requestAcronym": "SAR", "responseAcronym": "SAA", "requestName": "Server Assignment Request", "responseName": "Server Assignment Answer"}, {"commandCode": 302, "applicationId": 16777216, "responseMethod": self.Answer_16777216_302, "failureResultCode": 4100 ,"requestAcronym": "LIR", "responseAcronym": "LIA", "requestName": "Location Information Request", "responseName": "Location Information Answer"}, {"commandCode": 303, "applicationId": 16777216, "responseMethod": self.Answer_16777216_303, "failureResultCode": 4100 ,"requestAcronym": "MAR", "responseAcronym": "MAA", "requestName": "Multimedia Authentication Request", "responseName": "Multimedia Authentication Answer"}, {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, + {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, + {"commandCode": 258, "applicationId": 16777236, "responseMethod": self.Answer_16777236_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, + {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, + {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, + {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, + {"commandCode": 321, "applicationId": 16777251, "responseMethod": self.Answer_16777251_321, "failureResultCode": 5012 ,"requestAcronym": "PUR", "responseAcronym": "PUA", "requestName": "Purge UE Request", "responseName": "Purge UE Answer"}, + {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] - + self.redisMessaging = RedisMessagingAsync() self.logTool = logTool @@ -326,4 +330,16 @@ async def Answer_16777252_324(self): pass async def Answer_16777291_8388622(self): + pass + + async def Answer_16777236_265(self): + pass + + async def Answer_16777236_258(self): + pass + + async def Answer_16777236_275(self): + pass + + async def Answer_16777236_274(self): pass \ No newline at end of file From 5c12ee153f41d0f2611848e13d233110a4c201f2 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 27 Sep 2023 07:32:14 +1000 Subject: [PATCH 23/43] Fix in CCA for apn being in plmn-based string --- lib/diameter.py | 72 ++++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index d863eba..ee022fb 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1328,9 +1328,17 @@ def Answer_16777238_272(self, packet_vars, avps): CC_Request_Type = self.get_avp_data(avps, 416)[0] CC_Request_Number = self.get_avp_data(avps, 415)[0] #Called Station ID - self.logTool.log(service='HSS', level='debug', message="Attempting to find APN in CCR", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Attempting to find APN in CCR", redisClient=self.redisMessaging) apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message="CCR for APN " + str(apn), redisClient=self.redisMessaging) + # Strip plmn based domain from apn, if present + try: + if '.' in apn: + assert('mcc' in apn) + assert('mnc' in apn) + apn = apn.split('.')[0] + except Exception as e: + apn = bytes.fromhex(self.get_avp_data(avps, 30)[0]).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] CCR for APN " + str(apn), redisClient=self.redisMessaging) OriginHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP OriginHost = binascii.unhexlify(OriginHost).decode('utf-8') #Format it @@ -1343,46 +1351,45 @@ def Answer_16777238_272(self, packet_vars, avps): remote_peer = binascii.unhexlify(remote_peer).decode('utf-8') #Format it except: #If we don't have a record-route set, we'll send the response to the OriginHost remote_peer = OriginHost - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) avp = '' #Initiate empty var AVP session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCR] Session Id is " + str(binascii.unhexlify(session_id).decode()), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Session Id is " + str(binascii.unhexlify(session_id).decode()), redisClient=self.redisMessaging) avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number - #Get Subscriber info from Subscription ID for SubscriptionIdentifier in self.get_avp_data(avps, 443): for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - self.logTool.log(service='HSS', level='debug', message="Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Evaluating UniqueSubscriptionIdentifier AVP " + str(UniqueSubscriptionIdentifier) + " to find IMSI", redisClient=self.redisMessaging) if UniqueSubscriptionIdentifier['avp_code'] == 444: imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message="Found IMSI " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Found IMSI " + str(imsi), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) try: - self.logTool.log(service='HSS', level='info', message="Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) - self.logTool.log(service='HSS', level='info', message="Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) except Exception as E: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) if int(CC_Request_Type) == 1: - self.logTool.log(service='HSS', level='info', message="Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) #Get UE IP try: ue_ip = self.get_avp_data(avps, 8)[0] ue_ip = str(self.hex_to_ip(ue_ip)) except Exception as E: - self.logTool.log(service='HSS', level='error', message="Failed to get UE IP", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to get UE IP", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) ue_ip = 'Failed to Decode / Get UE IP' @@ -1393,10 +1400,10 @@ def Answer_16777238_272(self, packet_vars, avps): #Supported-Features(628) (Gx feature list) avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") - #Default EPS Beaerer QoS (From database with fallback source CCR-I) + #Default EPS Bearer QoS (From database with fallback source CCR-I, then omission) try: apn_data = ChargingRules['apn_data'] - self.logTool.log(service='HSS', level='debug', message="Setting APN AMBR", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Setting APN AMBR", redisClient=self.redisMessaging) #AMBR AMBR = '' #Initiate empty var AVP for AMBR apn_ambr_ul = int(apn_data['apn_ambr_ul']) @@ -1405,7 +1412,7 @@ def Answer_16777238_272(self, packet_vars, avps): AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) @@ -1415,12 +1422,13 @@ def Answer_16777238_272(self, packet_vars, avps): avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) except Exception as E: self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to populate default_EPS_QoS from DB for sub " + str(imsi), redisClient=self.redisMessaging) default_EPS_QoS = self.get_avp_data(avps, 1049)[0][8:] - avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) + if len(default_EPS_QoS) > 0: + avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - self.logTool.log(service='HSS', level='info', message="Creating QoS Information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Creating QoS Information", redisClient=self.redisMessaging) #QoS-Information try: apn_data = ChargingRules['apn_data'] @@ -1428,11 +1436,11 @@ def Answer_16777238_272(self, packet_vars, avps): apn_ambr_dl = int(apn_data['apn_ambr_dl']) QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - self.logTool.log(service='HSS', level='info', message="Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="Populated QoS_Information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Populated QoS_Information", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) except Exception as E: - self.logTool.log(service='HSS', level='error', message="Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='error', message=E, redisClient=self.redisMessaging) QoS_Information = '' @@ -1440,26 +1448,28 @@ def Answer_16777238_272(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message=AMBR_Part, redisClient=self.redisMessaging) AMBR_AVP = self.generate_vendor_avp(AMBR_Part['avp_code'], "80", 10415, AMBR_Part['misc_data'][8:]) QoS_Information += AMBR_AVP - self.logTool.log(service='HSS', level='debug', message="QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS_Information added " + str(AMBR_AVP), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) - self.logTool.log(service='HSS', level='debug', message="QoS information set statically", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS information set statically", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="Added to AVP List", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message="QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Added to AVP List", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I if ChargingRules and ChargingRules['charging_rules'] is not None: try: self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) for individual_charging_rule in ChargingRules['charging_rules']: - self.logTool.log(service='HSS', level='debug', message="Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) - avp += self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Processing Charging Rule: " + str(individual_charging_rule), redisClient=self.redisMessaging) + chargingRule = self.Charging_Rule_Generator(ChargingRules=individual_charging_rule, ue_ip=ue_ip) + if len(chargingRule) > 0: + avp += chargingRule except Exception as E: - self.logTool.log(service='HSS', level='debug', message="Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) elif int(CC_Request_Type) == 3: - self.logTool.log(service='HSS', level='info', message="Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -1468,7 +1478,7 @@ def Answer_16777238_272(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet except Exception as e: #Get subscriber details #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " unknown in HSS for CCR", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Subscriber " + str(imsi) + " unknown in HSS for CCR", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=traceback.format_exc(), redisClient=self.redisMessaging) self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', From d03d26c2546125d3921d4b28914c0b06c6a02ec5 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 27 Sep 2023 11:10:26 +1000 Subject: [PATCH 24/43] Improve /deregister, fix /oam/diameter_peers, add dra peerType --- lib/database.py | 48 +++++++++++- lib/diameter.py | 12 ++- lib/diameterAsync.py | 2 +- services/apiService.py | 167 ++++++++++++++++++++++++++++++----------- 4 files changed, 178 insertions(+), 51 deletions(-) diff --git a/lib/database.py b/lib/database.py index 1bae5b2..d3f4226 100755 --- a/lib/database.py +++ b/lib/database.py @@ -19,6 +19,7 @@ from messaging import RedisMessaging import yaml import json +import traceback Base = declarative_base() @@ -1225,7 +1226,14 @@ def Get_Subscriber(self, **kwargs): Session = sessionmaker(bind = self.engine) session = Session() - if 'msisdn' in kwargs: + if 'subscriber_id' in kwargs: + self.logTool.log(service='Database', level='debug', message="Get_Subscriber for id " + str(kwargs['subscriber_id']), redisClient=self.redisMessaging) + try: + result = session.query(SUBSCRIBER).filter_by(subscriber_id=int(kwargs['subscriber_id'])).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + elif 'msisdn' in kwargs: self.logTool.log(service='Database', level='debug', message="Get_Subscriber for msisdn " + str(kwargs['msisdn']), redisClient=self.redisMessaging) try: result = session.query(SUBSCRIBER).filter_by(msisdn=str(kwargs['msisdn'])).one() @@ -1737,6 +1745,44 @@ def Get_Serving_APN(self, subscriber_id, apn_id): self.safe_close(session) return result + def Get_Serving_APNs(self, subscriber_id: int) -> dict: + """ + Returns all a dictionary containing all APNs that a subscriber is configured for (subscriber/apn_list), + with active sessions being a populated dictionary, and inactive sessions being an empty dictionary. + """ + self.logTool.log(service='Database', level='debug', message=f"Getting Serving APNs for subscriber_id: {subscriber_id}", redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + apnDict = {'apns': {}} + + try: + subscriber = self.Get_Subscriber(subscriber_id=subscriber_id) + except: + self.logTool.log(service='Database', level='debug', message=f"Unable to get subscriber with ID: {subscriber_id}: {traceback.format_exc()} ", redisClient=self.redisMessaging) + return apnDict + + apnList = subscriber.get('apn_list', []).split(',') + for apnId in apnList: + try: + apnData = self.Get_APN(apnId) + apnName = apnData.get('apn', 'Unknown') + try: + servingApn = self.Sanitize_Datetime(self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apnId)) + self.logTool.log(service='Database', level='debug', message=f"Got serving APN: {servingApn}", redisClient=self.redisMessaging) + if len(servingApn) > 0: + apnDict['apns'][apnName] = servingApn + else: + apnDict['apns'][apnName] = {} + except Exception as e: + apnDict['apns'][apnName] = {} + continue + except Exception as E: + self.logTool.log(service='Database', level='debug', message=f"Error getting apn for subscriber id: {subscriber_id}: {traceback.format_exc()} ", redisClient=self.redisMessaging) + + self.logTool.log(service='Database', level='debug', message=f"Returning: {apnDict}", redisClient=self.redisMessaging) + + return apnDict + def Get_Charging_Rule(self, charging_rule_id): self.logTool.log(service='Database', level='debug', message="Called Get_Charging_Rule() for charging_rule_id " + str(charging_rule_id), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) diff --git a/lib/diameter.py b/lib/diameter.py index ee022fb..615014a 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -70,6 +70,7 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.diameterRequestList = [ {"commandCode": 304, "applicationId": 16777216, "requestMethod": self.Request_16777216_304, "failureResultCode": 5012 ,"requestAcronym": "RTR", "responseAcronym": "RTA", "requestName": "Registration Termination Request", "responseName": "Registration Termination Answer"}, {"commandCode": 258, "applicationId": 16777238, "requestMethod": self.Request_16777238_258, "failureResultCode": 5012 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, + {"commandCode": 272, "applicationId": 16777238, "requestMethod": self.Request_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, @@ -522,7 +523,7 @@ def decode_diameter_packet_length(self, data): def getPeerType(self, originHost: str) -> str: try: - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] for peer in peerTypes: if peer in originHost.lower(): @@ -2914,9 +2915,12 @@ def Request_16777291_8388622(self, **kwargs): return response #3GPP Gx - Credit Control Request - def Request_16777238_272(self, imsi, apn, ccr_type): + def Request_16777238_272(self, imsi, apn, ccr_type, destinationHost, destinationRealm, sessionId=None): avp = '' - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_gx' #Session state generate + if sessionId == None: + sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_gx' #Session state generate + else: + sessionid = sessionId avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- VendorSpecificApplicationId = '' @@ -2924,7 +2928,7 @@ def Request_16777238_272(self, imsi, apn, ccr_type): VendorSpecificApplicationId += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx avp += self.generate_avp(260, 40, VendorSpecificApplicationId) avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State (Not maintained) - avp += self.generate_avp(264, 40, self.string_to_hex('ExamplePGW.com')) #Origin Host + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index 6ca952a..c8c3557 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -220,7 +220,7 @@ async def decodeAvpPacket(self, data): async def getPeerType(self, originHost: str) -> str: try: - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] for peer in peerTypes: if peer in originHost.lower(): diff --git a/services/apiService.py b/services/apiService.py index 7f00c3e..75eaf22 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -32,7 +32,7 @@ redisHost = config.get("redis", {}).get("host", "127.0.0.1") redisPort = int(config.get("redis", {}).get("port", 6379)) -redisMessaging = RedisMessaging(host=redisHost, port=redisPort) +redisMessaging = RedisMessaging() logTool = LogTool(config) @@ -1064,76 +1064,153 @@ def get(self, table_name): @ns_oam.route('/diameter_peers') class PyHSS_OAM_Peers(Resource): def get(self): - '''Get all Diameter Peers''' + '''Get active Diameter Peers''' try: - diameterPeers = redisMessaging.getValue("ActiveDiameterPeers") + diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers")) return diameterPeers, 200 except Exception as E: + logTool.log(service='API', level='error', message=f"[API] An error occurred: {traceback.format_exc()}", redisClient=redisMessaging) print(E) return handle_exception(E) @ns_oam.route('/deregister/') class PyHSS_OAM_Deregister(Resource): def get(self, imsi): - '''Deregisters a given IMSI from the entire network''' + '''Deregisters a given IMSI from the entire network.''' try: - subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) + subscriberId = subscriberInfo.get('subscriber_id', None) + servingMmePeer = subscriberInfo.get('serving_mme_peer', None) servingMme = subscriberInfo.get('serving_mme', None) servingMmeRealm = subscriberInfo.get('serving_mme_realm', None) - servingScscf = imsSubscriberInfo.get('scscf_peer', None) + servingScscf = subscriberInfo.get('scscf', None) + servingScscfPeer = imsSubscriberInfo.get('scscf_peer', None) servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) - if servingMme is not None and servingMmeRealm is not None: - diameterRequest = diameterClient.broadcastDiameterRequest( + if servingMmePeer is not None and servingMmeRealm is not None and servingMme is not None: + if ';' in servingMmePeer: + servingMmePeer = servingMmePeer.split(';')[0] + + # Send the CLR to the serving MME + diameterClient.sendDiameterRequest( requestType='CLR', - peerType='MME', + hostname=servingMmePeer, imsi=imsi, DestinationHost=servingMme, DestinationRealm=servingMmeRealm, CancellationType=2 ) - databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) - if servingScscf and servingScscfRealm is not None: - servingScscf = servingScscf.split(';')[0] - diameterRequest = diameterClient.broadcastDiameterRequest( + + #Broadcast the CLR to all connected MME's, regardless of whether the subscriber is attached. + diameterClient.broadcastDiameterRequest( + requestType='CLR', + peerType='MME', + imsi=imsi, + DestinationHost=servingMme, + DestinationRealm=servingMmeRealm, + CancellationType=2 + ) + + databaseClient.Update_Serving_MME(imsi=imsi, serving_mme=None) + + if servingScscfPeer is not None and servingScscfRealm is not None and servingScscf is not None: + if ';' in servingScscfPeer: + servingScscfPeer = servingScscfPeer.split(';')[0] + servingScscf = servingScscf.replace('sip:', '') + if ';' in servingScscf: + servingScscf = servingScscf.split(';')[0] + diameterClient.sendDiameterRequest( requestType='RTR', - peerType='SCSCF', + peerType=servingScscfPeer, imsi=imsi, destinationHost=servingScscf, destinationRealm=servingScscfRealm, domain=servingScscfRealm ) - databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) - return {"result": f"Successfully deregistered {imsi} from the entire network"}, 200 - except Exception as E: - print(E) - return handle_exception(E) - - -# The below function is kept as a placeholder until Rx is implemented. -# @ns_oam.route('/deregister_ims/') -# class PyHSS_OAM_DeregisterIMS(Resource): -# def get(self, imsi): -# '''Deregisters a given IMSI from the IMS network''' -# try: -# imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) -# servingScscf = imsSubscriberInfo.get('scscf_peer', None) -# servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) -# if servingScscf and servingScscfRealm is not None: -# servingScscf = servingScscf.split(';')[0] -# diameterRequest = diameterClient.broadcastDiameterRequest( -# requestType='RTR', -# peerType='SCSCF', -# imsi=imsi, -# destinationHost=servingScscf, -# destinationRealm=servingScscfRealm, -# domain=servingScscfRealm -# ) -# databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) -# return {"result": f"Successfully deregistered {imsi} from the IMS network"}, 200 -# except Exception as E: -# print(E) -# return handle_exception(E) + + #Broadcast the RTR to all connected SCSCF's, regardless of whether the subscriber is attached. + diameterClient.broadcastDiameterRequest( + requestType='RTR', + peerType='SCSCF', + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + + databaseClient.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + + # If a subscriber has an active serving apn, grab the pcrf session id for that apn and send a CCR-T, then a Registration Termination Request to the serving pgw peer. + if subscriberId is not None: + servingApns = databaseClient.Get_Serving_APNs(subscriber_id=subscriberId) + if len(servingApns.get('apns', {})) > 0: + for apnKey, apnDict in servingApns['apns'].items(): + pcrfSessionId = None + servingPgwPeer = None + servingPgwRealm = None + servingPgw = None + for apnDataKey, apnDataValue in servingApns['apns'][apnKey].items(): + if apnDataKey == 'pcrf_session_id': + pcrfSessionId = apnDataValue + if apnDataKey == 'serving_pgw_peer': + servingPgwPeer = apnDataValue + if apnDataKey == 'serving_pgw_realm': + servingPgwRealm = apnDataValue + if apnDataKey == 'serving_pgw': + servingPgwRealm = apnDataValue + + if pcrfSessionId is not None and servingPgwPeer is not None and servingPgwRealm is not None and servingPgw is not None: + if ';' in servingPgwPeer: + servingPgwPeer = servingPgwPeer.split(';')[0] + + diameterClient.sendDiameterRequest( + requestType='CCR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId=pcrfSessionId, + domain=servingPgwRealm + ) + + diameterClient.sendDiameterRequest( + requestType='RTR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + diameterClient.broadcastDiameterRequest( + requestType='CCR', + peerType='PGW', + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId = pcrfSessionId, + domain=servingPgwRealm + ) + + diameterClient.broadcastDiameterRequest( + requestType='RTR', + peerType='PGW', + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) + imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) + servingApns = databaseClient.Get_Serving_APNs(subscriber_id=subscriberId) + + return {'subscriber': subscriberInfo, 'ims_subscriber': imsSubscriberInfo, 'pcrf': servingApns}, 200 + except Exception as E: + print(E) + return handle_exception(E) @ns_oam.route("/ping") class PyHSS_OAM_Ping(Resource): From 4c8022cae9845d73aec3352134de2d973781c4ff Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 27 Sep 2023 11:50:57 +1000 Subject: [PATCH 25/43] Remove theading import --- lib/database.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/database.py b/lib/database.py index d3f4226..3737997 100755 --- a/lib/database.py +++ b/lib/database.py @@ -15,7 +15,6 @@ import socket import pprint import S6a_crypt -import threading from messaging import RedisMessaging import yaml import json From d776dc96ea0f676b39a28a743371666446c54c2d Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 27 Sep 2023 11:56:25 +1000 Subject: [PATCH 26/43] Initial changelog --- CHANGELOG.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..fbab39b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,35 @@ +# Changelog + +All notable changes to PyHSS are documented in this file, beginning from [Service Overhaul #168](https://github.com/nickvsnetworking/pyhss/pull/168). + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.0] - 2023-09-27 + +### Added + + - Systemd service files for PyHSS services. + - /oam/diameter_peers endpoint. + - /oam/deregister/{imsi} endpoint. + - /geored/peers endpoint. + - /geored/webhooks endpoint. + - Dependency on Redis for inter-service messaging + - Significant performance improvements under load + - Basic Rx support for RAA, AAA, ASA and STA + +### Changed + +- Split logical functions of PyHSS into 6 service processes. +- Logtool no longer handles metric processing +- Updated config.yaml + +### Fixed + + - Memory leaking in diameter.py + +### Removed + +- Multithreading in all services, except for metricService. + +[1.0.0]: https://github.com/nickvsnetworking/pyhss/releases/tag/v1.0.0 \ No newline at end of file From 7d154490ac7e0cf4fde27c65458e91bf597c26af Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 07:35:45 +1000 Subject: [PATCH 27/43] Fix APC and APV AVPs --- CHANGELOG.md | 4 + lib/database.py | 52 +++++++----- lib/diameter.py | 181 +++++++++++++++++++++++++++++++++++------ services/apiService.py | 3 + 4 files changed, 195 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fbab39b..7bfc2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,16 +17,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Dependency on Redis for inter-service messaging - Significant performance improvements under load - Basic Rx support for RAA, AAA, ASA and STA + - Asymmetric geored support ### Changed - Split logical functions of PyHSS into 6 service processes. - Logtool no longer handles metric processing - Updated config.yaml +- Gx CCR-T now flushes PGW / IMS data, depending on Called-Station-Id ### Fixed - Memory leaking in diameter.py + - Gx CCA now supports apn inside a plmn based uri + - AVP_Preemption_Capability and AVP_Preemption_Vulnerability now presents correctly in all diameter messages ### Removed diff --git a/lib/database.py b/lib/database.py index 3737997..ba49831 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1636,7 +1636,6 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N finally: self.safe_close(session) - def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn), redisClient=self.redisMessaging) self.logTool.log(service='Database', level='debug', message="PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing), redisClient=self.redisMessaging) @@ -1680,31 +1679,41 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber 'subscriber_routing' : str(subscriber_routing) } - try: - #Check if already a serving APN on record - self.logTool.log(service='Database', level='debug', message="Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id), redisClient=self.redisMessaging) - ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) - self.logTool.log(service='Database', level='debug', message="Existing Serving APN ID on record, updating", redisClient=self.redisMessaging) + if serving_pgw is None: try: - assert(type(serving_pgw) == str) - assert(len(serving_pgw) > 0) - assert("None" not in serving_pgw) - - self.UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) - objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - self.handleWebhook(objectData, 'PATCH') - except: + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) self.handleWebhook(objectData, 'DELETE') self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) - except Exception as E: - self.logTool.log(service='Database', level='info', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) - #Create if does not exist - self.CreateObj(SERVING_APN, json_data, True) - ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) - objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) - self.handleWebhook(objectData, 'PUT') + except Exception as e: + self.logTool.log(service='Database', level='debug', message=f"Error when trying to delete serving_apn id: {apn_id}", redisClient=self.redisMessaging) + else: + try: + #Check if already a serving APN on record + self.logTool.log(service='Database', level='debug', message="Checking to see if subscriber id " + str(subscriber_id) + " already has an active PCRF profile on APN id " + str(apn_id), redisClient=self.redisMessaging) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + self.logTool.log(service='Database', level='debug', message="Existing Serving APN ID on record, updating", redisClient=self.redisMessaging) + try: + assert(type(serving_pgw) == str) + assert(len(serving_pgw) > 0) + assert("None" not in serving_pgw) + + self.UpdateObj(SERVING_APN, json_data, ServingAPN['serving_apn_id'], True) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'PATCH') + except: + self.logTool.log(service='Database', level='debug', message="Clearing PCRF session ID on serving_apn_id: " + str(ServingAPN['serving_apn_id']), redisClient=self.redisMessaging) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'DELETE') + self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) + except Exception as E: + self.logTool.log(service='Database', level='info', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) + #Create if does not exist + self.CreateObj(SERVING_APN, json_data, True) + ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) + objectData = self.GetObj(SERVING_APN, ServingAPN['serving_apn_id']) + self.handleWebhook(objectData, 'PUT') #Sync state change with geored if propagate == True: @@ -1724,7 +1733,6 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber except Exception as E: self.logTool.log(service='Database', level='debug', message="Nothing synced to Geographic PyHSS instances for event PCRF", redisClient=self.redisMessaging) - return def Get_Serving_APN(self, subscriber_id, apn_id): diff --git a/lib/diameter.py b/lib/diameter.py index 615014a..47b7c3f 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -699,6 +699,126 @@ def validateImsSubscriber(self, imsi=None, msisdn=None) -> bool: return True + + def deregisterApn(self, imsi: str=None, msisdn: str=None, apn: str=None) -> bool: + """ + Revokes a given UE's session with the assigned PGW (If it exists), and sends a CLR to the MME. + """ + try: + if imsi is None and msisdn is None: + return False + + if imsi is not None: + subscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if msisdn is not None: + subscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + imsi = subscriberDetails.get('imsi', '') + + if subscriberDetails is None: + return False + + subscriberId = subscriberDetails.get('subscriber_id', None) + + # If a subscriber has an active serving apn, grab the pcrf session id for that apn and send a CCR-T, then a Registration Termination Request to the serving pgw peer. + if subscriberId is not None: + servingApns = self.database.Get_Serving_APNs(subscriber_id=subscriberId) + if len(servingApns.get('apns', {})) > 0: + for apnKey, apnDict in servingApns['apns'].items(): + pcrfSessionId = None + servingPgwPeer = None + servingPgwRealm = None + servingPgw = None + for apnDataKey, apnDataValue in servingApns['apns'][apnKey].items(): + if apnDataKey == 'pcrf_session_id': + pcrfSessionId = apnDataValue + if apnDataKey == 'serving_pgw_peer': + servingPgwPeer = apnDataValue + if apnDataKey == 'serving_pgw_realm': + servingPgwRealm = apnDataValue + if apnDataKey == 'serving_pgw': + servingPgwRealm = apnDataValue + + if pcrfSessionId is not None and servingPgwPeer is not None and servingPgwRealm is not None and servingPgw is not None: + if ';' in servingPgwPeer: + servingPgwPeer = servingPgwPeer.split(';')[0] + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [deregisterData] Sending CCR-T with Session-ID:{pcrfSessionId} to peer: {servingPgwPeer} {apnKey}", redisClient=self.redisMessaging) + + self.sendDiameterRequest( + requestType='CCR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + ccr_type=3, + sessionId=pcrfSessionId, + domain=servingPgwRealm + ) + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [deregisterData] Sending RTR to peer: {servingPgwPeer} {apnKey}", redisClient=self.redisMessaging) + + self.sendDiameterRequest( + requestType='RTR', + hostname=servingPgwPeer, + imsi=imsi, + destinationHost=servingPgw, + destinationRealm=servingPgwRealm, + domain=servingPgwRealm + ) + + self.database.Update_Serving_APN(imsi=imsi, apn=apnKey, pcrf_session_id=None, serving_pgw=None, subscriber_routing='') + + return True + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [deregisterIms] Error deregistering subscriber from IMS: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + + def deregisterIms(self, imsi=None, msisdn=None) -> bool: + """ + Revokes a given UE's IMS registration, and sends a RTR to the SCSCF (if defined). + Does not revoke the pgw session, or notify the mme. + """ + try: + if imsi is None and msisdn is None: + return False + + if imsi is not None: + imsSubscriberDetails = self.database.Get_Subscriber(imsi=imsi) + if msisdn is not None: + imsSubscriberDetails = self.database.Get_Subscriber(msisdn=msisdn) + + if imsSubscriberDetails is None: + return False + + servingScscf = imsSubscriberDetails.get('scscf', None) + servingScscfPeer = imsSubscriberDetails.get('scscf_peer', None) + servingScscfRealm = imsSubscriberDetails.get('scscf_realm', None) + + if servingScscfPeer is not None and servingScscfRealm is not None and servingScscf is not None: + if ';' in servingScscfPeer: + servingScscfPeer = servingScscfPeer.split(';')[0] + servingScscf = servingScscf.replace('sip:', '') + if ';' in servingScscf: + servingScscf = servingScscf.split(';')[0] + self.sendDiameterRequest( + requestType='RTR', + peerType=servingScscfPeer, + imsi=imsi, + destinationHost=servingScscf, + destinationRealm=servingScscfRealm, + domain=servingScscfRealm + ) + + if imsi is not None: + self.database.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + elif msisdn is not None: + self.database.Update_Serving_CSCF(msisdn=msisdn, serving_cscf=None) + + return True + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [deregisterIms] Error deregistering subscriber from IMS: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: if avp_dicts['avp_code'] == 278: @@ -741,8 +861,8 @@ def Charging_Rule_Generator(self, ChargingRules, ue_ip): #ARP self.logTool.log(service='HSS', level='info', message="Defining ARP information", redisClient=self.redisMessaging) AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(ChargingRules['arp_preemption_vulnerability']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_vulnerability']), 4)) ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) self.logTool.log(service='HSS', level='info', message="Defining MBR information", redisClient=self.redisMessaging) @@ -1038,8 +1158,8 @@ def Answer_16777251_316(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "c0", 10415, self.int_to_hex(int(not apn_data['arp_preemption_vulnerability']), 4)) AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) APN_EPS_Subscribed_QoS_Profile = self.generate_vendor_avp(1431, "c0", 10415, AVP_QoS + AVP_ARP) @@ -1415,9 +1535,14 @@ def Answer_16777238_272(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Setting APN Allocation-Retention-Priority", redisClient=self.redisMessaging) #AVP: Allocation-Retention-Priority(1034) l=60 f=V-- vnd=TGPP + # Per TS 29.212, we need to flip our stored values for capability and vulnerability: + # PRE-EMPTION_CAPABILITY_ENABLED (0) + # PRE-EMPTION_CAPABILITY_DISABLED (1) + # PRE-EMPTION_VULNERABILITY_ENABLED (0) + # PRE-EMPTION_VULNERABILITY_DISABLED (1) AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(apn_data['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(apn_data['arp_preemption_vulnerability']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not apn_data['arp_preemption_vulnerability']), 4)) AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(apn_data['qci']), 4)) avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) @@ -1456,7 +1581,8 @@ def Answer_16777238_272(self, packet_vars, avps): self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Added to AVP List", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) - #If database returned an existing ChargingRule defintion add ChargingRule to CCA-I + # If database returned an existing ChargingRule defintion add ChargingRule to CCA-I + # If a Charging Rule Install AVP is present, it may trigger the creation of a dedicated bearer. if ChargingRules and ChargingRules['charging_rules'] is not None: try: self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) @@ -1470,9 +1596,18 @@ def Answer_16777238_272(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) elif int(CC_Request_Type) == 3: - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) - self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=binascii.unhexlify(session_id).decode(), serving_pgw=None, subscriber_routing=None) - + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) + if 'ims' in apn: + if not self.deregisterIms(imsi=imsi): + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to deregister IMS", redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Successfully deregistered IMS", redisClient=self.redisMessaging) + else: + if not self.deregisterData(imsi=imsi): + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to deregister Data APNs", redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Successfully deregistered Data APNs", redisClient=self.redisMessaging) + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) @@ -2389,7 +2524,7 @@ def Request_282(self): #3GPP S6a/S6d Authentication Information Request def Request_16777251_318(self, imsi, DestinationHost, DestinationRealm, requested_vectors=1): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -2413,7 +2548,7 @@ def Request_16777251_316(self, imsi, DestinationRealm): mcc = imsi[0:3] mnc = imsi[3:5] avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) @@ -2431,7 +2566,7 @@ def Request_16777251_316(self, imsi, DestinationRealm): #3GPP S6a/S6d Purge UE Request PUR def Request_16777251_321(self, imsi, DestinationRealm, DestinationHost): avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -2446,7 +2581,7 @@ def Request_16777251_321(self, imsi, DestinationRealm, DestinationHost): #3GPP S6a/S6d NOtify Request NOR def Request_16777251_323(self, imsi, DestinationRealm, DestinationHost): avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -2461,7 +2596,7 @@ def Request_16777251_323(self, imsi, DestinationRealm, DestinationHost): #3GPP S6a/S6d Cancel-Location-Request Request CLR def Request_16777251_317(self, imsi, DestinationRealm, DestinationHost=None, CancellationType=2): avp = '' - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -2480,7 +2615,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): avp = '' #Initiate empty var AVP avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_s6a' #Session ID generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_s6a' #Session ID generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID set AVP avp += self.generate_vendor_avp(266, 40, 10415, '') #AVP Vendor ID #AVP: Vendor-Specific-Application-Id(260) l=32 f=-M- @@ -2712,7 +2847,7 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): #ToDo - Check the command code here... def Request_16777216_302(self, sipaor): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate #Auth Session state avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State @@ -2731,7 +2866,7 @@ def Request_16777216_302(self, sipaor): #3GPP Cx User Authorization Request (UAR) def Request_16777216_300(self, imsi, domain): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -2747,7 +2882,7 @@ def Request_16777216_300(self, imsi, domain): #3GPP Cx Server Assignment Request (SAR) def Request_16777216_301(self, imsi, domain, server_assignment_type): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session Session ID avp += self.generate_avp(264, 40, str(binascii.hexlify(str.encode("testclient." + self.config['hss']['OriginHost'])),'ascii')) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -2765,7 +2900,7 @@ def Request_16777216_301(self, imsi, domain, server_assignment_type): #3GPP Cx Multimedia Authentication Request (MAR) def Request_16777216_303(self, imsi, domain): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -2785,7 +2920,7 @@ def Request_16777216_303(self, imsi, domain): #3GPP Cx Registration Termination Request (RTR) def Request_16777216_304(self, imsi, domain, destinationHost, destinationRealm): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_cx' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777216),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Cx) @@ -2815,7 +2950,7 @@ def Request_16777216_304(self, imsi, domain, destinationHost, destinationRealm): #3GPP Sh User-Data Request (UDR) def Request_16777217_306(self, **kwargs): avp = '' #Initiate empty var AVP #Session-ID - sessionid = str(self.OriginHost) + ';' + self.generate_id(5) + ';1;app_sh' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + ';' + self.generate_id(5) + ';1;app_sh' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session ID AVP avp += self.generate_avp(260, 40, "000001024000000c" + format(int(16777217),"x").zfill(8) + "0000010a4000000c000028af") #Vendor-Specific-Application-ID (Sh) @@ -2897,7 +3032,7 @@ def Request_16777291_8388622(self, **kwargs): avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - sessionid = 'nickpc.localdomain;' + self.generate_id(5) + ';1;app_slh' #Session state generate + sessionid = str(bytes.fromhex(self.OriginHost).decode('ascii')) + self.generate_id(5) + ';1;app_slh' #Session state generate avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session State set AVP #Username (IMSI) diff --git a/services/apiService.py b/services/apiService.py index 75eaf22..55cfa5b 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1087,6 +1087,7 @@ def get(self, imsi): servingScscf = subscriberInfo.get('scscf', None) servingScscfPeer = imsSubscriberInfo.get('scscf_peer', None) servingScscfRealm = imsSubscriberInfo.get('scscf_realm', None) + if servingMmePeer is not None and servingMmeRealm is not None and servingMme is not None: if ';' in servingMmePeer: servingMmePeer = servingMmePeer.split(';')[0] @@ -1203,6 +1204,8 @@ def get(self, imsi): domain=servingPgwRealm ) + databaseClient.Update_Serving_APN(imsi=imsi, apn=apnKey, pcrf_session_id=None, serving_pgw=None, subscriber_routing='') + subscriberInfo = databaseClient.Get_Subscriber(imsi=str(imsi)) imsSubscriberInfo = databaseClient.Get_IMS_Subscriber(imsi=str(imsi)) servingApns = databaseClient.Get_Serving_APNs(subscriber_id=subscriberId) From 7420eccfab000e841b06fbd96be8fc70471f7b63 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 09:48:41 +1000 Subject: [PATCH 28/43] Fix for empty geored / webhook peers when defined --- CHANGELOG.md | 1 + services/georedService.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bfc2e4..90f49c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Memory leaking in diameter.py - Gx CCA now supports apn inside a plmn based uri - AVP_Preemption_Capability and AVP_Preemption_Vulnerability now presents correctly in all diameter messages + - Crash when webhook or geored endpoints enabled and no peers defined ### Removed diff --git a/services/georedService.py b/services/georedService.py index 6300a39..12471d5 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -30,9 +30,10 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): if not self.config.get('geored', {}).get('enabled'): self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") quit() - if not (len(self.georedPeers) > 0): - self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") - quit() + if self.georedPeers is not None: + if not (len(self.georedPeers) > 0): + self.logger.error("[Geored] Fatal Error - no peers defined under geored.sync_endpoints, exiting.") + quit() async def sendGeored(self, asyncSession, url: str, operation: str, body: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: """ @@ -366,11 +367,13 @@ async def startService(self): georedEnabled = self.config.get('geored', {}).get('enabled', False) webhooksEnabled = self.config.get('webhooks', {}).get('enabled', False) - if not len(self.georedPeers) > 0: - georedEnabled = False + if self.georedPeers is not None: + if not len(self.georedPeers) > 0: + georedEnabled = False - if not len(self.webhookPeers) > 0: - webhooksEnabled = False + if self.webhookPeers is not None: + if not len(self.webhookPeers) > 0: + webhooksEnabled = False if not georedEnabled and not webhooksEnabled: await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [startService] Geored and Webhook services both disabled or missing peers, exiting.")) @@ -400,4 +403,4 @@ async def startService(self): if __name__ == '__main__': georedService = GeoredService() - asyncio.run(georedService.startService()) \ No newline at end of file + asyncio.run(georedService.startService()) From 32d32527e877ff0a2b72a5dc3ecfe1b90dcd20ca Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 10:28:22 +1000 Subject: [PATCH 29/43] Fix webhook always sending post --- lib/database.py | 14 +++++++++----- services/georedService.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/lib/database.py b/lib/database.py index ba49831..a3ba587 100755 --- a/lib/database.py +++ b/lib/database.py @@ -850,10 +850,11 @@ def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, return georedDict = {} if self.config.get('geored', {}).get('enabled', False): - if self.config.get('geored', {}).get('endpoints', []) is not None and len(self.config.get('geored', {}).get('endpoints', [])) > 0: - georedDict['body'] = jsonData - georedDict['operation'] = operation - self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) + if self.config.get('geored', {}).get('endpoints', []) is not None: + if len(self.config.get('geored', {}).get('endpoints', [])) > 0: + georedDict['body'] = jsonData + georedDict['operation'] = operation + self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) if asymmetric: if len(asymmetricUrls) > 0: georedDict['body'] = jsonData @@ -874,6 +875,9 @@ def handleWebhook(self, objectData, operation: str="PATCH"): if not webhooksEnabled: return False + if endpointList is None: + return False + if not len (endpointList) > 0: self.logTool.log(service='Database', level='error', message="Webhooks enabled, but endpoints are missing.", redisClient=self.redisMessaging) return False @@ -882,7 +886,7 @@ def handleWebhook(self, objectData, operation: str="PATCH"): webhook['body'] = self.Sanitize_Datetime(objectData) webhook['headers'] = webhookHeaders - webhook['operation'] = "POST" + webhook['operation'] = operation self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(webhook), queueExpiry=120) return True diff --git a/services/georedService.py b/services/georedService.py index 12471d5..2400263 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -254,11 +254,11 @@ async def handleAsymmetricGeoredQueue(self): if not len(asymmetricGeoredQueue) > 0: await(asyncio.sleep(0.01)) continue - georedMessage = await(self.redisGeoredMessaging.getMessage(queue=georedQueue)) + georedMessage = await(self.redisGeoredMessaging.getMessage(queue=asymmetricGeoredQueue)) if not len(georedMessage) > 0: await(asyncio.sleep(0.01)) continue - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Queue: {georedQueue}")) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Queue: {asymmetricGeoredQueue}")) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) georedDict = json.loads(georedMessage) @@ -385,6 +385,7 @@ async def startService(self): georedTask = asyncio.create_task(self.handleGeoredQueue()) asymmetricGeoredTask = asyncio.create_task(self.handleAsymmetricGeoredQueue()) activeTasks.append(georedTask) + activeTasks.append(asymmetricGeoredTask) if webhooksEnabled: webhookTask = asyncio.create_task(self.handleWebhookQueue()) From d65cd8de7cecf1357f2553e74980912ba2f0d2ab Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 11:50:05 +1000 Subject: [PATCH 30/43] Add configurable redis connection in config.yaml --- CHANGELOG.md | 1 + config.yaml | 3 + lib/database.py | 15 ++-- lib/diameter.py | 154 +++++++++++++++++++----------------- lib/diameterAsync.py | 11 ++- lib/logtool.py | 10 ++- lib/messaging.py | 8 +- lib/messagingAsync.py | 8 +- services/apiService.py | 5 +- services/diameterService.py | 12 ++- services/georedService.py | 10 ++- services/hssService.py | 8 +- services/logService.py | 8 +- 13 files changed, 155 insertions(+), 98 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 90f49c6..dcf8267 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Significant performance improvements under load - Basic Rx support for RAA, AAA, ASA and STA - Asymmetric geored support + - Configurable redis connection (Unix socket or TCP) ### Changed diff --git a/config.yaml b/config.yaml index 54d96a0..5f01224 100644 --- a/config.yaml +++ b/config.yaml @@ -113,6 +113,9 @@ geored: #Redis is required to run PyHSS. A locally running instance is recommended for production. redis: + # Whether to use a UNIX socket instead of a tcp connection to redis. Host and port is ignored if useUnixSocket is True. + useUnixSocket: False + unixSocketPath: '/var/run/redis/redis-server.sock' host: localhost port: 6379 diff --git a/lib/database.py b/lib/database.py index a3ba587..570ff50 100755 --- a/lib/database.py +++ b/lib/database.py @@ -261,12 +261,17 @@ class Database: def __init__(self, logTool, redisMessaging=None): with open("../config.yaml", 'r') as stream: self.config = (yaml.safe_load(stream)) + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) self.logTool = logTool if redisMessaging: self.redisMessaging = redisMessaging else: - self.redisMessaging = RedisMessaging() + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") self.engine = create_engine( @@ -469,11 +474,11 @@ def log_changes_before_commit(self, session): changes = [] for attr in class_mapper(obj.__class__).column_attrs: hist = get_history(obj, attr.key) - self.logTool.log(service='Database', level='info', message=f"History {hist}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=f"History {hist}", redisClient=self.redisMessaging) if hist.has_changes() and hist.added and hist.deleted: old_value, new_value = hist.deleted[0], hist.added[0] - self.logTool.log(service='Database', level='info', message=f"Old Value {old_value}", redisClient=self.redisMessaging) - self.logTool.log(service='Database', level='info', message=f"New Value {new_value}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=f"Old Value {old_value}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=f"New Value {new_value}", redisClient=self.redisMessaging) changes.append((attr.key, old_value, new_value)) continue @@ -1712,7 +1717,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber self.handleWebhook(objectData, 'DELETE') self.DeleteObj(SERVING_APN, ServingAPN['serving_apn_id'], True) except Exception as E: - self.logTool.log(service='Database', level='info', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Failed to update existing APN " + str(E), redisClient=self.redisMessaging) #Create if does not exist self.CreateObj(SERVING_APN, json_data, True) ServingAPN = self.Get_Serving_APN(subscriber_id=subscriber_id, apn_id=apn_id) diff --git a/lib/diameter.py b/lib/diameter.py index 47b7c3f..7a8a822 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -26,10 +26,16 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.MNC = str(mnc) self.MCC = str(mcc) self.logTool = logTool + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) if redisMessaging: - self.redisMessaging=redisMessaging + self.redisMessaging = redisMessaging else: - self.redisMessaging=RedisMessaging() + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.database = Database(logTool=logTool) self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) @@ -830,19 +836,19 @@ def AVP_278_Origin_State_Incriment(self, avps): def Charging_Rule_Generator(self, ChargingRules, ue_ip): self.logTool.log(service='HSS', level='debug', message="Called Charging_Rule_Generator", redisClient=self.redisMessaging) #Install Charging Rules - self.logTool.log(service='HSS', level='info', message="Naming Charging Rule", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Naming Charging Rule", redisClient=self.redisMessaging) Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) - self.logTool.log(service='HSS', level='info', message="Named Charging Rule", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Named Charging Rule", redisClient=self.redisMessaging) #Populate all Flow Information AVPs Flow_Information = '' for tft in ChargingRules['tft']: - self.logTool.log(service='HSS', level='info', message=tft, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) #If {{ UE_IP }} in TFT splice in the real UE IP Value try: tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) - self.logTool.log(service='HSS', level='info', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) except Exception as E: self.logTool.log(service='HSS', level='error', message="Failed to splice in UE IP into flow description", redisClient=self.redisMessaging) @@ -852,64 +858,64 @@ def Charging_Rule_Generator(self, ChargingRules, ue_ip): Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) - self.logTool.log(service='HSS', level='info', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="Defining QoS information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining QoS information", redisClient=self.redisMessaging) #QCI QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) #ARP - self.logTool.log(service='HSS', level='info', message="Defining ARP information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining ARP information", redisClient=self.redisMessaging) AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_capability']), 4)) AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_vulnerability']), 4)) ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - self.logTool.log(service='HSS', level='info', message="Defining MBR information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining MBR information", redisClient=self.redisMessaging) #Max Requested Bandwidth Bandwidth_info = '' Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) - self.logTool.log(service='HSS', level='info', message="Defining GBR information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining GBR information", redisClient=self.redisMessaging) #GBR if int(ChargingRules['gbr_ul']) != 0: Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) if int(ChargingRules['gbr_dl']) != 0: Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) - self.logTool.log(service='HSS', level='info', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) #Populate QoS Information QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) - self.logTool.log(service='HSS', level='info', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) #Precedence - self.logTool.log(service='HSS', level='info', message="Defining Precedence information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining Precedence information", redisClient=self.redisMessaging) Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) - self.logTool.log(service='HSS', level='info', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) #Rating Group - self.logTool.log(service='HSS', level='info', message="Defining Rating Group information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defining Rating Group information", redisClient=self.redisMessaging) if ChargingRules['rating_group'] != None: RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID else: RatingGroup = '' - self.logTool.log(service='HSS', level='info', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) #Complete Charging Rule Defintion - self.logTool.log(service='HSS', level='info', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) #Charging Rule Install - self.logTool.log(service='HSS', level='info', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) def Get_IMS_Subscriber_Details_from_AVP(self, username): #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data username = binascii.unhexlify(username).decode('utf-8') - self.logTool.log(service='HSS', level='info', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) username = username.split('@')[0] #Strip Domain to get User part username = username[4:] #Strip tel: or sip: prefix #Determine if dealing with IMSI or MSISDN @@ -1057,9 +1063,9 @@ def Answer_16777251_316(self, packet_vars, avps): return response except ValueError as e: - self.logTool.log(service='HSS', level='error', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Error is " + str(e), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="failed to get data backfrom database for imsi " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Error is " + str(e), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 316, 16777251, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) @@ -1167,7 +1173,7 @@ def Answer_16777251_316(self, packet_vars, avps): #Try static IP allocation try: subscriber_routing_dict = self.database.Get_SUBSCRIBER_ROUTING(subscriber_id=subscriber_details['subscriber_id'], apn_id=apn_id) #Get subscriber details - self.logTool.log(service='HSS', level='info', message="Got static UE IP " + str(subscriber_routing_dict), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got static UE IP " + str(subscriber_routing_dict), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="Found static IP for UE " + str(subscriber_routing_dict['ip_address']), redisClient=self.redisMessaging) Served_Party_Address = self.generate_vendor_avp(848, "c0", 10415, self.ip_to_hex(subscriber_routing_dict['ip_address'])) except Exception as E: @@ -1192,7 +1198,7 @@ def Answer_16777251_316(self, packet_vars, avps): #If static SMF / PGW-C defined if apn_data['pgw_address'] is not None: - self.logTool.log(service='HSS', level='info', message="MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="MIP6-Agent-Info present (Static SMF/PGW-C), value " + str(apn_data['pgw_address']), redisClient=self.redisMessaging) MIP_Home_Agent_Address = self.generate_avp(334, '40', self.ip_to_hex(apn_data['pgw_address'])) MIP6_Agent_Info = self.generate_avp(486, '40', MIP_Home_Agent_Address) else: @@ -1281,8 +1287,8 @@ def Answer_16777251_318(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message=f"{response}", redisClient=self.redisMessaging) return response except ValueError as e: - self.logTool.log(service='HSS', level='info', message="Minor getting subscriber details for IMSI " + str(imsi), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message=e, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Error getting subscriber details for IMSI " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=e, redisClient=self.redisMessaging) self.redisMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_auth_event_count', metricType='counter', metricAction='inc', metricValue=1.0, @@ -1294,7 +1300,7 @@ def Answer_16777251_318(self, packet_vars, avps): metricHelp='Diameter Authentication related Counters', metricExpiry=60) #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" - self.logTool.log(service='HSS', level='info', message="Subscriber " + str(imsi) + " is unknown in database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Subscriber " + str(imsi) + " is unknown in database", redisClient=self.redisMessaging) avp = '' session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set @@ -1350,7 +1356,7 @@ def Answer_16777251_318(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="Raw value of requested vectors is " + str(sub_avp['misc_data']), redisClient=self.redisMessaging) requested_vectors = int(sub_avp['misc_data'], 16) if requested_vectors >= 32: - self.logTool.log(service='HSS', level='info', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Client has requested " + str(requested_vectors) + " vectors, limiting this to 32", redisClient=self.redisMessaging) requested_vectors = 32 self.logTool.log(service='HSS', level='debug', message="Generating " + str(requested_vectors) + " vectors as requested", redisClient=self.redisMessaging) @@ -1491,11 +1497,11 @@ def Answer_16777238_272(self, packet_vars, avps): imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Found IMSI " + str(imsi), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] SubscriptionID: " + str(self.get_avp_data(avps, 443)), redisClient=self.redisMessaging) try: - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Getting Get_Charging_Rules for IMSI " + str(imsi) + " using APN " + str(apn) + " from database", redisClient=self.redisMessaging) #Get subscriber details ChargingRules = self.database.Get_Charging_Rules(imsi=imsi, apn=apn) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Got Charging Rules: " + str(ChargingRules), redisClient=self.redisMessaging) except Exception as E: #Handle if the subscriber is not present in HSS return "DIAMETER_ERROR_USER_UNKNOWN" self.logTool.log(service='HSS', level='debug', message=E, redisClient=self.redisMessaging) @@ -1503,7 +1509,7 @@ def Answer_16777238_272(self, packet_vars, avps): if int(CC_Request_Type) == 1: - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) #Get UE IP try: @@ -1554,7 +1560,7 @@ def Answer_16777238_272(self, packet_vars, avps): avp += self.generate_vendor_avp(1049, "80", 10415, default_EPS_QoS) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Creating QoS Information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Creating QoS Information", redisClient=self.redisMessaging) #QoS-Information try: apn_data = ChargingRules['apn_data'] @@ -1562,8 +1568,8 @@ def Answer_16777238_272(self, packet_vars, avps): apn_ambr_dl = int(apn_data['apn_ambr_dl']) QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Populated QoS_Information", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Created both QoS AVPs from data from Database", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Populated QoS_Information", redisClient=self.redisMessaging) avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) except Exception as E: self.logTool.log(service='HSS', level='error', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to get QoS information dynamically for sub " + str(imsi), redisClient=self.redisMessaging) @@ -1578,7 +1584,7 @@ def Answer_16777238_272(self, packet_vars, avps): avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS information set statically", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="[diameter.py] [Answer_16777238_272] [CCA] Added to AVP List", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Added to AVP List", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] QoS Information: " + str(QoS_Information), redisClient=self.redisMessaging) # If database returned an existing ChargingRule defintion add ChargingRule to CCA-I @@ -1658,10 +1664,10 @@ def Answer_16777216_300(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_300] [UAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: - self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 1)[0] username = binascii.unhexlify(username).decode('utf-8') - self.logTool.log(service='HSS', level='info', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Username AVP is present, value is " + str(username), redisClient=self.redisMessaging) imsi = username.split('@')[0] #Strip Domain domain = username.split('@')[1] #Get Domain Part self.logTool.log(service='HSS', level='debug', message="Extracted imsi: " + str(imsi) + " now checking backend for this IMSI", redisClient=self.redisMessaging) @@ -1722,10 +1728,10 @@ def Answer_16777216_300(self, packet_vars, avps): avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.logTool.log(service='HSS', level='info', message="Using generated S-CSCF Address as failed to source from list due to " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using generated S-CSCF Address as failed to source from list due to " + str(E), redisClient=self.redisMessaging) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.logTool.log(service='HSS', level='info', message="Using generated S-CSCF Address as none set in scscf_pool in config", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using generated S-CSCF Address as none set in scscf_pool in config", redisClient=self.redisMessaging) experimental_avp = '' experimental_avp += experimental_avp + self.generate_avp(266, 40, format(int(10415),"x").zfill(8)) #3GPP Vendor ID experimental_avp = experimental_avp + self.generate_avp(298, 40, format(int(2001),"x").zfill(8)) #DIAMETER_FIRST_REGISTRATION (2001) @@ -1760,7 +1766,7 @@ def Answer_16777216_301(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777216_301] [SAR] Remote Peer is " + str(remote_peer), redisClient=self.redisMessaging) try: - self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) self.logTool.log(service='HSS', level='debug', message="Got subscriber details: " + str(ims_subscriber_details), redisClient=self.redisMessaging) @@ -1831,7 +1837,7 @@ def Answer_16777216_302(self, packet_vars, avps): try: - self.logTool.log(service='HSS', level='info', message="Checking if username present", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Checking if username present", redisClient=self.redisMessaging) username = self.get_avp_data(avps, 601)[0] ims_subscriber_details = self.Get_IMS_Subscriber_Details_from_AVP(username) if ims_subscriber_details['scscf'] != None: @@ -1847,10 +1853,10 @@ def Answer_16777216_302(self, packet_vars, avps): avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode(scscf)),'ascii')) except Exception as E: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.logTool.log(service='HSS', level='info', message="Using generated iFC as failed to source from list due to " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using generated iFC as failed to source from list due to " + str(E), redisClient=self.redisMessaging) else: avp += self.generate_vendor_avp(602, "c0", 10415, str(binascii.hexlify(str.encode("sip:scscf.ims.mnc" + str(self.MNC).zfill(3) + ".mcc" + str(self.MCC).zfill(3) + ".3gppnetwork.org")),'ascii')) - self.logTool.log(service='HSS', level='info', message="Using generated iFC", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using generated iFC", redisClient=self.redisMessaging) except Exception as E: self.logTool.log(service='HSS', level='error', message="Threw Exception: " + str(E), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='error', message="No known MSISDN or IMSI in Answer_16777216_302() input", redisClient=self.redisMessaging) @@ -1928,7 +1934,7 @@ def Answer_16777216_303(self, packet_vars, avps): #Determine if SQN Resync is required & auth type to use for sub_avp_612 in self.get_avp_data(avps, 612)[0]: if sub_avp_612['avp_code'] == 610: - self.logTool.log(service='HSS', level='info', message="SQN in HSS is out of sync - Performing resync", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="SQN in HSS is out of sync - Performing resync", redisClient=self.redisMessaging) auts = str(sub_avp_612['misc_data'])[32:] rand = str(sub_avp_612['misc_data'])[:32] rand = binascii.unhexlify(rand) @@ -1945,9 +1951,9 @@ def Answer_16777216_303(self, packet_vars, avps): metricHelp='Diameter Authentication related Counters', metricExpiry=60) if sub_avp_612['avp_code'] == 608: - self.logTool.log(service='HSS', level='info', message="Auth mechansim requested: " + str(sub_avp_612['misc_data']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Auth mechansim requested: " + str(sub_avp_612['misc_data']), redisClient=self.redisMessaging) auth_scheme = binascii.unhexlify(sub_avp_612['misc_data']).decode('utf-8') - self.logTool.log(service='HSS', level='info', message="Auth mechansim requested: " + str(auth_scheme), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Auth mechansim requested: " + str(auth_scheme), redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="IMSI is " + str(imsi), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(601, "c0", 10415, str(binascii.hexlify(str.encode(public_identity)),'ascii')) #Public Identity (IMSI) @@ -2007,7 +2013,7 @@ def Respond_ResultCode(self, packet_vars, avps, result_code): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID except: - self.logTool.log(service='HSS', level='info', message="Failed to add SessionID into error", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Failed to add SessionID into error", redisClient=self.redisMessaging) for avps_to_check in avps: #Only include AVP 260 (Vendor-Specific-Application-ID) if inital request included it if avps_to_check['avp_code'] == 260: concat_subavp = '' @@ -2059,9 +2065,9 @@ def Answer_16777217_306(self, packet_vars, avps): try: user_identity_avp = self.get_avp_data(avps, 700)[0] msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request - self.logTool.log(service='HSS', level='info', message="Got raw MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got raw MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) msisdn = self.TBCD_decode(msisdn) - self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) except: self.logTool.log(service='HSS', level='error', message="No MSISDN", redisClient=self.redisMessaging) try: @@ -2110,7 +2116,7 @@ def Answer_16777217_306(self, packet_vars, avps): #Sh-User-Data (XML) #This loads a Jinja XML template containing the Sh-User-Data sh_userdata_template = self.config['hss']['Default_Sh_UserData'] - self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = self.templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) @@ -2335,7 +2341,7 @@ def Answer_16777252_324(self, packet_vars, avps): imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI #avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.logTool.log(service='HSS', level='info', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) @@ -2403,12 +2409,12 @@ def Answer_16777291_8388622(self, packet_vars, avps): #Try and get IMSI if present if 1 in present_avps: - self.logTool.log(service='HSS', level='info', message="IMSI AVP is present", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="IMSI AVP is present", redisClient=self.redisMessaging) try: imsi = self.get_avp_data(avps, 1)[0] #Get IMSI from User-Name AVP in request imsi = binascii.unhexlify(imsi).decode('utf-8') #Convert IMSI avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.logTool.log(service='HSS', level='info', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got IMSI with value " + str(imsi), redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message="Failed to get IMSI from LCS-Routing-Info-Request", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) @@ -2416,11 +2422,11 @@ def Answer_16777291_8388622(self, packet_vars, avps): #Try and get MSISDN if present try: msisdn = self.get_avp_data(avps, 701)[0] #Get MSISDN from AVP in request - self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) avp += self.generate_vendor_avp(701, 'c0', 10415, self.get_avp_data(avps, 701)[0]) #MSISDN - self.logTool.log(service='HSS', level='info', message="Got MSISDN with encoded value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with encoded value " + str(msisdn), redisClient=self.redisMessaging) msisdn = self.TBCD_decode(msisdn) - self.logTool.log(service='HSS', level='info', message="Got MSISDN with decoded value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with decoded value " + str(msisdn), redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message="Failed to get MSISDN from LCS-Routing-Info-Request", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message="Error was: " + str(e), redisClient=self.redisMessaging) @@ -2437,9 +2443,9 @@ def Answer_16777291_8388622(self, packet_vars, avps): subscriber_details = self.database.Get_Subscriber(msisdn=msisdn) self.logTool.log(service='HSS', level='debug', message="Got subscriber_details from MSISDN: " + str(subscriber_details), redisClient=self.redisMessaging) except Exception as E: - self.logTool.log(service='HSS', level='error', message="No MSISDN or IMSI returned in Answer_16777291_8388622 input", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Error is " + str(E), redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='error', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="No MSISDN or IMSI returned in Answer_16777291_8388622 input", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Error is " + str(E), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='info', message="Responding with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) response = self.generate_diameter_packet("01", "40", 8388622, 16777291, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet self.logTool.log(service='HSS', level='info', message="Diameter user unknown - Sending ULA with DIAMETER_ERROR_USER_UNKNOWN", redisClient=self.redisMessaging) @@ -2447,12 +2453,12 @@ def Answer_16777291_8388622(self, packet_vars, avps): - self.logTool.log(service='HSS', level='info', message="Got subscriber_details for subscriber: " + str(subscriber_details), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got subscriber_details for subscriber: " + str(subscriber_details), redisClient=self.redisMessaging) if subscriber_details['serving_mme'] == None: #DB has no location on record for subscriber - self.logTool.log(service='HSS', level='info', message="No location on record for Subscriber", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="No location on record for Subscriber", redisClient=self.redisMessaging) result_code = 4201 #DIAMETER_ERROR_ABSENT_USER (4201) #This result code shall be sent by the HSS to indicate that the location of the targeted user is not known at this time to @@ -2643,10 +2649,10 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): try: user_identity_avp = self.get_avp_data(avps, 700)[0] - self.logTool.log(service='HSS', level='info', message=user_identity_avp, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=user_identity_avp, redisClient=self.redisMessaging) msisdn = self.get_avp_data(user_identity_avp, 701)[0] #Get MSISDN from AVP in request msisdn = self.TBCD_decode(msisdn) - self.logTool.log(service='HSS', level='info', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got MSISDN with value " + str(msisdn), redisClient=self.redisMessaging) except: self.logTool.log(service='HSS', level='error', message="No MSISDN present", redisClient=self.redisMessaging) return @@ -2655,11 +2661,11 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): self.logTool.log(service='HSS', level='debug', message="Got subscriber location: " + subscriber_location, redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='info', message="Getting IMSI for MSISDN " + str(msisdn), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Getting IMSI for MSISDN " + str(msisdn), redisClient=self.redisMessaging) imsi = self.database.Get_IMSI_from_MSISDN(msisdn) avp += self.generate_avp(1, 40, self.string_to_hex(imsi)) #Username (IMSI) - self.logTool.log(service='HSS', level='info', message="Got back location data: " + str(subscriber_location), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Got back location data: " + str(subscriber_location), redisClient=self.redisMessaging) #Populate Destination Host & Realm avp += self.generate_avp(293, 40, self.string_to_hex(subscriber_location)) #Destination Host #Destination-Host @@ -2774,26 +2780,26 @@ def Request_16777251_319(self, packet_vars, avps, **kwargs): Served_Party_Address = "" if 'MIP6-Agent-Info' in apn_profile: - self.logTool.log(service='HSS', level='info', message="MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="MIP6-Agent-Info present, value " + str(apn_profile['MIP6-Agent-Info']), redisClient=self.redisMessaging) MIP6_Destination_Host = self.generate_avp(293, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_HOST']))) MIP6_Destination_Realm = self.generate_avp(283, '40', self.string_to_hex(str(apn_profile['MIP6-Agent-Info']['MIP6_DESTINATION_REALM']))) MIP6_Home_Agent_Host = self.generate_avp(348, '40', MIP6_Destination_Host + MIP6_Destination_Realm) MIP6_Agent_Info = self.generate_avp(486, '40', MIP6_Home_Agent_Host) - self.logTool.log(service='HSS', level='info', message="MIP6 value is " + str(MIP6_Agent_Info), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="MIP6 value is " + str(MIP6_Agent_Info), redisClient=self.redisMessaging) else: MIP6_Agent_Info = '' if 'PDN_GW_Allocation_Type' in apn_profile: - self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="PDN_GW_Allocation_Type present, value " + str(apn_profile['PDN_GW_Allocation_Type']), redisClient=self.redisMessaging) PDN_GW_Allocation_Type = self.generate_vendor_avp(1438, 'c0', 10415, self.int_to_hex(int(apn_profile['PDN_GW_Allocation_Type']), 4)) - self.logTool.log(service='HSS', level='info', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="PDN_GW_Allocation_Type value is " + str(PDN_GW_Allocation_Type), redisClient=self.redisMessaging) else: PDN_GW_Allocation_Type = '' if 'VPLMN_Dynamic_Address_Allowed' in apn_profile: - self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="VPLMN_Dynamic_Address_Allowed present, value " + str(apn_profile['VPLMN_Dynamic_Address_Allowed']), redisClient=self.redisMessaging) VPLMN_Dynamic_Address_Allowed = self.generate_vendor_avp(1432, 'c0', 10415, self.int_to_hex(int(apn_profile['VPLMN_Dynamic_Address_Allowed']), 4)) - self.logTool.log(service='HSS', level='info', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="VPLMN_Dynamic_Address_Allowed value is " + str(VPLMN_Dynamic_Address_Allowed), redisClient=self.redisMessaging) else: VPLMN_Dynamic_Address_Allowed = '' @@ -3222,7 +3228,7 @@ def Request_16777217_307(self, msisdn): templateLoader = jinja2.FileSystemLoader(searchpath="./") templateEnv = jinja2.Environment(loader=templateLoader) sh_userdata_template = self.config['hss']['Default_Sh_UserData'] - self.logTool.log(service='HSS', level='info', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Using template " + str(sh_userdata_template) + " for SH user data", redisClient=self.redisMessaging) template = templateEnv.get_template(sh_userdata_template) #These variables are passed to the template for use subscriber_details['mnc'] = self.MNC.zfill(3) diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index c8c3557..8a1b128 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -1,6 +1,7 @@ #Diameter Packet Decoder / Encoder & Tools import math import asyncio +import yaml from messagingAsync import RedisMessagingAsync @@ -30,7 +31,15 @@ def __init__(self, logTool): {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, ] - self.redisMessaging = RedisMessagingAsync() + with open("../config.yaml", 'r') as stream: + self.config = (yaml.safe_load(stream)) + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.logTool = logTool diff --git a/lib/logtool.py b/lib/logtool.py index b3528f4..caf20a2 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -33,8 +33,14 @@ def __init__(self, config: dict): 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, } self.logLevel = config.get('logging', {}).get('level', 'INFO') - self.redisMessagingAsync = RedisMessagingAsync() - self.redisMessaging = RedisMessaging() + + self.redisUseUnixSocket = config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = config.get('redis', {}).get('host', 'localhost') + self.redisPort = config.get('redis', {}).get('port', 6379) + + self.redisMessagingAsync = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) async def logAsync(self, service: str, level: str, message: str, redisClient=None) -> bool: """ diff --git a/lib/messaging.py b/lib/messaging.py index 8e783dd..b5b9762 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -7,9 +7,11 @@ class RedisMessaging: A class for sending and receiving redis messages. """ - def __init__(self, host: str='localhost', port: int=6379): - self.redisClient = Redis(unix_socket_path='/var/run/redis/redis-server.sock') - pass + def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): + if useUnixSocket: + self.redisClient = Redis(unix_socket_path=unixSocketPath) + else: + self.redisClient = Redis(host=host, port=port) def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: """ diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index af4f54c..aa3c003 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -8,8 +8,12 @@ class RedisMessagingAsync: A class for sending and receiving redis messages asynchronously. """ - def __init__(self, host: str='localhost', port: int=6379): - self.redisClient = redis.Redis(unix_socket_path='/var/run/redis/redis-server.sock') + def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): + if useUnixSocket: + self.redisClient = redis.Redis(unix_socket_path=unixSocketPath) + else: + self.redisClient = redis.Redis(host=host, port=port) + pass async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: """ diff --git a/services/apiService.py b/services/apiService.py index 55cfa5b..1cc1f18 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -32,7 +32,10 @@ redisHost = config.get("redis", {}).get("host", "127.0.0.1") redisPort = int(config.get("redis", {}).get("port", 6379)) -redisMessaging = RedisMessaging() +redisUseUnixSocket = config.get('redis', {}).get('useUnixSocket', False) +redisUnixSocketPath = config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + +redisMessaging = RedisMessaging(host=redisHost, port=redisPort, useUnixSocket=redisUseUnixSocket, redisUnixSocketPath=redisUnixSocketPath) logTool = LogTool(config) diff --git a/services/diameterService.py b/services/diameterService.py index 5d787d5..c7946c4 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -16,7 +16,7 @@ class DiameterService: Functions in this class are high-performance, please edit with care. Last profiled on 20-09-2023. """ - def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + def __init__(self): try: with open("../config.yaml", "r") as self.configFile: self.config = yaml.safe_load(self.configFile) @@ -24,9 +24,13 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): print(f"[Diameter] [__init__] Fatal Error - config.yaml not found, exiting.") quit() - self.redisReaderMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) - self.redisWriterMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) - self.redisPeerMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisReaderMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisWriterMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisPeerMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.banners = Banners() self.logTool = LogTool(config=self.config) self.diameterLibrary = DiameterAsync(logTool=self.logTool) diff --git a/services/georedService.py b/services/georedService.py index 2400263..01336fe 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -21,8 +21,14 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.logTool = LogTool(self.config) self.banners = Banners() - self.redisGeoredMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) - self.redisWebhookMessaging = RedisMessagingAsync(host=redisHost, port=redisPort) + + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisGeoredMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.redisWebhookMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.georedPeers = self.config.get('geored', {}).get('endpoints', []) self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) diff --git a/services/hssService.py b/services/hssService.py index e7cd5f4..b475fe7 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -7,7 +7,7 @@ class HssService: - def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + def __init__(self): try: with open("../config.yaml", "r") as self.configFile: @@ -15,7 +15,11 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): except: print(f"[HSS] Fatal Error - config.yaml not found, exiting.") quit() - self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.logTool = LogTool(config=self.config) self.banners = Banners() self.mnc = self.config.get('hss', {}).get('MNC', '999') diff --git a/services/logService.py b/services/logService.py index a6a4e03..4828195 100644 --- a/services/logService.py +++ b/services/logService.py @@ -14,7 +14,7 @@ class LogService: This class is synchronous and not high-performance. """ - def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): + def __init__(self): try: with open("../config.yaml", "r") as self.configFile: self.config = yaml.safe_load(self.configFile) @@ -23,7 +23,11 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): quit() self.logTool = LogTool(config=self.config) self.banners = Banners() - self.redisMessaging = RedisMessaging(host=redisHost, port=redisPort) + self.redisUseUnixSocket = self.config.get('redis', {}).get('useUnixSocket', False) + self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') + self.redisHost = self.config.get('redis', {}).get('host', 'localhost') + self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.logFilePaths = self.config.get('logging', {}).get('logfiles', {}) self.logLevels = { 'CRITICAL': {'verbosity': 1, 'logging': logging.CRITICAL}, From b42a9ea7c8b5c34e2d99b3dde720657c65be93e4 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 14:51:13 +1000 Subject: [PATCH 31/43] Add PCSCF state management, basic database upgrade support --- CHANGELOG.md | 16 +-- lib/database.py | 18 ++- lib/diameter.py | 23 ++-- services/apiService.py | 2 +- tools/databaseUpgrade/README.md | 21 ++++ tools/databaseUpgrade/alembic.ini | 110 ++++++++++++++++++ tools/databaseUpgrade/alembic/README | 1 + tools/databaseUpgrade/alembic/env.py | 89 ++++++++++++++ tools/databaseUpgrade/alembic/lib | 1 + tools/databaseUpgrade/alembic/script.py.mako | 24 ++++ .../2ad87e0c0c76_service_overhaul_revision.py | 34 ++++++ tools/databaseUpgrade/lib | 1 + tools/databaseUpgrade/requirements.txt | 2 + 13 files changed, 318 insertions(+), 24 deletions(-) create mode 100644 tools/databaseUpgrade/README.md create mode 100644 tools/databaseUpgrade/alembic.ini create mode 100644 tools/databaseUpgrade/alembic/README create mode 100644 tools/databaseUpgrade/alembic/env.py create mode 120000 tools/databaseUpgrade/alembic/lib create mode 100644 tools/databaseUpgrade/alembic/script.py.mako create mode 100644 tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py create mode 120000 tools/databaseUpgrade/lib create mode 100644 tools/databaseUpgrade/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index dcf8267..3de74be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,20 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - - Systemd service files for PyHSS services. - - /oam/diameter_peers endpoint. - - /oam/deregister/{imsi} endpoint. - - /geored/peers endpoint. - - /geored/webhooks endpoint. + - Systemd service files for PyHSS services + - /oam/diameter_peers endpoint + - /oam/deregister/{imsi} endpoint + - /geored/peers endpoint + - /geored/webhooks endpoint - Dependency on Redis for inter-service messaging - Significant performance improvements under load - Basic Rx support for RAA, AAA, ASA and STA - Asymmetric geored support - Configurable redis connection (Unix socket or TCP) + - Basic database upgrade support in tools/databaseUpgrade + - PCSCF state storage in ims_subscriber ### Changed -- Split logical functions of PyHSS into 6 service processes. +- Split logical functions of PyHSS into 6 service processes - Logtool no longer handles metric processing - Updated config.yaml - Gx CCR-T now flushes PGW / IMS data, depending on Called-Station-Id @@ -36,6 +38,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- Multithreading in all services, except for metricService. +- Multithreading in all services, except for metricService [1.0.0]: https://github.com/nickvsnetworking/pyhss/releases/tag/v1.0.0 \ No newline at end of file diff --git a/lib/database.py b/lib/database.py index 570ff50..876798e 100755 --- a/lib/database.py +++ b/lib/database.py @@ -191,9 +191,13 @@ class IMS_SUBSCRIBER(Base): msisdn_list = Column(String(1200), doc='Comma Separated list of additional MSISDNs for Subscriber') imsi = Column(String(18), unique=False, doc=SUBSCRIBER.imsi.doc) ifc_path = Column(String(18), doc='Path to template file for the Initial Filter Criteria') + pcscf = Column(String(512), doc='Proxy-CSCF serving this subscriber') + pcscf_realm = Column(String(512), doc='Realm of PCSCF') + pcscf_timestamp = Column(DateTime, doc='Timestamp of last ue attach to PCSCF') + pcscf_peer = Column(String(512), doc='Diameter peer used to reach PCSCF') sh_profile = Column(Text(12000), doc='Sh Subscriber Profile') scscf = Column(String(512), doc='Serving-CSCF serving this subscriber') - scscf_timestamp = Column(DateTime, doc='Timestamp of attach to S-CSCF') + scscf_timestamp = Column(DateTime, doc='Timestamp of last ue attach to SCSCF') scscf_realm = Column(String(512), doc='Realm of SCSCF') scscf_peer = Column(String(512), doc='Diameter peer used to reach SCSCF') last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') @@ -1552,16 +1556,18 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ URL = 'http://' + serving_hss + '.' + self.config['hss']['OriginRealm'] + ':8080/push/clr/' + str(imsi) self.logTool.log(service='Database', level='debug', message="Sending CLR to API at " + str(URL), redisClient=self.redisMessaging) - - self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(json_data), redisClient=self.redisMessaging) - transaction_id = str(uuid.uuid4()) - self.handleGeored({ + + clrBody = { "imsi": str(imsi), "DestinationRealm": result.serving_mme_realm, "DestinationHost": result.serving_mme, "cancellationType": 2, "diameterPeer": serving_mme_peer, - }, asymmetric=True, asymmetricUrls=[URL]) + } + + self.logTool.log(service='Database', level='debug', message="Pushing CLR to API on " + str(URL) + " with JSON body: " + str(clrBody), redisClient=self.redisMessaging) + transaction_id = str(uuid.uuid4()) + self.handleGeored(clrBody, asymmetric=True, asymmetricUrls=[URL]) else: #No currently serving MME - No action to take self.logTool.log(service='Database', level='debug', message="No currently serving MME - No need to send CLR", redisClient=self.redisMessaging) diff --git a/lib/diameter.py b/lib/diameter.py index 7a8a822..295a28b 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1093,13 +1093,11 @@ def Answer_16777251_316(self, packet_vars, avps): self.database.Update_Serving_MME(imsi=imsi, serving_mme=OriginHost, serving_mme_peer=remote_peer, serving_mme_realm=OriginRealm) - #Boilerplate AVPs avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) avp += self.generate_avp(277, 40, "00000001") #Auth-Session-State avp += self.generate_vendor_avp(1406, "c0", 10415, "00000001") #ULA Flags - #Subscription Data: subscription_data = '' subscription_data += self.generate_vendor_avp(1426, "c0", 10415, "00000000") #Access Restriction Data @@ -1508,6 +1506,7 @@ def Answer_16777238_272(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Subscriber " + str(imsi) + " unknown in HSS for CCR - Check Charging Rule assigned to APN is set and exists", redisClient=self.redisMessaging) + # CCR - Initial Request if int(CC_Request_Type) == 1: self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 1 - Initial", redisClient=self.redisMessaging) @@ -1601,18 +1600,22 @@ def Answer_16777238_272(self, packet_vars, avps): except Exception as E: self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Error in populating dynamic charging rules: " + str(E), redisClient=self.redisMessaging) + # CCR - Termination Request elif int(CC_Request_Type) == 3: self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Request type for CCA is 3 - Termination", redisClient=self.redisMessaging) if 'ims' in apn: - if not self.deregisterIms(imsi=imsi): - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to deregister IMS", redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Successfully deregistered IMS", redisClient=self.redisMessaging) + try: + self.database.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), subscriber_routing='') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored IMS state", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear stored IMS state: {traceback.format_exc()}", redisClient=self.redisMessaging) else: - if not self.deregisterData(imsi=imsi): - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Failed to deregister Data APNs", redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Successfully deregistered Data APNs", redisClient=self.redisMessaging) + try: + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), subscriber_routing='') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored state for: {apn}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear apn state for {apn}: {traceback.format_exc()}", redisClient=self.redisMessaging) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm diff --git a/services/apiService.py b/services/apiService.py index 1cc1f18..fa0f85e 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -35,7 +35,7 @@ redisUseUnixSocket = config.get('redis', {}).get('useUnixSocket', False) redisUnixSocketPath = config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') -redisMessaging = RedisMessaging(host=redisHost, port=redisPort, useUnixSocket=redisUseUnixSocket, redisUnixSocketPath=redisUnixSocketPath) +redisMessaging = RedisMessaging(host=redisHost, port=redisPort, useUnixSocket=redisUseUnixSocket, unixSocketPath=redisUnixSocketPath) logTool = LogTool(config) diff --git a/tools/databaseUpgrade/README.md b/tools/databaseUpgrade/README.md new file mode 100644 index 0000000..4376b4a --- /dev/null +++ b/tools/databaseUpgrade/README.md @@ -0,0 +1,21 @@ +# Database Upgrade + +Database upgrades are currently limited to semi-automation. + +Alembic is used to handle database schema upgades. + +This will not give a foolproof upgrade, ensure you read the generated scripts. +For best results (and in production environments), read lib/database.py and compare each base object to the table in your database. +Types for columns should also be checked. + +# Usage + +1. Ensure that `config.yaml` is populated with the correct database credentials. + +2. Navigate to `tools/databaseUpgrade` + +2. `pip3 install -r requirements.txt` + +3. `alembic revision --autogenerate -m "Name your upgrade"` + +4. `alembic upgrade head` \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic.ini b/tools/databaseUpgrade/alembic.ini new file mode 100644 index 0000000..7bb0089 --- /dev/null +++ b/tools/databaseUpgrade/alembic.ini @@ -0,0 +1,110 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +; sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/tools/databaseUpgrade/alembic/README b/tools/databaseUpgrade/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/tools/databaseUpgrade/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic/env.py b/tools/databaseUpgrade/alembic/env.py new file mode 100644 index 0000000..b0dfb0f --- /dev/null +++ b/tools/databaseUpgrade/alembic/env.py @@ -0,0 +1,89 @@ +from logging.config import fileConfig +from sqlalchemy import create_engine +from alembic import context +import yaml +import sys +import os +sys.path.append(os.path.realpath('lib')) +from database import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + +def get_url_from_config() -> str: + """ + Reads config.yaml and returns the database url. + """ + with open("config.yaml", 'r') as stream: + try: + config = yaml.safe_load(stream) + db_string = 'mysql://' + str(config['database']['username']) + ':' + str(config['database']['password']) + '@' + str(config['database']['server']) + '/' + str(config['database']['database']) + return db_string + except Exception as e: + print(e) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = create_engine(get_url_from_config()) + + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/tools/databaseUpgrade/alembic/lib b/tools/databaseUpgrade/alembic/lib new file mode 120000 index 0000000..a5bc743 --- /dev/null +++ b/tools/databaseUpgrade/alembic/lib @@ -0,0 +1 @@ +../../../lib \ No newline at end of file diff --git a/tools/databaseUpgrade/alembic/script.py.mako b/tools/databaseUpgrade/alembic/script.py.mako new file mode 100644 index 0000000..55df286 --- /dev/null +++ b/tools/databaseUpgrade/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py b/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py new file mode 100644 index 0000000..d92fa9f --- /dev/null +++ b/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py @@ -0,0 +1,34 @@ +"""Service Overhaul revision + +Revision ID: 2ad87e0c0c76 +Revises: +Create Date: 2023-09-29 04:28:33.635508 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2ad87e0c0c76' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('ims_subscriber', sa.Column('pcscf', sa.String(length=512), nullable=True)) + op.add_column('ims_subscriber', sa.Column('pcscf_realm', sa.String(length=512), nullable=True)) + op.add_column('ims_subscriber', sa.Column('pcscf_timestamp', sa.DateTime(), nullable=True)) + op.add_column('ims_subscriber', sa.Column('pcscf_peer', sa.String(length=512), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('ims_subscriber', 'pcscf_peer') + op.drop_column('ims_subscriber', 'pcscf_timestamp') + op.drop_column('ims_subscriber', 'pcscf_realm') + op.drop_column('ims_subscriber', 'pcscf') + # ### end Alembic commands ### diff --git a/tools/databaseUpgrade/lib b/tools/databaseUpgrade/lib new file mode 120000 index 0000000..58677dd --- /dev/null +++ b/tools/databaseUpgrade/lib @@ -0,0 +1 @@ +../../lib \ No newline at end of file diff --git a/tools/databaseUpgrade/requirements.txt b/tools/databaseUpgrade/requirements.txt new file mode 100644 index 0000000..691f7b6 --- /dev/null +++ b/tools/databaseUpgrade/requirements.txt @@ -0,0 +1,2 @@ +alembic==1.10.3 +zipp==3.17.0 From c0a7c43bc40989fdd09f79d495cfea5ff6936747 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 29 Sep 2023 19:54:22 +1000 Subject: [PATCH 32/43] Working Rx Call, dedicated bearer setup --- lib/database.py | 45 ++++++++++++++++- lib/diameter.py | 131 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 161 insertions(+), 15 deletions(-) diff --git a/lib/database.py b/lib/database.py index 876798e..160d3d3 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1512,7 +1512,7 @@ def Get_APN(self, apn_id): return result def Get_APN_by_Name(self, apn): - self.logTool.log(service='Database', level='debug', message="Getting APN named " + str(apn_id), redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message="Getting APN named " + str(apn), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() try: @@ -1608,6 +1608,49 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ self.safe_close(session) + def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Update_Proxy_CSCF for sub " + str(imsi) + " to pcscf " + str(proxy_cscf) + " with realm " + str(pcscf_realm) + " and peer " + str(pcscf_peer), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(IMS_SUBSCRIBER).filter_by(imsi=imsi).one() + try: + assert(type(proxy_cscf) == str) + assert(len(proxy_cscf) > 0) + self.logTool.log(service='Database', level='debug', message="Setting Proxy CSCF", redisClient=self.redisMessaging) + #Strip duplicate SIP prefix before storing + proxy_cscf = proxy_cscf.replace("sip:sip:", "sip:") + result.pcscf = proxy_cscf + result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + result.pcscf_realm = pcscf_realm + result.pcscf_peer = str(pcscf_peer) + except: + #Clear values + self.logTool.log(service='Database', level='debug', message="Clearing Proxy CSCF", redisClient=self.redisMessaging) + result.pcscf = None + result.pcscf_timestamp = None + result.pcscf_realm = None + result.pcscf_peer = None + + session.commit() + objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) + self.handleWebhook(objectData, 'PATCH') + + #Sync state change with geored + if propagate == True: + if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: + self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) + self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_peer": str(result.pcscf_peer)}) + else: + self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='Database', level='error', message="An error occurred, rolling back session: " + str(E), redisClient=self.redisMessaging) + self.safe_rollback(session) + raise + finally: + self.safe_close(session) + def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) diff --git a/lib/diameter.py b/lib/diameter.py index 295a28b..cd2ff66 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -69,8 +69,6 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 {"commandCode": 323, "applicationId": 16777251, "responseMethod": self.Answer_16777251_323, "failureResultCode": 5012 ,"requestAcronym": "NOR", "responseAcronym": "NOA", "requestName": "Notify Request", "responseName": "Notify Answer"}, {"commandCode": 324, "applicationId": 16777252, "responseMethod": self.Answer_16777252_324, "failureResultCode": 4100 ,"requestAcronym": "ECR", "responseAcronym": "ECA", "requestName": "ME Identity Check Request", "responseName": "ME Identity Check Answer"}, {"commandCode": 8388622, "applicationId": 16777291, "responseMethod": self.Answer_16777291_8388622, "failureResultCode": 4100 ,"requestAcronym": "LRR", "responseAcronym": "LRA", "requestName": "LCS Routing Info Request", "responseName": "LCS Routing Info Answer"}, - - ] self.diameterRequestList = [ @@ -79,7 +77,6 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 {"commandCode": 272, "applicationId": 16777238, "requestMethod": self.Request_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 317, "applicationId": 16777251, "requestMethod": self.Request_16777251_317, "failureResultCode": 5012 ,"requestAcronym": "CLR", "responseAcronym": "CLA", "requestName": "Cancel Location Request", "responseName": "Cancel Location Answer"}, {"commandCode": 319, "applicationId": 16777251, "requestMethod": self.Request_16777251_319, "failureResultCode": 5012 ,"requestAcronym": "ISD", "responseAcronym": "ISA", "requestName": "Insert Subscriber Data Request", "responseName": "Insert Subscriber Data Answer"}, - ] #Generates rounding for calculating padding @@ -591,7 +588,7 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: try: request = '' requestType = requestType.upper() - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) for diameterApplication in self.diameterRequestList: try: @@ -609,6 +606,7 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Error generating diameter outbound request: {traceback.format_exc()}", redisClient=self.redisMessaging) return '' def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> bool: @@ -1606,13 +1604,14 @@ def Answer_16777238_272(self, packet_vars, avps): if 'ims' in apn: try: self.database.Update_Serving_CSCF(imsi=imsi, serving_cscf=None) - self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), subscriber_routing='') + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=None) + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), serving_pgw=OriginHost, subscriber_routing='') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored IMS state", redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear stored IMS state: {traceback.format_exc()}", redisClient=self.redisMessaging) else: try: - self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), subscriber_routing='') + self.database.Update_Serving_APN(imsi=imsi, apn=apn, pcrf_session_id=str(binascii.unhexlify(session_id).decode()), serving_pgw=OriginHost, subscriber_routing='') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Successfully cleared stored state for: {apn}", redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear apn state for {apn}: {traceback.format_exc()}", redisClient=self.redisMessaging) @@ -2168,7 +2167,6 @@ def Answer_16777217_307(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 307, 16777217, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response - ################################ #### 3GPP RX #### ################################ @@ -2216,7 +2214,108 @@ def Answer_16777236_265(self, packet_vars, avps): imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) if imsEnabled: + """ + Add the PCSCF to the IMS_Subscriber object, and set the result code to 2001. + """ self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request authorized", redisClient=self.redisMessaging) + + if imsi is None: + imsi = subscriberDetails.get('imsi', None) + + aarOriginHost = self.get_avp_data(avps, 264)[0] + aarOriginHost = bytes.fromhex(aarOriginHost).decode('ascii') + aarOriginRealm = self.get_avp_data(avps, 296)[0] + aarOriginRealm = bytes.fromhex(aarOriginRealm).decode('ascii') + #Check if we have a record-route set as that's where we'll need to send the response + try: + #Get first record-route header, then parse it + remotePeer = self.get_avp_data(avps, 282)[-1] + remotePeer = binascii.unhexlify(remotePeer).decode('utf-8') + except Exception as e: + #If we don't have a record-route set, we'll send the response to the OriginHost + remotePeer = aarOriginHost + + remotePeer = f"{remotePeer};{self.config['hss']['OriginHost']}" + + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer) + """ + Check for AVP's 504 (AF-Application-Identifier) and 520 (Media-Type), which indicates the UE is making a call. + Media-Type: 0 = Audio, 4 = Control + """ + try: + afApplicationIdentifier = self.get_avp_data(avps, 504)[0] + mediaType = self.get_avp_data(avps, 520)[0] + assert(bytes.fromhex(afApplicationIdentifier).decode('ascii') == "IMS Services") + assert(int(mediaType, 16) == 0) + + # At this point, we know the AAR is indicating a call setup, so we'll send get the serving pgw information, then send a + # RAR to the PGW over Gx, asking it to setup the dedicated bearer. + + subscriberId = subscriberDetails.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + ueIp = servingApn.get('subscriber_routing', None) + + """ + The below charging rule needs to be replaced by the following logic: + 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, + 2. Compare it to a given backup rule + - If the flowrates are greater than the backup rule (UE is asking for more than allowed), use the backup rule + - If the flowrates are lesser than the backup rule, use the requested flowrates. This will allow for better utilization of radio resources. + 3. Maybe something to do with the TFT's + 4. Send the winning rule. + """ + + chargingRule = { + "charging_rule_id": 1000, + "qci": 1, + "arp_preemption_capability": True, + "mbr_dl": 128000, + "mbr_ul": 128000, + "gbr_ul": 128000, + "precedence": 100, + "arp_priority": 2, + "rule_name": "GBR-Voice", + "arp_preemption_vulnerability": False, + "gbr_dl": 128000, + "tft_group_id": 1, + "rating_group": None, + "tft": [ + { + "tft_group_id": 1, + "direction": 1, + "tft_id": 1, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535", + "last_modified": "2023-09-29T05:09:26Z" + }, + { + "tft_group_id": 1, + "direction": 2, + "tft_id": 2, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535", + "last_modified": "2023-09-29T05:09:26Z" + } + ] + } + + self.sendDiameterRequest( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + chargingRules=chargingRule, + ueIp=ueIp, + servingPgw=servingPgw, + servingRealm=servingPgwRealm + ) + + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error sending Gx RAR: {traceback.format_exc()}", redisClient=self.redisMessaging) + pass + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) else: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request unauthorized", redisClient=self.redisMessaging) @@ -3115,19 +3214,23 @@ def Request_16777238_272(self, imsi, apn, ccr_type, destinationHost, destination return response #3GPP Gx - Re Auth Request - def Request_16777238_258(self, sessionid, ChargingRules, ue_ip, Serving_PGW, Serving_Realm): + def Request_16777238_258(self, sessionId, chargingRules, ueIp, servingPgw, servingRealm): avp = '' - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionid)),'ascii')) #Session-Id set AVP + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Creating Re Auth Request", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Charging Rules: {chargingRules}", redisClient=self.redisMessaging) - #Setup Charging Rule - self.logTool.log(service='HSS', level='debug', message=ChargingRules, redisClient=self.redisMessaging) - avp += self.Charging_Rule_Generator(ChargingRules=ChargingRules, ue_ip=ue_ip) + avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionId)),'ascii')) #Session-Id set AVP + + #Setup Charging Rule + self.logTool.log(service='HSS', level='debug', message=chargingRules, redisClient=self.redisMessaging) + avp += self.Charging_Rule_Generator(ChargingRules=chargingRules, ue_ip=ueIp) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Generated Charging Rules", redisClient=self.redisMessaging) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - avp += self.generate_avp(293, 40, self.string_to_hex(Serving_PGW)) #Destination Host - avp += self.generate_avp(283, 40, self.string_to_hex(Serving_Realm)) #Destination Realm + avp += self.generate_avp(293, 40, self.string_to_hex(servingPgw)) #Destination Host + avp += self.generate_avp(283, 40, self.string_to_hex(servingRealm)) #Destination Realm avp += self.generate_avp(258, 40, format(int(16777238),"x").zfill(8)) #Auth-Application-ID Gx From dca0038a73afdd6b2b05e50600e5531424e71734 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Sat, 30 Sep 2023 10:40:23 +1000 Subject: [PATCH 33/43] Geored fixes --- lib/database.py | 79 +++++++++++++++++++++++++++++++++++------- lib/diameter.py | 2 +- services/apiService.py | 28 +++++++++++++-- 3 files changed, 93 insertions(+), 16 deletions(-) diff --git a/lib/database.py b/lib/database.py index 160d3d3..3656dc4 100755 --- a/lib/database.py +++ b/lib/database.py @@ -906,7 +906,10 @@ def Sanitize_Datetime(self, result): continue else: self.logTool.log(service='Database', level='debug', message="Key " + str(keys) + " is type DateTime with value: " + str(result[keys]) + " - Formatting to String", redisClient=self.redisMessaging) - result[keys] = str(result[keys]) + try: + result[keys] = result[keys].strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result[keys] = str(result[keys]) return result def Sanitize_Keys(self, result): @@ -1530,7 +1533,7 @@ def Update_AuC(self, auc_id, sqn=1): self.logTool.log(service='Database', level='debug', message=self.UpdateObj(AUC, {'sqn': sqn}, auc_id, True), redisClient=self.redisMessaging) return - def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, propagate=True): + def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_mme_peer=None, serving_mme_timestamp=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Updating Serving MME for sub " + str(imsi) + " to MME " + str(serving_mme), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() @@ -1575,7 +1578,17 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ if type(serving_mme) == str: self.logTool.log(service='Database', level='debug', message="Updating serving MME & Timestamp", redisClient=self.redisMessaging) result.serving_mme = serving_mme - result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + try: + if serving_mme_timestamp is not None and serving_mme_timestamp is not 'None': + result.serving_mme_timestamp = datetime.strptime(serving_mme_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.serving_mme_timestamp = result.serving_mme_timestamp.replace(tzinfo=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.serving_mme_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') result.serving_mme_realm = serving_mme_realm result.serving_mme_peer = serving_mme_peer else: @@ -1585,11 +1598,15 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ result.serving_mme_timestamp = None result.serving_mme_realm = None result.serving_mme_peer = None + serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') session.commit() objectData = self.GetObj(SUBSCRIBER, result.subscriber_id) self.handleWebhook(objectData, 'PATCH') + if result.serving_mme_timestamp is not None: + result.serving_mme_timestamp = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + #Sync state change with geored if propagate == True: if 'HSS' in self.config['geored'].get('sync_actions', []) and self.config['geored'].get('enabled', False) == True: @@ -1598,7 +1615,8 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ "imsi": str(imsi), "serving_mme": result.serving_mme, "serving_mme_realm": str(result.serving_mme_realm), - "serving_mme_peer": str(result.serving_mme_peer) + "serving_mme_peer": str(result.serving_mme_peer), + "serving_mme_timestamp": serving_mme_timestamp_string }) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of HSS events", redisClient=self.redisMessaging) @@ -1608,7 +1626,7 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ self.safe_close(session) - def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, propagate=True): + def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, pcscf_timestamp=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Update_Proxy_CSCF for sub " + str(imsi) + " to pcscf " + str(proxy_cscf) + " with realm " + str(pcscf_realm) + " and peer " + str(pcscf_peer), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() @@ -1622,7 +1640,17 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, #Strip duplicate SIP prefix before storing proxy_cscf = proxy_cscf.replace("sip:sip:", "sip:") result.pcscf = proxy_cscf - result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + try: + if pcscf_timestamp is not None and pcscf_timestamp is not 'None': + result.pcscf_timestamp = datetime.strptime(pcscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.pcscf_timestamp = result.pcscf_timestamp.replace(tzinfo=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.pcscf_timestamp = datetime.datetime.now(tz=timezone.utc) + pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') result.pcscf_realm = pcscf_realm result.pcscf_peer = str(pcscf_peer) except: @@ -1632,6 +1660,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, result.pcscf_timestamp = None result.pcscf_realm = None result.pcscf_peer = None + pcscf_timestamp_string = None session.commit() objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) @@ -1641,7 +1670,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, if propagate == True: if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) - self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_peer": str(result.pcscf_peer)}) + self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": str(result.pcscf_peer)}) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) except Exception as E: @@ -1651,7 +1680,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, finally: self.safe_close(session) - def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=None, propagate=True): + def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=None, scscf_timestamp=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Update_Serving_CSCF for sub " + str(imsi) + " to SCSCF " + str(serving_cscf) + " with realm " + str(scscf_realm) + " and peer " + str(scscf_peer), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() @@ -1665,7 +1694,17 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N #Strip duplicate SIP prefix before storing serving_cscf = serving_cscf.replace("sip:sip:", "sip:") result.scscf = serving_cscf - result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + try: + if scscf_timestamp is not None and scscf_timestamp is not 'None': + result.scscf_timestamp = datetime.strptime(scscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') + result.scscf_timestamp = result.scscf_timestamp.replace(tzinfo=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + result.scscf_timestamp = datetime.datetime.now(tz=timezone.utc) + scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') result.scscf_realm = scscf_realm result.scscf_peer = str(scscf_peer) except: @@ -1675,6 +1714,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N result.scscf_timestamp = None result.scscf_realm = None result.scscf_peer = None + scscf_timestamp_string = None session.commit() objectData = self.GetObj(IMS_SUBSCRIBER, result.ims_subscriber_id) @@ -1684,7 +1724,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N if propagate == True: if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) - self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_peer": str(result.scscf_peer)}) + self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_timestamp": scscf_timestamp_string, "scscf_peer": str(result.scscf_peer)}) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) except Exception as E: @@ -1694,7 +1734,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N finally: self.safe_close(session) - def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, propagate=True): + def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber_routing, serving_pgw_realm=None, serving_pgw_peer=None, serving_pgw_timestamp=None, propagate=True): self.logTool.log(service='Database', level='debug', message="Called Update_Serving_APN() for imsi " + str(imsi) + " with APN " + str(apn), redisClient=self.redisMessaging) self.logTool.log(service='Database', level='debug', message="PCRF Session ID " + str(pcrf_session_id) + " and serving PGW " + str(serving_pgw) + " and subscriber routing " + str(subscriber_routing), redisClient=self.redisMessaging) self.logTool.log(service='Database', level='debug', message="Serving PGW Realm is: " + str(serving_pgw_realm) + " and peer is: " + str(serving_pgw_peer), redisClient=self.redisMessaging) @@ -1726,6 +1766,20 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber break self.logTool.log(service='Database', level='debug', message="APN ID is " + str(apn_id), redisClient=self.redisMessaging) + try: + if serving_pgw_timestamp is not None and serving_pgw_timestamp is not 'None': + serving_pgw_timestamp = datetime.strptime(serving_pgw_timestamp, '%Y-%m-%dT%H:%M:%SZ') + serving_pgw_timestamp = serving_pgw_timestamp.replace(tzinfo=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + else: + serving_pgw_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + except Exception as e: + serving_pgw_timestamp = datetime.datetime.now(tz=timezone.utc) + serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') + serving_pgw_realm = serving_pgw_realm + serving_pgw_peer = serving_pgw_peer + json_data = { 'apn' : apn_id, 'subscriber_id' : subscriber_id, @@ -1733,7 +1787,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber 'serving_pgw' : str(serving_pgw), 'serving_pgw_realm' : str(serving_pgw_realm), 'serving_pgw_peer' : str(serving_pgw_peer), - 'serving_pgw_timestamp' : datetime.datetime.now(tz=timezone.utc), + 'serving_pgw_timestamp' : serving_pgw_timestamp, 'subscriber_routing' : str(subscriber_routing) } @@ -1784,6 +1838,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber 'serving_pgw': str(serving_pgw), 'serving_pgw_realm': str(serving_pgw_realm), 'serving_pgw_peer': str(serving_pgw_peer), + 'serving_pgw_timestamp': serving_pgw_timestamp_string, 'subscriber_routing': str(subscriber_routing) }) else: diff --git a/lib/diameter.py b/lib/diameter.py index cd2ff66..98f72a2 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1814,7 +1814,7 @@ def Answer_16777216_301(self, packet_vars, avps): ServingCSCF = binascii.unhexlify(ServingCSCF).decode('utf-8') #Format it self.logTool.log(service='HSS', level='debug', message="Subscriber is served by S-CSCF " + str(ServingCSCF), redisClient=self.redisMessaging) if (Server_Assignment_Type == 1) or (Server_Assignment_Type == 2): - self.logTool.log(service='HSS', level='debug', message="SAR is Register / Re-Restister", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="SAR is Register / Re-Register", redisClient=self.redisMessaging) remote_peer = remote_peer + ";" + str(self.config['hss']['OriginHost']) self.database.Update_Serving_CSCF(imsi, serving_cscf=ServingCSCF, scscf_realm=OriginRealm, scscf_peer=remote_peer) else: diff --git a/services/apiService.py b/services/apiService.py index fa0f85e..db59956 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1507,6 +1507,8 @@ def patch(self): json_data['serving_pgw_realm'] = None if 'serving_pgw_peer' not in json_data: json_data['serving_pgw_peer'] = None + if 'serving_pgw_timestamp' not in json_data: + json_data['serving_pgw_timestamp'] = None response_data.append(databaseClient.Update_Serving_APN( imsi=str(json_data['imsi']), apn=json_data['serving_apn'], @@ -1515,6 +1517,7 @@ def patch(self): subscriber_routing=json_data['subscriber_routing'], serving_pgw_realm=json_data['serving_pgw_realm'], serving_pgw_peer=json_data['serving_pgw_peer'], + serving_pgw_timestamp=json_data['serving_pgw_timestamp'], propagate=False)) redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', metricType='counter', metricAction='inc', @@ -1525,17 +1528,36 @@ def patch(self): }, metricExpiry=60) if 'scscf' in json_data: - print("Updating serving SCSCF") + print("Updating Serving SCSCF") if 'scscf_realm' not in json_data: json_data['scscf_realm'] = None if 'scscf_peer' not in json_data: json_data['scscf_peer'] = None - response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), propagate=False)) + if 'scscf_timestamp' not in json_data: + json_data['scscf_timestamp'] = None + response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), scscf_timestamp=json_data['scscf_timestamp'], propagate=False)) redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes Received', metricLabels={ - "endpoint": "IMS", + "endpoint": "IMS_SCSCF", + "geored_host": request.remote_addr, + }, + metricExpiry=60) + if 'pcscf' in json_data: + print("Updating Proxy SCSCF") + if 'pcscf_realm' not in json_data: + json_data['pcscf_realm'] = None + if 'pcscf_peer' not in json_data: + json_data['pcscf_peer'] = None + if 'pcscf_timestamp' not in json_data: + json_data['pcscf_timestamp'] = None + response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=str(json_data['pcscf_realm']), pcscf_peer=str(json_data['pcscf_peer']), pcscf_timestamp=json_data['pcscf_timestamp'], propagate=False)) + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "IMS_PCSCF", "geored_host": request.remote_addr, }, metricExpiry=60) From b3a7cb1ac15aafd756f1158889e9e7e3919bdee5 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Mon, 2 Oct 2023 15:53:58 +1000 Subject: [PATCH 34/43] Rx Call setup awaiting RAA per correct flow --- lib/diameter.py | 320 ++++++++++++++++++++++++------------ lib/diameterAsync.py | 24 ++- lib/messaging.py | 21 ++- lib/messagingAsync.py | 2 +- services/apiService.py | 1 + services/diameterService.py | 2 +- services/hssService.py | 3 +- 7 files changed, 253 insertions(+), 120 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index 98f72a2..dbd4fb5 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -59,9 +59,9 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, - {"commandCode": 258, "applicationId": 16777236, "responseMethod": self.Answer_16777236_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 258, "applicationId": 16777238, "responseMethod": self.Answer_16777238_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, @@ -525,61 +525,68 @@ def decode_diameter_packet_length(self, data): return False def getPeerType(self, originHost: str) -> str: - try: - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + try: + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] - for peer in peerTypes: - if peer in originHost.lower(): - return peer - - except Exception as e: - return '' + for peer in peerTypes: + if peer in originHost.lower(): + return peer + + except Exception as e: + return '' def getConnectedPeersByType(self, peerType: str) -> list: - try: - peerType = peerType.lower() - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + try: + peerType = peerType.lower() + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] - if peerType not in peerTypes: - return [] - filteredConnectedPeers = [] - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + if peerType not in peerTypes: + return [] + filteredConnectedPeers = [] + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) - for key, value in activePeers.items(): - if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': - filteredConnectedPeers.append(activePeers.get(key, {})) - - return filteredConnectedPeers + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + filteredConnectedPeers.append(activePeers.get(key, {})) + + return filteredConnectedPeers - except Exception as e: - return [] + except Exception as e: + return [] def getPeerByHostname(self, hostname: str) -> dict: - try: - hostname = hostname.lower() - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + try: + hostname = hostname.lower() + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) - for key, value in activePeers.items(): - if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': - return(activePeers.get(key, {})) + for key, value in activePeers.items(): + if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': + return(activePeers.get(key, {})) - except Exception as e: - return {} + except Exception as e: + return {} def getDiameterMessageType(self, binaryData: str) -> dict: - packet_vars, avps = self.decode_diameter_packet(binaryData) - response = {} - - for diameterApplication in self.diameterResponseList: - try: - assert(packet_vars["command_code"] == diameterApplication["commandCode"]) - assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + """ + Determines whether a message is a request or a response, and the appropriate acronyms for each type. + """ + packet_vars, avps = self.decode_diameter_packet(binaryData) + response = {} + + for diameterApplication in self.diameterResponseList: + try: + assert(packet_vars["command_code"] == diameterApplication["commandCode"]) + assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) + if packet_vars["flags_bin"][0:1] == "1": response['inbound'] = diameterApplication["requestAcronym"] response['outbound'] = diameterApplication["responseAcronym"] - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Matched message types: {response}", redisClient=self.redisMessaging) - except Exception as e: - continue - return response + else: + response['inbound'] = diameterApplication["responseAcronym"] + response['outbound'] = diameterApplication["requestAcronym"] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] Matched message types: {response}", redisClient=self.redisMessaging) + except Exception as e: + continue + return response def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: """ @@ -599,11 +606,11 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: peerIp = connectedPeer['ipAddress'] peerPort = connectedPeer['port'] request = diameterApplication["requestMethod"](**kwargs) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" outboundMessage = json.dumps({'diameter-outbound': request}) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Error generating diameter outbound request: {traceback.format_exc()}", redisClient=self.redisMessaging) @@ -637,6 +644,81 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> except Exception as e: return '' + def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeout: float=0.12, **kwargs) -> str: + """ + Sends a given diameter request of requestType to the provided peer hostname. + Ensures the peer is connected, sends the request, then waits on and returns the response. + If the timeout is reached, the function fails. + + Diameter lacks a unique identifier for all message types, the closest being Session-ID which exists for most. + We attempt to get the associated response given the following logic: + - If sessionId is none, attempt to return the first response that matches the expected response method (eg AAA, CEA, etc.) which has a timestamp greater than sendTime. + - If sessionId is not none, perform the logic above, and also ensure that sessionId matches. + + Returns an empty string if fails. + + Until diameter.py is rewritten to be asynchronous, this method should be called only when strictly necessary. It potentially adds up to 120ms of delay per invocation. + """ + try: + request = '' + requestType = requestType.upper() + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Generating a diameter outbound request", redisClient=self.redisMessaging) + + for diameterApplication in self.diameterRequestList: + try: + assert(requestType == diameterApplication["requestAcronym"]) + except Exception as e: + continue + connectedPeer = self.getPeerByHostname(hostname=hostname) + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + request = diameterApplication["requestMethod"](**kwargs) + responseType = diameterApplication["responseAcronym"] + sessionId = kwargs.get('sessionId', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + sendTime = time.time_ns() + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{sendTime}" + outboundMessage = json.dumps({'diameter-outbound': request}) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) + startTimer = time.time() + while True: + try: + if not time.time() >= startTimer + timeout: + if sessionId is None: + responseQueues = self.redisMessaging.getQueues(pattern=f"diameter-inbound-{peerIp.replace('.', '*')}-{peerPort}-{responseType}*") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] responseQueues(NoSessionId): {responseQueues}", redisClient=self.redisMessaging) + for responseQueue in responseQueues: + if float(responseQueue.split('-')[5]) > sendTime: + inboundResponseList = self.redisMessaging.getMessage(queue=responseQueue) + if len(inboundResponseList) > 0: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Found inbound response: {inboundResponse}", redisClient=self.redisMessaging) + return json.loads(inboundResponseList[0]).get('diameter-inbound', '') + time.sleep(0.02) + else: + responseQueues = self.redisMessaging.getQueues(pattern=f"diameter-inbound-{peerIp.replace('.', '*')}-{peerPort}-{responseType}*") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] responseQueues({sessionId}): {responseQueues} responseType: {responseType}", redisClient=self.redisMessaging) + for responseQueue in responseQueues: + if float(responseQueue.split('-')[5]) > sendTime: + inboundResponseList = self.redisMessaging.getList(key=responseQueue) + if len(inboundResponseList) > 0: + for inboundResponse in inboundResponseList: + responseHex = json.loads(inboundResponse)['diameter-inbound'] + packetVars, avps = self.decode_diameter_packet(responseHex) + responseSessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') + if responseSessionId == sessionId: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Matched on Session Id: {sessionId}", redisClient=self.redisMessaging) + return json.loads(inboundResponseList[0]).get('diameter-inbound', '') + time.sleep(0.02) + else: + return '' + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Traceback: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Error generating diameter outbound request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + def generateDiameterResponse(self, binaryData: str) -> str: try: packet_vars, avps = self.decode_diameter_packet(binaryData) @@ -2251,72 +2333,87 @@ def Answer_16777236_265(self, packet_vars, avps): # At this point, we know the AAR is indicating a call setup, so we'll send get the serving pgw information, then send a # RAR to the PGW over Gx, asking it to setup the dedicated bearer. - subscriberId = subscriberDetails.get('subscriber_id', None) - apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) - servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) - servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] - servingPgw = servingApn.get('serving_pgw', None) - servingPgwRealm = servingApn.get('serving_pgw_realm', None) - pcrfSessionId = servingApn.get('pcrf_session_id', None) - ueIp = servingApn.get('subscriber_routing', None) - - """ - The below charging rule needs to be replaced by the following logic: - 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, - 2. Compare it to a given backup rule - - If the flowrates are greater than the backup rule (UE is asking for more than allowed), use the backup rule - - If the flowrates are lesser than the backup rule, use the requested flowrates. This will allow for better utilization of radio resources. - 3. Maybe something to do with the TFT's - 4. Send the winning rule. - """ - - chargingRule = { - "charging_rule_id": 1000, - "qci": 1, - "arp_preemption_capability": True, - "mbr_dl": 128000, - "mbr_ul": 128000, - "gbr_ul": 128000, - "precedence": 100, - "arp_priority": 2, - "rule_name": "GBR-Voice", - "arp_preemption_vulnerability": False, - "gbr_dl": 128000, - "tft_group_id": 1, - "rating_group": None, - "tft": [ - { - "tft_group_id": 1, - "direction": 1, - "tft_id": 1, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535", - "last_modified": "2023-09-29T05:09:26Z" - }, - { + try: + subscriberId = subscriberDetails.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + ueIp = servingApn.get('subscriber_routing', None) + + """ + The below charging rule needs to be replaced by the following logic: + 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, + 2. Compare it to a given backup rule + - If the flowrates are greater than the backup rule (UE is asking for more than allowed), use the backup rule + - If the flowrates are lesser than the backup rule, use the requested flowrates. This will allow for better utilization of radio resources. + 3. Maybe something to do with the TFT's + 4. Send the winning rule. + """ + + chargingRule = { + "charging_rule_id": 1000, + "qci": 1, + "arp_preemption_capability": True, + "mbr_dl": 128000, + "mbr_ul": 128000, + "gbr_ul": 128000, + "precedence": 100, + "arp_priority": 2, + "rule_name": "GBR-Voice", + "arp_preemption_vulnerability": False, + "gbr_dl": 128000, "tft_group_id": 1, - "direction": 2, - "tft_id": 2, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535", - "last_modified": "2023-09-29T05:09:26Z" + "rating_group": None, + "tft": [ + { + "tft_group_id": 1, + "direction": 1, + "tft_id": 1, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + }, + { + "tft_group_id": 1, + "direction": 2, + "tft_id": 2, + "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + } + ] } - ] - } - self.sendDiameterRequest( - requestType='RAR', - hostname=servingPgwPeer, - sessionId=pcrfSessionId, - chargingRules=chargingRule, - ueIp=ueIp, - servingPgw=servingPgw, - servingRealm=servingPgwRealm - ) + reAuthAnswer = self.awaitDiameterRequestAndResponse( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + chargingRules=chargingRule, + ueIp=ueIp, + servingPgw=servingPgw, + servingRealm=servingPgwRealm + ) + + if not len(reAuthAnswer) > 0: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) + assert() + + raaPacketVars, raaAvps = self.decode_diameter_packet(reAuthAnswer) + raaResultCode = int(self.get_avp_data(raaAvps, 268)[0], 16) + + if raaResultCode == 2001: + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA returned Successfully, authorizing request", redisClient=self.redisMessaging) + else: + avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAA returned Unauthorized, declining request", redisClient=self.redisMessaging) + + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error processing RAR / RAA, Authorizing request: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) except Exception as e: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Error sending Gx RAR: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) pass - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) else: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request unauthorized", redisClient=self.redisMessaging) avp += self.generate_avp(268, 40, self.int_to_hex(4001, 4)) @@ -2410,6 +2507,7 @@ def Answer_16777236_275(self, packet_vars, avps): avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response except Exception as e: @@ -2427,12 +2525,26 @@ def Answer_16777236_274(self, packet_vars, avps): avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) response = self.generate_diameter_packet("01", "40", 274, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_274] [STA] Error generating STA: {traceback.format_exc()}", redisClient=self.redisMessaging) + # Re Auth Answer + def Answer_16777238_258(self, packet_vars, avps): + try: + avp = '' + session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 274, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_274] [RAA] Error generating RAA: {traceback.format_exc()}", redisClient=self.redisMessaging) #3GPP S13 - ME-Identity-Check Answer def Answer_16777252_324(self, packet_vars, avps): diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index 8a1b128..f8c5be9 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -19,9 +19,9 @@ def __init__(self, logTool): {"commandCode": 306, "applicationId": 16777217, "responseMethod": self.Answer_16777217_306, "failureResultCode": 5001 ,"requestAcronym": "UDR", "responseAcronym": "UDA", "requestName": "User Data Request", "responseName": "User Data Answer"}, {"commandCode": 307, "applicationId": 16777217, "responseMethod": self.Answer_16777217_307, "failureResultCode": 5001 ,"requestAcronym": "PRUR", "responseAcronym": "PRUA", "requestName": "Profile Update Request", "responseName": "Profile Update Answer"}, {"commandCode": 265, "applicationId": 16777236, "responseMethod": self.Answer_16777236_265, "failureResultCode": 4100 ,"requestAcronym": "AAR", "responseAcronym": "AAA", "requestName": "AA Request", "responseName": "AA Answer"}, - {"commandCode": 258, "applicationId": 16777236, "responseMethod": self.Answer_16777236_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, {"commandCode": 275, "applicationId": 16777236, "responseMethod": self.Answer_16777236_275, "failureResultCode": 4100 ,"requestAcronym": "STR", "responseAcronym": "STA", "requestName": "Session Termination Request", "responseName": "Session Termination Answer"}, {"commandCode": 274, "applicationId": 16777236, "responseMethod": self.Answer_16777236_274, "failureResultCode": 4100 ,"requestAcronym": "ASR", "responseAcronym": "ASA", "requestName": "Abort Session Request", "responseName": "Abort Session Answer"}, + {"commandCode": 258, "applicationId": 16777238, "responseMethod": self.Answer_16777238_258, "failureResultCode": 4100 ,"requestAcronym": "RAR", "responseAcronym": "RAA", "requestName": "Re Auth Request", "responseName": "Re Auth Answer"}, {"commandCode": 272, "applicationId": 16777238, "responseMethod": self.Answer_16777238_272, "failureResultCode": 5012 ,"requestAcronym": "CCR", "responseAcronym": "CCA", "requestName": "Credit Control Request", "responseName": "Credit Control Answer"}, {"commandCode": 318, "applicationId": 16777251, "flags": "c0", "responseMethod": self.Answer_16777251_318, "failureResultCode": 4100 ,"requestAcronym": "AIR", "responseAcronym": "AIA", "requestName": "Authentication Information Request", "responseName": "Authentication Information Answer"}, {"commandCode": 316, "applicationId": 16777251, "responseMethod": self.Answer_16777251_316, "failureResultCode": 4100 ,"requestAcronym": "ULR", "responseAcronym": "ULA", "requestName": "Update Location Request", "responseName": "Update Location Answer"}, @@ -229,7 +229,7 @@ async def decodeAvpPacket(self, data): async def getPeerType(self, originHost: str) -> str: try: - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] for peer in peerTypes: if peer in originHost.lower(): @@ -241,7 +241,7 @@ async def getPeerType(self, originHost: str) -> str: async def getConnectedPeersByType(self, peerType: str) -> list: try: peerType = peerType.lower() - peerTypes = ['mme', 'pgw', 'icscf', 'scscf', 'hss', 'ocs'] + peerTypes = ['mme', 'pgw', 'pcscf', 'icscf', 'scscf', 'hss', 'ocs', 'dra'] if peerType not in peerTypes: return [] @@ -257,8 +257,10 @@ async def getConnectedPeersByType(self, peerType: str) -> list: except Exception as e: return [] - async def getDiameterMessageType(self, binaryData: str) -> dict: + """ + Determines whether a message is a request or a response, and the appropriate acronyms for each type. + """ packet_vars, avps = await(self.decodeDiameterPacket(binaryData)) response = {} @@ -266,8 +268,12 @@ async def getDiameterMessageType(self, binaryData: str) -> dict: try: assert(packet_vars["command_code"] == diameterApplication["commandCode"]) assert(packet_vars["ApplicationId"] == diameterApplication["applicationId"]) - response['inbound'] = diameterApplication["requestAcronym"] - response['outbound'] = diameterApplication["responseAcronym"] + if packet_vars["flags_bin"][0:1] == "1": + response['inbound'] = diameterApplication["requestAcronym"] + response['outbound'] = diameterApplication["responseAcronym"] + else: + response['inbound'] = diameterApplication["responseAcronym"] + response['outbound'] = diameterApplication["requestAcronym"] except Exception as e: continue @@ -344,11 +350,11 @@ async def Answer_16777291_8388622(self): async def Answer_16777236_265(self): pass - async def Answer_16777236_258(self): - pass - async def Answer_16777236_275(self): pass async def Answer_16777236_274(self): + pass + + async def Answer_16777238_258(self): pass \ No newline at end of file diff --git a/lib/messaging.py b/lib/messaging.py index b5b9762..62b4a78 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,5 +1,5 @@ from redis import Redis -import time, json, uuid +import time, json, uuid, traceback class RedisMessaging: """ @@ -88,11 +88,11 @@ def getQueues(self, pattern: str='*') -> list: Returns all Queues (Keys) in the database. """ try: - allQueues = self.redisClient.keys(pattern) + allQueues = self.redisClient.scan_iter(match=pattern) return [x.decode() for x in allQueues] except Exception as e: - return [] - + return f"{traceback.format_exc()}" + def getNextQueue(self, pattern: str='*') -> dict: """ Returns the next Queue (Key) in the list. @@ -138,6 +138,19 @@ def getValue(self, key: str) -> str: except Exception as e: return '' + def getList(self, key: str) -> list: + """ + Gets the list stored under a given key. + """ + try: + allResults = self.redisClient.lrange(key, 0, -1) + if allResults is None: + result = [] + else: + return [result.decode() for result in allResults] + except Exception as e: + return [] + def RedisHGetAll(self, key: str): """ Wrapper for Redis HGETALL diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index aa3c003..baed706 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -99,7 +99,7 @@ async def getQueues(self, pattern: str='*') -> list: Returns all Queues (Keys) in the database, asynchronously. """ try: - allQueuesBinary = await(self.redisClient.keys(pattern)) + allQueuesBinary = await(self.redisClient.scan_iter(match=pattern)) allQueues = [x.decode() for x in allQueuesBinary] return allQueues except Exception as e: diff --git a/services/apiService.py b/services/apiService.py index db59956..3a92ce8 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -248,6 +248,7 @@ def page_not_found(e): @apiService.after_request def apply_caching(response): response.headers["HSS"] = str(config['hss']['OriginHost']) + response.headers["Access-Control-Allow-Origin"] = "*" return response @ns_apn.route('/') diff --git a/services/diameterService.py b/services/diameterService.py index c7946c4..0969170 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -144,7 +144,7 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=inboundData)) diameterMessageType = diameterMessageType.get('inbound', '') - inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{time.time_ns()}" + inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{diameterMessageType}-{time.time_ns()}" inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}")) await(self.redisReaderMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) diff --git a/services/hssService.py b/services/hssService.py index b475fe7..dc8027d 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -49,7 +49,8 @@ def handleQueue(self): inboundSplit = str(inboundQueue).split('-') inboundHost = inboundSplit[2] inboundPort = inboundSplit[3] - inboundTimestamp = inboundSplit[4] + inboundMessageType = inboundSplit[4] + inboundTimestamp = inboundSplit[5] try: diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) From 5260314f95167cec8dc41983b230ff8fef468a8d Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 3 Oct 2023 20:42:39 +1000 Subject: [PATCH 35/43] Successful Rx Call, with dedicated bearer setup and teardown --- CHANGELOG.md | 2 + lib/database.py | 27 +- lib/diameter.py | 248 +++++++++++------- services/apiService.py | 1 - tools/databaseUpgrade/alembic/env.py | 2 +- .../2ad87e0c0c76_service_overhaul_revision.py | 34 --- 6 files changed, 185 insertions(+), 129 deletions(-) delete mode 100644 tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3de74be..792f08e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Dependency on Redis for inter-service messaging - Significant performance improvements under load - Basic Rx support for RAA, AAA, ASA and STA + - Rx MO call flow support (AAR -> RAR -> RAA -> AAA) + - Dedicated bearer setup and teardown on Rx call - Asymmetric geored support - Configurable redis connection (Unix socket or TCP) - Basic database upgrade support in tools/databaseUpgrade diff --git a/lib/database.py b/lib/database.py index 3656dc4..41f4823 100755 --- a/lib/database.py +++ b/lib/database.py @@ -193,6 +193,7 @@ class IMS_SUBSCRIBER(Base): ifc_path = Column(String(18), doc='Path to template file for the Initial Filter Criteria') pcscf = Column(String(512), doc='Proxy-CSCF serving this subscriber') pcscf_realm = Column(String(512), doc='Realm of PCSCF') + pcscf_active_session = Column(String(512), doc='Session Id for the PCSCF when in a call') pcscf_timestamp = Column(DateTime, doc='Timestamp of last ue attach to PCSCF') pcscf_peer = Column(String(512), doc='Diameter peer used to reach PCSCF') sh_profile = Column(Text(12000), doc='Sh Subscriber Profile') @@ -1626,8 +1627,8 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ self.safe_close(session) - def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, pcscf_timestamp=None, propagate=True): - self.logTool.log(service='Database', level='debug', message="Update_Proxy_CSCF for sub " + str(imsi) + " to pcscf " + str(proxy_cscf) + " with realm " + str(pcscf_realm) + " and peer " + str(pcscf_peer), redisClient=self.redisMessaging) + def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, pcscf_timestamp=None, pcscf_active_session=None, propagate=True): + self.logTool.log(service='Database', level='debug', message="Update_Proxy_CSCF for sub " + str(imsi) + " to pcscf " + str(proxy_cscf) + " with realm " + str(pcscf_realm) + " and peer " + str(pcscf_peer) + " for session id " + str(pcscf_active_session), redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() @@ -1640,6 +1641,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, #Strip duplicate SIP prefix before storing proxy_cscf = proxy_cscf.replace("sip:sip:", "sip:") result.pcscf = proxy_cscf + result.pcscf_active_session = pcscf_active_session try: if pcscf_timestamp is not None and pcscf_timestamp is not 'None': result.pcscf_timestamp = datetime.strptime(pcscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') @@ -1660,6 +1662,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, result.pcscf_timestamp = None result.pcscf_realm = None result.pcscf_peer = None + result.pcscf_active_session = None pcscf_timestamp_string = None session.commit() @@ -1670,7 +1673,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, if propagate == True: if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) - self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": str(result.pcscf_peer)}) + self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": str(result.pcscf_peer), "pcscf_active_session": str(pcscf_active_session)}) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) except Exception as E: @@ -1988,8 +1991,22 @@ def Get_UE_by_IP(self, subscriber_routing): result.pop('_sa_instance_state') result = self.Sanitize_Datetime(result) return result - #Get Subscriber ID from IMSI - subscriber_details = Get_Subscriber(imsi=str(imsi)) + + def Get_IMS_Subscriber_By_Session_Id(self, sessionId): + self.logTool.log(service='Database', level='debug', message="Called Get_IMS_Subscriber_By_Session_Id() for Session " + str(sessionId), redisClient=self.redisMessaging) + + Session = sessionmaker(bind = self.engine) + session = Session() + + try: + result = session.query(IMS_SUBSCRIBER).filter_by(pcscf_active_session=sessionId).one() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + result = self.Sanitize_Datetime(result) + return result def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): #IMSI 14-15 Digits diff --git a/lib/diameter.py b/lib/diameter.py index dbd4fb5..c592492 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -913,84 +913,101 @@ def AVP_278_Origin_State_Incriment(self, avps): origin_state_incriment_hex = format(origin_state_incriment_int,"x").zfill(8) return origin_state_incriment_hex - def Charging_Rule_Generator(self, ChargingRules, ue_ip): - self.logTool.log(service='HSS', level='debug', message="Called Charging_Rule_Generator", redisClient=self.redisMessaging) - #Install Charging Rules - self.logTool.log(service='HSS', level='debug', message="Naming Charging Rule", redisClient=self.redisMessaging) - Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) - self.logTool.log(service='HSS', level='debug', message="Named Charging Rule", redisClient=self.redisMessaging) - - #Populate all Flow Information AVPs - Flow_Information = '' - for tft in ChargingRules['tft']: - self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) - #If {{ UE_IP }} in TFT splice in the real UE IP Value - try: - tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) - tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) - self.logTool.log(service='HSS', level='debug', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) - except Exception as E: - self.logTool.log(service='HSS', level='error', message="Failed to splice in UE IP into flow description", redisClient=self.redisMessaging) - - #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional - Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) - Flow_Description = self.generate_vendor_avp(507, "c0", 10415, str(binascii.hexlify(str.encode(tft['tft_string'])),'ascii')) - Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) - - Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) - self.logTool.log(service='HSS', level='debug', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) - - self.logTool.log(service='HSS', level='debug', message="Defining QoS information", redisClient=self.redisMessaging) - #QCI - QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) - - #ARP - self.logTool.log(service='HSS', level='debug', message="Defining ARP information", redisClient=self.redisMessaging) - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_vulnerability']), 4)) - ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - - self.logTool.log(service='HSS', level='debug', message="Defining MBR information", redisClient=self.redisMessaging) - #Max Requested Bandwidth - Bandwidth_info = '' - Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) - Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) - - self.logTool.log(service='HSS', level='debug', message="Defining GBR information", redisClient=self.redisMessaging) - #GBR - if int(ChargingRules['gbr_ul']) != 0: - Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) - if int(ChargingRules['gbr_dl']) != 0: - Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) - self.logTool.log(service='HSS', level='debug', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) - - #Populate QoS Information - QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) - self.logTool.log(service='HSS', level='debug', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) + def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleName=None, action="install"): + self.logTool.log(service='HSS', level='debug', message=f"Called Charging_Rule_Generator with action: {action}", redisClient=self.redisMessaging) + if action not in ['install', 'remove']: + self.logTool.log(service='HSS', level='debug', message="Invalid action supplied to Charging_Rule_Generator", redisClient=self.redisMessaging) + return None - #Precedence - self.logTool.log(service='HSS', level='debug', message="Defining Precedence information", redisClient=self.redisMessaging) - Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) - self.logTool.log(service='HSS', level='debug', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) - - #Rating Group - self.logTool.log(service='HSS', level='debug', message="Defining Rating Group information", redisClient=self.redisMessaging) - if ChargingRules['rating_group'] != None: - RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID - else: - RatingGroup = '' - self.logTool.log(service='HSS', level='debug', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) + if action == 'remove': + if chargingRuleName is None: + self.logTool.log(service='HSS', level='error', message="chargingRuleName must be defined when removing a charging rule", redisClient=self.redisMessaging) + return None + Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(chargingRuleName))),'ascii')) + ChargingRuleDef = Charging_Rule_Name + return self.generate_vendor_avp(1002, "c0", 10415, ChargingRuleDef) + else: + if ChargingRules is None or ue_ip is None: + self.logTool.log(service='HSS', level='error', message="ChargingRules and ue_ip must be defined when installing a charging rule", redisClient=self.redisMessaging) + return None + + #Install Charging Rules + self.logTool.log(service='HSS', level='debug', message="Naming Charging Rule", redisClient=self.redisMessaging) + Charging_Rule_Name = self.generate_vendor_avp(1005, "c0", 10415, str(binascii.hexlify(str.encode(str(ChargingRules['rule_name']))),'ascii')) + self.logTool.log(service='HSS', level='debug', message="Named Charging Rule", redisClient=self.redisMessaging) + + #Populate all Flow Information AVPs + Flow_Information = '' + for tft in ChargingRules['tft']: + self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) + #If {{ UE_IP }} in TFT splice in the real UE IP Value + try: + tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) + tft['tft_string'] = tft['tft_string'].replace('{{UE_IP}}', str(ue_ip)) + self.logTool.log(service='HSS', level='debug', message="Spliced in UE IP into TFT: " + str(tft['tft_string']), redisClient=self.redisMessaging) + except Exception as E: + self.logTool.log(service='HSS', level='error', message="Failed to splice in UE IP into flow description", redisClient=self.redisMessaging) + + #Valid Values for Flow_Direction: 0- Unspecified, 1 - Downlink, 2 - Uplink, 3 - Bidirectional + Flow_Direction = self.generate_vendor_avp(1080, "80", 10415, self.int_to_hex(tft['direction'], 4)) + Flow_Description = self.generate_vendor_avp(507, "c0", 10415, str(binascii.hexlify(str.encode(tft['tft_string'])),'ascii')) + Flow_Information += self.generate_vendor_avp(1058, "80", 10415, Flow_Direction + Flow_Description) + + Flow_Status = self.generate_vendor_avp(511, "c0", 10415, self.int_to_hex(2, 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Flow_Status: " + str(Flow_Status), redisClient=self.redisMessaging) + + self.logTool.log(service='HSS', level='debug', message="Defining QoS information", redisClient=self.redisMessaging) + #QCI + QCI = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(ChargingRules['qci'], 4)) + + #ARP + self.logTool.log(service='HSS', level='debug', message="Defining ARP information", redisClient=self.redisMessaging) + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(ChargingRules['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not ChargingRules['arp_preemption_vulnerability']), 4)) + ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + + self.logTool.log(service='HSS', level='debug', message="Defining MBR information", redisClient=self.redisMessaging) + #Max Requested Bandwidth + Bandwidth_info = '' + Bandwidth_info += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_ul']), 4)) + Bandwidth_info += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(int(ChargingRules['mbr_dl']), 4)) + + self.logTool.log(service='HSS', level='debug', message="Defining GBR information", redisClient=self.redisMessaging) + #GBR + if int(ChargingRules['gbr_ul']) != 0: + Bandwidth_info += self.generate_vendor_avp(1026, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_ul']), 4)) + if int(ChargingRules['gbr_dl']) != 0: + Bandwidth_info += self.generate_vendor_avp(1025, "c0", 10415, self.int_to_hex(int(ChargingRules['gbr_dl']), 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Bandwith Info: " + str(Bandwidth_info), redisClient=self.redisMessaging) + + #Populate QoS Information + QoS_Information = self.generate_vendor_avp(1016, "c0", 10415, QCI + ARP + Bandwidth_info) + self.logTool.log(service='HSS', level='debug', message="Defined QoS_Information: " + str(QoS_Information), redisClient=self.redisMessaging) + + #Precedence + self.logTool.log(service='HSS', level='debug', message="Defining Precedence information", redisClient=self.redisMessaging) + Precedence = self.generate_vendor_avp(1010, "c0", 10415, self.int_to_hex(ChargingRules['precedence'], 4)) + self.logTool.log(service='HSS', level='debug', message="Defined Precedence " + str(Precedence), redisClient=self.redisMessaging) + + #Rating Group + self.logTool.log(service='HSS', level='debug', message="Defining Rating Group information", redisClient=self.redisMessaging) + if ChargingRules['rating_group'] != None: + RatingGroup = self.generate_avp(432, 40, format(int(ChargingRules['rating_group']),"x").zfill(8)) #Rating-Group-ID + else: + RatingGroup = '' + self.logTool.log(service='HSS', level='debug', message="Defined Rating Group " + str(ChargingRules['rating_group']), redisClient=self.redisMessaging) + - #Complete Charging Rule Defintion - self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) - ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup - ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) + #Complete Charging Rule Defintion + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + ChargingRuleDef = Charging_Rule_Name + Flow_Information + Flow_Status + QoS_Information + Precedence + RatingGroup + ChargingRuleDef = self.generate_vendor_avp(1003, "c0", 10415, ChargingRuleDef) - #Charging Rule Install - self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) - return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) + #Charging Rule Install + self.logTool.log(service='HSS', level='debug', message="Collating ChargingRuleDef", redisClient=self.redisMessaging) + return self.generate_vendor_avp(1001, "c0", 10415, ChargingRuleDef) def Get_IMS_Subscriber_Details_from_AVP(self, username): #Feed the Username AVP with Tel URI, SIP URI and either MSISDN or IMSI and this returns user data @@ -2261,8 +2278,8 @@ def Answer_16777236_265(self, packet_vars, avps): The response is determined by whether or not the subscriber is enabled, and has a matching ims_subscriber entry. """ avp = '' - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + sessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') #Get Session-ID + avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID avp += self.generate_avp(258, 40, format(int(16777236),"x").zfill(8)) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm @@ -2318,8 +2335,8 @@ def Answer_16777236_265(self, packet_vars, avps): remotePeer = aarOriginHost remotePeer = f"{remotePeer};{self.config['hss']['OriginHost']}" - - self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer) + + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=None) """ Check for AVP's 504 (AF-Application-Identifier) and 520 (Media-Type), which indicates the UE is making a call. Media-Type: 0 = Audio, 4 = Control @@ -2383,6 +2400,8 @@ def Answer_16777236_265(self, packet_vars, avps): ] } + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) + reAuthAnswer = self.awaitDiameterRequestAndResponse( requestType='RAR', hostname=servingPgwPeer, @@ -2499,19 +2518,69 @@ def Answer_16777236_258(self, packet_vars, avps): def Answer_16777236_275(self, packet_vars, avps): try: """ - Generates a response to a provided STR. - Returns Result-Code 2001. + Triggers a Re-Auth-Request to the PGW, the returns a Session Termination Answer. """ avp = '' - session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID - avp += self.generate_avp(263, 40, session_id) #Set session ID to received session ID + sessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') #Get Session-ID + avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) + imsi = imsSubscriber.get('imsi', None) + pcscf = imsSubscriber.get('pcscf', None) + pcscf_realm = imsSubscriber.get('pcscf_realm', None) + pcscf_peer = imsSubscriber.get('pcscf_peer', None) + subscriber = self.database.Get_Subscriber(imsi=imsi) + subscriberId = subscriber.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + + if servingApn is not None: + servingPgw = servingApn.get('serving_pgw', '') + servingPgwRealm = servingApn.get('serving_pgw_realm', '') + servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] + pcrfSessionId = servingApn.get('pcrf_session_id', None) + reAuthAnswer = self.awaitDiameterRequestAndResponse( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + servingPgw=servingPgw, + servingRealm=servingPgwRealm, + chargingRuleName='GBR-Voice', + chargingRuleAction='remove' + ) + + if not len(reAuthAnswer) > 0: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) + assert() + + raaPacketVars, raaAvps = self.decode_diameter_packet(reAuthAnswer) + raaResultCode = int(self.get_avp_data(raaAvps, 268)[0], 16) + + if raaResultCode == 2001: + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA returned Successfully, authorizing request", redisClient=self.redisMessaging) + else: + avp += self.generate_avp(268, 40, self.int_to_hex(5001, 4)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA returned Unauthorized, returning Result-Code 5001", redisClient=self.redisMessaging) + + else: + self.logTool.log(service='HSS', level='info', message=f"[diameter.py] [Answer_16777236_275] [STA] Unable to find serving APN for RAR, returning Result-Code 2001", redisClient=self.redisMessaging) + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response except Exception as e: - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA: {traceback.format_exc()}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001: {traceback.format_exc()}", redisClient=self.redisMessaging) + avp = '' + sessionId = self.get_avp_data(avps, 263)[0] #Get Session-ID + avp += self.generate_avp(263, 40, sessionId) #Set session ID to received session ID + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) + response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response #3GPP Rx - Abort Session Answer (ASA) def Answer_16777236_274(self, packet_vars, avps): @@ -3326,18 +3395,21 @@ def Request_16777238_272(self, imsi, apn, ccr_type, destinationHost, destination return response #3GPP Gx - Re Auth Request - def Request_16777238_258(self, sessionId, chargingRules, ueIp, servingPgw, servingRealm): + def Request_16777238_258(self, sessionId, servingPgw, servingRealm, chargingRules=None, ueIp=None, chargingRuleAction='install', chargingRuleName=None): avp = '' self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Creating Re Auth Request", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Charging Rules: {chargingRules}", redisClient=self.redisMessaging) - avp += self.generate_avp(263, 40, str(binascii.hexlify(str.encode(sessionId)),'ascii')) #Session-Id set AVP #Setup Charging Rule self.logTool.log(service='HSS', level='debug', message=chargingRules, redisClient=self.redisMessaging) - avp += self.Charging_Rule_Generator(ChargingRules=chargingRules, ue_ip=ueIp) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Generated Charging Rules", redisClient=self.redisMessaging) + if chargingRules is not None and ueIp is not None: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Charging Rules: {chargingRules}", redisClient=self.redisMessaging) + avp += self.Charging_Rule_Generator(ChargingRules=chargingRules, ue_ip=ueIp) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Generated Charging Rules", redisClient=self.redisMessaging) + elif chargingRuleName is not None and chargingRuleAction == 'remove': + avp += self.Charging_Rule_Generator(action=chargingRuleAction, chargingRuleName=chargingRuleName) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Request_16777238_258] [RAR] Removing Charging Rule: {chargingRuleName}", redisClient=self.redisMessaging) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm diff --git a/services/apiService.py b/services/apiService.py index 3a92ce8..db59956 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -248,7 +248,6 @@ def page_not_found(e): @apiService.after_request def apply_caching(response): response.headers["HSS"] = str(config['hss']['OriginHost']) - response.headers["Access-Control-Allow-Origin"] = "*" return response @ns_apn.route('/') diff --git a/tools/databaseUpgrade/alembic/env.py b/tools/databaseUpgrade/alembic/env.py index b0dfb0f..4cf83b4 100644 --- a/tools/databaseUpgrade/alembic/env.py +++ b/tools/databaseUpgrade/alembic/env.py @@ -31,7 +31,7 @@ def get_url_from_config() -> str: """ Reads config.yaml and returns the database url. """ - with open("config.yaml", 'r') as stream: + with open("../../config.yaml", 'r') as stream: try: config = yaml.safe_load(stream) db_string = 'mysql://' + str(config['database']['username']) + ':' + str(config['database']['password']) + '@' + str(config['database']['server']) + '/' + str(config['database']['database']) diff --git a/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py b/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py deleted file mode 100644 index d92fa9f..0000000 --- a/tools/databaseUpgrade/alembic/versions/2ad87e0c0c76_service_overhaul_revision.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Service Overhaul revision - -Revision ID: 2ad87e0c0c76 -Revises: -Create Date: 2023-09-29 04:28:33.635508 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '2ad87e0c0c76' -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('ims_subscriber', sa.Column('pcscf', sa.String(length=512), nullable=True)) - op.add_column('ims_subscriber', sa.Column('pcscf_realm', sa.String(length=512), nullable=True)) - op.add_column('ims_subscriber', sa.Column('pcscf_timestamp', sa.DateTime(), nullable=True)) - op.add_column('ims_subscriber', sa.Column('pcscf_peer', sa.String(length=512), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('ims_subscriber', 'pcscf_peer') - op.drop_column('ims_subscriber', 'pcscf_timestamp') - op.drop_column('ims_subscriber', 'pcscf_realm') - op.drop_column('ims_subscriber', 'pcscf') - # ### end Alembic commands ### From 7b355b51fce1411d8b6d690cd286a338d1e093c4 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 3 Oct 2023 21:03:50 +1000 Subject: [PATCH 36/43] Update geored schema --- services/apiService.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/services/apiService.py b/services/apiService.py index db59956..b6d2152 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -150,6 +150,11 @@ 'serving_mme_timestamp' : fields.String(description=SUBSCRIBER.serving_mme_timestamp.doc), 'serving_apn' : fields.String(description='Access Point Name of APN'), 'pcrf_session_id' : fields.String(description=Serving_APN.pcrf_session_id.doc), + 'pcscf' : fields.String(description=IMS_SUBSCRIBER.pcscf.doc), + 'pcscf_realm' : fields.String(description=IMS_SUBSCRIBER.pcscf_realm.doc), + 'pcscf_peer' : fields.String(description=IMS_SUBSCRIBER.pcscf_peer.doc), + 'pcscf_timestamp' : fields.String(description=IMS_SUBSCRIBER.pcscf_timestamp.doc), + 'pcscf_active_session' : fields.String(description=IMS_SUBSCRIBER.pcscf_active_session.doc), 'subscriber_routing' : fields.String(description=Serving_APN.subscriber_routing.doc), 'serving_pgw' : fields.String(description=Serving_APN.serving_pgw.doc), 'serving_pgw_realm' : fields.String(description=Serving_APN.serving_pgw_realm.doc), @@ -1552,7 +1557,9 @@ def patch(self): json_data['pcscf_peer'] = None if 'pcscf_timestamp' not in json_data: json_data['pcscf_timestamp'] = None - response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=str(json_data['pcscf_realm']), pcscf_peer=str(json_data['pcscf_peer']), pcscf_timestamp=json_data['pcscf_timestamp'], propagate=False)) + if 'pcscf_active_session' not in json_data: + json_data['pcscf_active_session'] = None + response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=str(json_data['pcscf_realm']), pcscf_peer=str(json_data['pcscf_peer']), pcscf_timestamp=json_data['pcscf_timestamp'], pcscf_active_session=str(json_data['pcscf_active_session']), propagate=False)) redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes Received', From 612d9fac73d14c7ecc9f5511f20d673bf2d95199 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 4 Oct 2023 16:48:40 +1000 Subject: [PATCH 37/43] All CPU overutilzation fixed, add benchmarking to config.yaml, minor fixes --- CHANGELOG.md | 3 + config.yaml | 9 ++- lib/database.py | 17 ++--- lib/diameter.py | 16 ++--- lib/messaging.py | 19 +++++- lib/messagingAsync.py | 18 +++++- services/apiService.py | 29 ++++++--- services/diameterService.py | 120 ++++++++++++++++-------------------- services/georedService.py | 58 +++++------------ services/hssService.py | 26 +++----- services/logService.py | 21 ++----- services/metricService.py | 7 +-- 12 files changed, 163 insertions(+), 180 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 792f08e..aadc806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Configurable redis connection (Unix socket or TCP) - Basic database upgrade support in tools/databaseUpgrade - PCSCF state storage in ims_subscriber + - (Experimental) Working horizontal scalability ### Changed @@ -30,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Logtool no longer handles metric processing - Updated config.yaml - Gx CCR-T now flushes PGW / IMS data, depending on Called-Station-Id +- Benchmarked lossless at ~100 diameter requests per second, per hssService. ### Fixed @@ -37,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Gx CCA now supports apn inside a plmn based uri - AVP_Preemption_Capability and AVP_Preemption_Vulnerability now presents correctly in all diameter messages - Crash when webhook or geored endpoints enabled and no peers defined + - CPU overutilization on all services ### Removed diff --git a/config.yaml b/config.yaml index 5f01224..a12d857 100644 --- a/config.yaml +++ b/config.yaml @@ -33,9 +33,6 @@ hss: #The maximum time to wait, in seconds, before disconnecting a client when no data is received. client_socket_timeout: 120 - #Enable benchmarking log output for response times - set to False in production. - enable_benchmarking: False - #The maximum time to wait, in seconds, before disconnecting a client when no data is received. client_socket_timeout: 300 @@ -71,6 +68,12 @@ hss: api: page_size: 200 +benchmarking: + # Whether to enable benchmark logging + enabled: True + # How often to report, in seconds. Not all benchmarking supports interval reporting. + reporting_interval: 3600 + eir: imsi_imei_logging: True #Store current IMEI / IMSI pair in backend no_match_response: 2 #Greylist diff --git a/lib/database.py b/lib/database.py index 41f4823..3fa3569 100755 --- a/lib/database.py +++ b/lib/database.py @@ -864,13 +864,15 @@ def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, if len(self.config.get('geored', {}).get('endpoints', [])) > 0: georedDict['body'] = jsonData georedDict['operation'] = operation - self.redisMessaging.sendMessage(queue=f'geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) + georedDict['timestamp'] = time.time_ns() + self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120) if asymmetric: if len(asymmetricUrls) > 0: georedDict['body'] = jsonData georedDict['operation'] = operation + georedDict['timestamp'] = time.time_ns() georedDict['urls'] = asymmetricUrls - self.redisMessaging.sendMessage(queue=f'asymmetric-geored-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(georedDict), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120) return True except Exception as E: @@ -897,7 +899,8 @@ def handleWebhook(self, objectData, operation: str="PATCH"): webhook['body'] = self.Sanitize_Datetime(objectData) webhook['headers'] = webhookHeaders webhook['operation'] = operation - self.redisMessaging.sendMessage(queue=f'webhook-{uuid.uuid4()}-{time.time_ns()}', message=json.dumps(webhook), queueExpiry=120) + webhook['timestamp'] = time.time_ns() + self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120) return True def Sanitize_Datetime(self, result): @@ -1580,7 +1583,7 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ self.logTool.log(service='Database', level='debug', message="Updating serving MME & Timestamp", redisClient=self.redisMessaging) result.serving_mme = serving_mme try: - if serving_mme_timestamp is not None and serving_mme_timestamp is not 'None': + if serving_mme_timestamp != None and serving_mme_timestamp != 'None': result.serving_mme_timestamp = datetime.strptime(serving_mme_timestamp, '%Y-%m-%dT%H:%M:%SZ') result.serving_mme_timestamp = result.serving_mme_timestamp.replace(tzinfo=timezone.utc) serving_mme_timestamp_string = result.serving_mme_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') @@ -1643,7 +1646,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, result.pcscf = proxy_cscf result.pcscf_active_session = pcscf_active_session try: - if pcscf_timestamp is not None and pcscf_timestamp is not 'None': + if pcscf_timestamp != None and pcscf_timestamp != 'None': result.pcscf_timestamp = datetime.strptime(pcscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') result.pcscf_timestamp = result.pcscf_timestamp.replace(tzinfo=timezone.utc) pcscf_timestamp_string = result.pcscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') @@ -1698,7 +1701,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N serving_cscf = serving_cscf.replace("sip:sip:", "sip:") result.scscf = serving_cscf try: - if scscf_timestamp is not None and scscf_timestamp is not 'None': + if scscf_timestamp != None and scscf_timestamp != 'None': result.scscf_timestamp = datetime.strptime(scscf_timestamp, '%Y-%m-%dT%H:%M:%SZ') result.scscf_timestamp = result.scscf_timestamp.replace(tzinfo=timezone.utc) scscf_timestamp_string = result.scscf_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') @@ -1770,7 +1773,7 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber self.logTool.log(service='Database', level='debug', message="APN ID is " + str(apn_id), redisClient=self.redisMessaging) try: - if serving_pgw_timestamp is not None and serving_pgw_timestamp is not 'None': + if serving_pgw_timestamp != None and serving_pgw_timestamp != 'None': serving_pgw_timestamp = datetime.strptime(serving_pgw_timestamp, '%Y-%m-%dT%H:%M:%SZ') serving_pgw_timestamp = serving_pgw_timestamp.replace(tzinfo=timezone.utc) serving_pgw_timestamp_string = serving_pgw_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ') diff --git a/lib/diameter.py b/lib/diameter.py index c592492..603ba4a 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -607,8 +607,9 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: peerPort = connectedPeer['port'] request = diameterApplication["requestMethod"](**kwargs) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) - outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" - outboundMessage = json.dumps({'diameter-outbound': request}) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + sendTime = time.time_ns() + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request @@ -636,8 +637,9 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> peerPort = connectedPeer['port'] request = diameterApplication["requestMethod"](**kwargs) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) - outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{time.time_ns()}" - outboundMessage = json.dumps({'diameter-outbound': request}) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + sendTime = time.time_ns() + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Queueing for peer type: {peerType} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return connectedPeerList @@ -677,8 +679,8 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo sessionId = kwargs.get('sessionId', None) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) sendTime = time.time_ns() - outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}-{sendTime}" - outboundMessage = json.dumps({'diameter-outbound': request}) + outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" + outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) startTimer = time.time() @@ -2572,7 +2574,7 @@ def Answer_16777236_275(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response except Exception as e: - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001: {traceback.format_exc()}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001: {traceback.format_exc()}", redisClient=self.redisMessaging) avp = '' sessionId = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, sessionId) #Set session ID to received session ID diff --git a/lib/messaging.py b/lib/messaging.py index 62b4a78..cc78835 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -33,6 +33,8 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA return 'Invalid Argument: metricValue must be a digit' metricValue = float(metricValue) prometheusMetricBody = json.dumps([{ + 'serviceName': serviceName, + 'timestamp': metricTimestamp, 'NAME': metricName, 'TYPE': metricType, 'HELP': metricHelp, @@ -42,7 +44,7 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA } ]) - metricQueueName = f"metric-{serviceName}-{metricTimestamp}-{uuid.uuid4()}" + metricQueueName = f"metric" try: self.redisClient.rpush(metricQueueName, prometheusMetricBody) @@ -57,8 +59,8 @@ def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, mes Stores a message in a given Queue (Key). """ try: - logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}-{uuid.uuid4()}" - logMessage = json.dumps({"message": message}) + logQueueName = f"log" + logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) self.redisClient.rpush(logQueueName, logMessage) if logExpiry is not None: self.redisClient.expire(logQueueName, logExpiry) @@ -103,6 +105,17 @@ def getNextQueue(self, pattern: str='*') -> dict: except Exception as e: return {} + def awaitMessage(self, key: str): + """ + Blocks until a message is received at the given key, then returns the message. + """ + try: + message = self.redisClient.blpop(key) + return tuple(data.decode() for data in message) + except Exception as e: + return '' + + def deleteQueue(self, queue: str) -> bool: """ Deletes the given Queue (Key) diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index baed706..78f2ca7 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -37,6 +37,8 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m return 'Invalid Argument: metricValue must be a digit' metricValue = float(metricValue) prometheusMetricBody = json.dumps([{ + 'serviceName': serviceName, + 'timestamp': metricTimestamp, 'NAME': metricName, 'TYPE': metricType, 'HELP': metricHelp, @@ -46,7 +48,7 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m } ]) - metricQueueName = f"metric-{serviceName}-{metricTimestamp}-{uuid.uuid4()}" + metricQueueName = f"metric" try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -63,8 +65,8 @@ async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: in Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: - logQueueName = f"log-{serviceName}-{logLevel}-{logTimestamp}-{uuid.uuid4()}" - logMessage = json.dumps({"message": message}) + logQueueName = f"log" + logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.rpush(logQueueName, logMessage) if logExpiry is not None: @@ -117,6 +119,16 @@ async def getNextQueue(self, pattern: str='*') -> str: print(e) return '' + async def awaitMessage(self, key: str): + """ + Asynchronously blocks until a message is received at the given key, then returns the message. + """ + try: + message = (await(self.redisClient.blpop(key))) + return tuple(data.decode() for data in message) + except Exception as e: + return '' + async def deleteQueue(self, queue: str) -> bool: """ Deletes the given Queue (Key) asynchronously. diff --git a/services/apiService.py b/services/apiService.py index b6d2152..2c195d2 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1444,22 +1444,33 @@ def put(self): print("subscriber_data: " + str(subscriber_data)) #Get PCRF Session - pcrf_session_data = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) - print("pcrf_session_data: " + str(pcrf_session_data)) + servingApn = databaseClient.Get_Serving_APN(subscriber_id=subscriber_data['subscriber_id'], apn_id=json_data['apn_id']) + print("pcrf_session_data: " + str(servingApn)) #Get Charging Rules ChargingRule = databaseClient.Get_Charging_Rule(json_data['charging_rule_id']) ChargingRule['apn_data'] = databaseClient.Get_APN(json_data['apn_id']) print("Got ChargingRule: " + str(ChargingRule)) - diameterRequest = diameterClient.Request_16777238_258(pcrf_session_data['pcrf_session_id'], ChargingRule, pcrf_session_data['subscriber_routing'], pcrf_session_data['serving_pgw'], 'ServingRealm.com') - connectedPgws = diameterClient.getConnectedPeersByType('pgw') - for connectedPgw in connectedPgws: - outboundQueue = f"diameter-outbound-{connectedPgw.get('ipAddress')}-{connectedPgw.get('port')}-{time.time_ns()}" - outboundMessage = json.dumps({"diameter-outbound": diameterRequest}) - redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + subscriberId = subscriber_data.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + ueIp = servingApn.get('subscriber_routing', None) + + diameterResponse = diameterClient.sendDiameterRequest( + requestType='RAR', + hostname=servingPgwPeer, + sessionId=pcrfSessionId, + chargingRules=ChargingRule, + ueIp=ueIp, + servingPgw=servingPgw, + servingRealm=servingPgwRealm + ) - result = {"request": diameterRequest, "destinationClients": connectedPgws} + result = {"Result": "Successfully sent Gx RAR", "destinationClients": str(servingPgw)} return result, 200 @ns_pcrf.route('/') diff --git a/services/diameterService.py b/services/diameterService.py index 0969170..dc2e9f8 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -36,7 +36,10 @@ def __init__(self): self.diameterLibrary = DiameterAsync(logTool=self.logTool) self.activePeers = {} self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) - self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + self.benchmarking = self.config.get('benchmarking', {}).get('enabled', False) + self.benchmarkingInterval = self.config.get('benchmarking', {}).get('reporting_interval', 3600) + self.diameterRequests = 0 + self.diameterResponses = 0 async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ @@ -44,28 +47,17 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb """ try: packetVars, avps = await(self.diameterLibrary.decodeDiameterPacket(inboundData)) - messageType = await(self.diameterLibrary.getDiameterMessageType(inboundData)) originHost = (await(self.diameterLibrary.getAvpData(avps, 264)))[0] originHost = bytes.fromhex(originHost).decode("utf-8") peerType = await(self.diameterLibrary.getPeerType(originHost)) - self.activePeers[f"{clientAddress}-{clientPort}"].update({'lastDwrTimestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") if messageType['inbound'] == 'DWR' else self.activePeers[f"{clientAddress}-{clientPort}"]['lastDwrTimestamp'], - 'diameterHostname': originHost, - 'peerType': peerType, - }) - await(self.redisReaderMessaging.sendMetric(serviceName='diameter', metricName='prom_diam_inbound_count', - metricType='counter', metricAction='inc', - metricValue=1.0, metricHelp='Number of Diameter Inbounds', - metricLabels={ - "diameter_application_id": str(packetVars["ApplicationId"]), - "diameter_cmd_code": str(packetVars["command_code"]), - "endpoint": originHost, - "type": "inbound"}, - metricExpiry=60)) + self.activePeers[f"{clientAddress}-{clientPort}"].update({'diameterHostname': originHost, + 'peerType': peerType, + }) + return True except Exception as e: await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] Exception: {e}\n{traceback.format_exc()}")) await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [validateDiameterInbound] AVPs: {avps}\nPacketVars: {packetVars}")) return False - return True async def handleActiveDiameterPeers(self): """ @@ -110,20 +102,34 @@ async def logActivePeers(self): activePeers = '' await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logActivePeers] {len(self.activePeers)} Active Peers {activePeers}")) + async def logProcessedMessages(self): + """ + Logs the number of processed messages on a rolling basis. + """ + if not self.benchmarking: + return False + + benchmarkInterval = int(self.benchmarkingInterval) + + while True: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logProcessedMessages] Processed {self.diameterRequests} inbound diameter messages in the last {self.benchmarkingInterval} second(s)")) + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [logProcessedMessages] Processed {self.diameterResponses} outbound in the last {self.benchmarkingInterval} second(s)")) + self.diameterRequests = 0 + self.diameterResponses = 0 + await(asyncio.sleep(benchmarkInterval)) + async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. Terminates the connection if diameter traffic is not received, or if the client disconnects. """ await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}")) + peerIsValidated = False while True: try: inboundData = await(asyncio.wait_for(reader.read(8192), timeout=socketTimeout)) - if self.benchmarking: - startTime = time.perf_counter() - if reader.at_eof(): await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.")) return False @@ -131,26 +137,22 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc if len(inboundData) > 0: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}")) - if self.benchmarking: - diamteterValidationStartTime = time.perf_counter() - if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): - await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) - await(asyncio.sleep(0.001)) - continue - if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to validate diameter request: {round(((time.perf_counter() - diamteterValidationStartTime)*1000), 3)} ms")) - - - diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=inboundData)) - diameterMessageType = diameterMessageType.get('inbound', '') + if not peerIsValidated: + if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) + await(asyncio.sleep(0)) + continue + else: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Validated peer: {clientAddress} on port {clientPort}")) + peerIsValidated = True - inboundQueueName = f"diameter-inbound-{clientAddress}-{clientPort}-{diameterMessageType}-{time.time_ns()}" - inboundHexString = json.dumps({f"diameter-inbound": inboundData.hex()}) - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] [{diameterMessageType}] Queueing {inboundHexString}")) + inboundQueueName = f"diameter-inbound" + inboundHexString = json.dumps({"diameter-inbound": inboundData.hex(), "inbound-received-timestamp": time.time_ns(), "clientAddress": clientAddress, "clientPort": clientPort}) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Queueing {inboundHexString}")) await(self.redisReaderMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0.001)) + self.diameterRequests += 1 + await(asyncio.sleep(0)) except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}")) @@ -158,44 +160,24 @@ async def readInboundData(self, reader, clientAddress: str, clientPort: str, soc async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ - Continually polls the Redis queue for outbound messages. Received messages from the queue are validated against the connected client, and sent. + Waits for a message to be received from Redis, then sends to the connected client. """ await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] writeOutboundData with host {clientAddress} on port {clientPort}")) - while True: + while not writer.transport.is_closing(): try: + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Waiting for messages for host {clientAddress} on port {clientPort}")) + pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}")))[1]) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Received message: {pendingOutboundMessage} for host {clientAddress} on port {clientPort}")) + diameterOutboundBinary = bytes.fromhex(pendingOutboundMessage.get('diameter-outbound', '')) + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) + writer.write(diameterOutboundBinary) + await(writer.drain()) if self.benchmarking: - startTime = time.perf_counter() - - if writer.transport.is_closing(): - return False - - pendingOutboundQueue = await(self.redisWriterMessaging.getNextQueue(pattern=f'diameter-outbound-{clientAddress.replace(".", "*")}-{clientPort}-*')) - if not len(pendingOutboundQueue) > 0: - await(asyncio.sleep(0.01)) - continue - pendingOutboundQueue = pendingOutboundQueue - outboundQueueSplit = str(pendingOutboundQueue).split('-') - queuedMessageType = outboundQueueSplit[1] - diameterOutboundHost = outboundQueueSplit[2] - diameterOutboundPort = outboundQueueSplit[3] - - if str(diameterOutboundHost) == str(clientAddress) and str(diameterOutboundPort) == str(clientPort) and queuedMessageType == 'outbound': - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Matched {pendingOutboundQueue} to host {clientAddress} on port {clientPort}")) - diameterOutbound = json.loads(await(self.redisWriterMessaging.getMessage(queue=pendingOutboundQueue))) - diameterOutboundBinary = bytes.fromhex(next(iter(diameterOutbound.values()))) - diameterMessageType = await(self.diameterLibrary.getDiameterMessageType(binaryData=diameterOutboundBinary)) - diameterMessageType = diameterMessageType.get('outbound', '') - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] [{diameterMessageType}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) - writer.write(diameterOutboundBinary) - await(writer.drain()) - await(asyncio.sleep(0.001)) - if self.benchmarking: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Time taken to write response: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - + self.diameterResponses += 1 + await(asyncio.sleep(0)) except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.")) return False - await(asyncio.sleep(0.001)) async def handleConnection(self, reader, writer): """ @@ -262,7 +244,7 @@ async def handleConnection(self, reader, writer): async def startServer(self, host: str=None, port: int=None, type: str=None): """ Start a server with the given parameters and handle new clients with self.handleConnection. - Also create a single instance of self.handleActiveDiameterPeers. + Also create a single instance of self.handleActiveDiameterPeers and self.logProcessedMessages. """ if host is None: @@ -283,6 +265,8 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) await(self.logTool.logAsync(service='Diameter', level='info', message=f"{self.banners.diameterService()}\n[Diameter] Serving on {servingAddresses}")) handleActiveDiameterPeerTask = asyncio.create_task(self.handleActiveDiameterPeers()) + if self.benchmarking: + logProcessedMessagesTask = asyncio.create_task(self.logProcessedMessages()) async with server: await(server.serve_forever()) diff --git a/services/georedService.py b/services/georedService.py index 01336fe..861e8b8 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -256,21 +256,12 @@ async def handleAsymmetricGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - asymmetricGeoredQueue = await(self.redisGeoredMessaging.getNextQueue(pattern='asymmetric-geored-*')) - if not len(asymmetricGeoredQueue) > 0: - await(asyncio.sleep(0.01)) - continue - georedMessage = await(self.redisGeoredMessaging.getMessage(queue=asymmetricGeoredQueue)) - if not len(georedMessage) > 0: - await(asyncio.sleep(0.01)) - continue - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Queue: {asymmetricGeoredQueue}")) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) - georedDict = json.loads(georedMessage) - georedOperation = georedDict['operation'] - georedBody = georedDict['body'] - georedUrls = georedDict['urls'] + georedOperation = georedMessage['operation'] + georedBody = georedMessage['body'] + georedUrls = georedMessage['urls'] georedTasks = [] for georedEndpoint in georedUrls: @@ -279,11 +270,11 @@ async def handleAsymmetricGeoredQueue(self): if self.benchmarking: await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleAsymmetricGeoredQueue] Time taken to send asymmetric geored message to specified peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0.001)) + await(asyncio.sleep(0)) except Exception as e: await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Error handling asymmetric geored queue: {e}")) - await(asyncio.sleep(0.001)) + await(asyncio.sleep(0)) continue async def handleGeoredQueue(self): @@ -295,20 +286,11 @@ async def handleGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedQueue = await(self.redisGeoredMessaging.getNextQueue(pattern='geored-*')) - if not len(georedQueue) > 0: - await(asyncio.sleep(0.01)) - continue - georedMessage = await(self.redisGeoredMessaging.getMessage(queue=georedQueue)) - if not len(georedMessage) > 0: - await(asyncio.sleep(0.01)) - continue - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Queue: {georedQueue}")) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) - georedDict = json.loads(georedMessage) - georedOperation = georedDict['operation'] - georedBody = georedDict['body'] + georedOperation = georedMessage['operation'] + georedBody = georedMessage['body'] georedTasks = [] for remotePeer in self.georedPeers: @@ -317,11 +299,11 @@ async def handleGeoredQueue(self): if self.benchmarking: await(self.logTool.logAsync(service='Geored', level='info', message=f"[Geored] [handleGeoredQueue] Time taken to send geored message to all geored peers: {round(((time.perf_counter() - startTime)*1000), 3)} ms")) - await(asyncio.sleep(0.001)) + await(asyncio.sleep(0)) except Exception as e: await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Error handling geored queue: {e}")) - await(asyncio.sleep(0.001)) + await(asyncio.sleep(0)) continue async def handleWebhookQueue(self): @@ -333,21 +315,13 @@ async def handleWebhookQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - webhookQueue = await(self.redisWebhookMessaging.getNextQueue(pattern='webhook-*')) - if not len(webhookQueue) > 0: - await(asyncio.sleep(0.01)) - continue - webhookMessage = await(self.redisWebhookMessaging.getMessage(queue=webhookQueue)) - if not len(webhookMessage) > 0: - await(asyncio.sleep(0.001)) - continue - await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Queue: {webhookQueue}")) + webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1]) + await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) - webhookDict = json.loads(webhookMessage) - webhookHeaders = webhookDict['headers'] - webhookOperation = webhookDict['operation'] - webhookBody = webhookDict['body'] + webhookHeaders = webhookMessage['headers'] + webhookOperation = webhookMessage['operation'] + webhookBody = webhookMessage['body'] webhookTasks = [] for remotePeer in self.webhookPeers: diff --git a/services/hssService.py b/services/hssService.py index dc8027d..41e2a95 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml, time +import os, sys, json, yaml, time, traceback sys.path.append(os.path.realpath('../lib')) from messaging import RedisMessaging from diameter import Diameter @@ -31,7 +31,6 @@ def __init__(self): self.diameterLibrary = Diameter(logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) - def handleQueue(self): """ Gets and parses inbound diameter requests, processes them and queues the response. @@ -40,17 +39,13 @@ def handleQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - inboundQueue = self.redisMessaging.getNextQueue(pattern='diameter-inbound*') - inboundMessage = self.redisMessaging.getMessage(queue=inboundQueue) - assert(len(inboundMessage)) - inboundDict = json.loads(inboundMessage) - inboundBinary = bytes.fromhex(next(iter(inboundDict.values()))) - inboundSplit = str(inboundQueue).split('-') - inboundHost = inboundSplit[2] - inboundPort = inboundSplit[3] - inboundMessageType = inboundSplit[4] - inboundTimestamp = inboundSplit[5] + inboundMessage = json.loads(self.redisMessaging.awaitMessage(key='diameter-inbound')[1]) + + inboundBinary = bytes.fromhex(inboundMessage.get('diameter-inbound', None)) + inboundHost = inboundMessage.get('clientAddress', None) + inboundPort = inboundMessage.get('clientPort', None) + inboundTimestamp = inboundMessage.get('inbound-received-timestamp', None) try: diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) @@ -61,14 +56,13 @@ def handleQueue(self): self.logTool.log(service='HSS', level='warning', message=f"[HSS] [handleQueue] Failed to generate diameter outbound: {e}", redisClient=self.redisMessaging) continue - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound Queue: {inboundQueue}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) if not len(diameterOutbound) > 0: continue - outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}-{inboundTimestamp}" - outboundMessage = json.dumps({"diameter-outbound": diameterOutbound}) + outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}" + outboundMessage = json.dumps({"diameter-outbound": diameterOutbound, "inbound-received-timestamp": inboundTimestamp}) self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) @@ -79,7 +73,7 @@ def handleQueue(self): self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) except Exception as e: - time.sleep(0.001) + self.logTool.log(service='HSS', level='error', message=f"[HSS] [handleQueue] Exception: {traceback.format_exc()}", redisClient=self.redisMessaging) continue diff --git a/services/logService.py b/services/logService.py index 4828195..34e7ae0 100644 --- a/services/logService.py +++ b/services/logService.py @@ -39,7 +39,6 @@ def __init__(self): } print(f"{self.banners.logService()}") - def handleLogs(self): """ Continually polls the Redis DB for queued log files. Parses and writes log files to disk, using LogTool. @@ -47,24 +46,14 @@ def handleLogs(self): activeLoggers = {} while True: try: - logQueue = self.redisMessaging.getNextQueue(pattern='log-*') - logMessage = self.redisMessaging.getMessage(queue=logQueue) - - if not len(logMessage) > 0: - time.sleep(0.001) - continue + logMessage = json.loads(self.redisMessaging.awaitMessage(key='log')[1]) - print(f"[Log] Queue: {logQueue}") print(f"[Log] Message: {logMessage}") - logSplit = logQueue.split('-') - logService = logSplit[1].lower() - logLevel = logSplit[2].upper() - logTimestamp = logSplit[3] - - logDict = json.loads(logMessage) - logFileMessage = logDict['message'] - + logFileMessage = logMessage['message'] + logService = logMessage.get('service').lower() + logLevel = logMessage.get('level').lower() + logTimestamp = logMessage.get('timestamp') if f"{logService}_logging_file" not in self.logFilePaths: continue diff --git a/services/metricService.py b/services/metricService.py index d75902d..12d51c1 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -34,13 +34,8 @@ def handleMetrics(self): actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} - metricQueue = self.redisMessaging.getNextQueue(pattern='metric-*') - metric = self.redisMessaging.getMessage(queue=metricQueue) + metric = self.redisMessaging.awaitMessage(key='metric')[1] - if not (len(metric) > 0): - time.sleep(0.001) - return - self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) prometheusJsonList = json.loads(metric) From 7c7af2e4f272d9de2482e2a8a311c8563e983a8d Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 4 Oct 2023 19:53:57 +1000 Subject: [PATCH 38/43] Fixes after queue restructure --- lib/diameter.py | 75 ++++++++++++++++++++++++------------- services/diameterService.py | 1 - services/hssService.py | 17 +++++++-- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index 603ba4a..4692925 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -603,8 +603,11 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: except Exception as e: continue connectedPeer = self.getPeerByHostname(hostname=hostname) - peerIp = connectedPeer['ipAddress'] - peerPort = connectedPeer['port'] + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' request = diameterApplication["requestMethod"](**kwargs) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" @@ -633,8 +636,11 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> continue connectedPeerList = self.getConnectedPeersByType(peerType=peerType) for connectedPeer in connectedPeerList: - peerIp = connectedPeer['ipAddress'] - peerPort = connectedPeer['port'] + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' request = diameterApplication["requestMethod"](**kwargs) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" @@ -672,8 +678,11 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo except Exception as e: continue connectedPeer = self.getPeerByHostname(hostname=hostname) - peerIp = connectedPeer['ipAddress'] - peerPort = connectedPeer['port'] + try: + peerIp = connectedPeer['ipAddress'] + peerPort = connectedPeer['port'] + except Exception as e: + return '' request = diameterApplication["requestMethod"](**kwargs) responseType = diameterApplication["responseAcronym"] sessionId = kwargs.get('sessionId', None) @@ -688,29 +697,41 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo try: if not time.time() >= startTimer + timeout: if sessionId is None: - responseQueues = self.redisMessaging.getQueues(pattern=f"diameter-inbound-{peerIp.replace('.', '*')}-{peerPort}-{responseType}*") - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] responseQueues(NoSessionId): {responseQueues}", redisClient=self.redisMessaging) - for responseQueue in responseQueues: - if float(responseQueue.split('-')[5]) > sendTime: - inboundResponseList = self.redisMessaging.getMessage(queue=responseQueue) - if len(inboundResponseList) > 0: + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages(NoSessionId): {queuedMessages}", redisClient=self.redisMessaging) + for queuedMessage in queuedMessages: + queuedMessage = json.loads(queuedMessage) + clientAddress = queuedMessage.get('clientAddress', None) + clientPort = queuedMessage.get('clientPort', None) + if clientAddress != peerIp or clientPort != peerPort: + continue + messageReceiveTime = queuedMessage.get('inbound-received-timestamp', None) + if float(messageReceiveTime) > sendTime: + messageHex = queuedMessage.get('diameter-inbound') + messageType = self.getDiameterMessageType(messageHex) + if messageType['inbound'].upper() == responseType.upper(): self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Found inbound response: {inboundResponse}", redisClient=self.redisMessaging) - return json.loads(inboundResponseList[0]).get('diameter-inbound', '') + return messageHex time.sleep(0.02) else: - responseQueues = self.redisMessaging.getQueues(pattern=f"diameter-inbound-{peerIp.replace('.', '*')}-{peerPort}-{responseType}*") - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] responseQueues({sessionId}): {responseQueues} responseType: {responseType}", redisClient=self.redisMessaging) - for responseQueue in responseQueues: - if float(responseQueue.split('-')[5]) > sendTime: - inboundResponseList = self.redisMessaging.getList(key=responseQueue) - if len(inboundResponseList) > 0: - for inboundResponse in inboundResponseList: - responseHex = json.loads(inboundResponse)['diameter-inbound'] - packetVars, avps = self.decode_diameter_packet(responseHex) - responseSessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') - if responseSessionId == sessionId: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Matched on Session Id: {sessionId}", redisClient=self.redisMessaging) - return json.loads(inboundResponseList[0]).get('diameter-inbound', '') + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages({sessionId}): {queuedMessages} responseType: {responseType}", redisClient=self.redisMessaging) + for queuedMessage in queuedMessages: + queuedMessage = json.loads(queuedMessage) + clientAddress = queuedMessage.get('clientAddress', None) + clientPort = queuedMessage.get('clientPort', None) + if clientAddress != peerIp or clientPort != peerPort: + continue + messageReceiveTime = queuedMessage.get('inbound-received-timestamp', None) + if float(messageReceiveTime) > sendTime: + messageHex = queuedMessage.get('diameter-inbound') + messageType = self.getDiameterMessageType(messageHex) + if messageType['inbound'].upper() == responseType.upper(): + packetVars, avps = self.decode_diameter_packet(messageHex) + messageSessionId = bytes.fromhex(self.get_avp_data(avps, 263)[0]).decode('ascii') + if messageSessionId == sessionId: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Matched on Session Id: {sessionId}", redisClient=self.redisMessaging) + return messageHex time.sleep(0.02) else: return '' @@ -2574,7 +2595,7 @@ def Answer_16777236_275(self, packet_vars, avps): response = self.generate_diameter_packet("01", "40", 275, 16777236, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response except Exception as e: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001: {traceback.format_exc()}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_275] [STA] Error generating STA, returning 2001", redisClient=self.redisMessaging) avp = '' sessionId = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, sessionId) #Set session ID to received session ID diff --git a/services/diameterService.py b/services/diameterService.py index dc2e9f8..cebf4cf 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -196,7 +196,6 @@ async def handleConnection(self, reader, writer): "ipAddress":'', "port":'', "connectionStatus": '', - "lastDwrTimestamp": '', "diameterHostname": '', "peerType": '', } diff --git a/services/hssService.py b/services/hssService.py index 41e2a95..b1d9da5 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -43,13 +43,27 @@ def handleQueue(self): inboundMessage = json.loads(self.redisMessaging.awaitMessage(key='diameter-inbound')[1]) inboundBinary = bytes.fromhex(inboundMessage.get('diameter-inbound', None)) + if inboundBinary == None: + continue inboundHost = inboundMessage.get('clientAddress', None) inboundPort = inboundMessage.get('clientPort', None) inboundTimestamp = inboundMessage.get('inbound-received-timestamp', None) try: diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) + + if diameterOutbound == None: + continue + if not len(diameterOutbound) > 0: + continue + diameterMessageTypeDict = self.diameterLibrary.getDiameterMessageType(binaryData=inboundBinary) + + if diameterMessageTypeDict == None: + continue + if not len(diameterMessageTypeDict) > 0: + continue + diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') except Exception as e: @@ -57,9 +71,6 @@ def handleQueue(self): continue self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) - - if not len(diameterOutbound) > 0: - continue outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}" outboundMessage = json.dumps({"diameter-outbound": diameterOutbound, "inbound-received-timestamp": inboundTimestamp}) From cca7fef8e68dd7fee7923b2a319a13c5aba23309 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 4 Oct 2023 20:23:08 +1000 Subject: [PATCH 39/43] Allow null on geored patch --- services/apiService.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/services/apiService.py b/services/apiService.py index 2c195d2..83aa034 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1551,7 +1551,7 @@ def patch(self): json_data['scscf_peer'] = None if 'scscf_timestamp' not in json_data: json_data['scscf_timestamp'] = None - response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=str(json_data['scscf_realm']), scscf_peer=str(json_data['scscf_peer']), scscf_timestamp=json_data['scscf_timestamp'], propagate=False)) + response_data.append(databaseClient.Update_Serving_CSCF(imsi=str(json_data['imsi']), serving_cscf=json_data['scscf'], scscf_realm=json_data['scscf_realm'], scscf_peer=json_data['scscf_peer'], scscf_timestamp=json_data['scscf_timestamp'], propagate=False)) redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes Received', @@ -1570,7 +1570,7 @@ def patch(self): json_data['pcscf_timestamp'] = None if 'pcscf_active_session' not in json_data: json_data['pcscf_active_session'] = None - response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=str(json_data['pcscf_realm']), pcscf_peer=str(json_data['pcscf_peer']), pcscf_timestamp=json_data['pcscf_timestamp'], pcscf_active_session=str(json_data['pcscf_active_session']), propagate=False)) + response_data.append(databaseClient.Update_Proxy_CSCF(imsi=str(json_data['imsi']), proxy_cscf=json_data['pcscf'], pcscf_realm=json_data['pcscf_realm'], pcscf_peer=json_data['pcscf_peer'], pcscf_timestamp=json_data['pcscf_timestamp'], pcscf_active_session=json_data['pcscf_active_session'], propagate=False)) redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', metricType='counter', metricAction='inc', metricValue=1.0, metricHelp='Number of Geored Pushes Received', From 09f461a6ef54111ea3e12dab82ea7c0d309a21c1 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Wed, 4 Oct 2023 20:29:28 +1000 Subject: [PATCH 40/43] Remove forced string typing in database geored --- lib/database.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/database.py b/lib/database.py index 3fa3569..7def195 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1618,8 +1618,8 @@ def Update_Serving_MME(self, imsi, serving_mme, serving_mme_realm=None, serving_ self.handleGeored({ "imsi": str(imsi), "serving_mme": result.serving_mme, - "serving_mme_realm": str(result.serving_mme_realm), - "serving_mme_peer": str(result.serving_mme_peer), + "serving_mme_realm": result.serving_mme_realm, + "serving_mme_peer": result.serving_mme_peer, "serving_mme_timestamp": serving_mme_timestamp_string }) else: @@ -1676,7 +1676,7 @@ def Update_Proxy_CSCF(self, imsi, proxy_cscf, pcscf_realm=None, pcscf_peer=None, if propagate == True: if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) - self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": str(result.pcscf_realm), "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": str(result.pcscf_peer), "pcscf_active_session": str(pcscf_active_session)}) + self.handleGeored({"imsi": str(imsi), "pcscf": result.pcscf, "pcscf_realm": result.pcscf_realm, "pcscf_timestamp": pcscf_timestamp_string, "pcscf_peer": result.pcscf_peer, "pcscf_active_session": pcscf_active_session}) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) except Exception as E: @@ -1730,7 +1730,7 @@ def Update_Serving_CSCF(self, imsi, serving_cscf, scscf_realm=None, scscf_peer=N if propagate == True: if 'IMS' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate IMS changes to Geographic PyHSS instances", redisClient=self.redisMessaging) - self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": str(result.scscf_realm), "scscf_timestamp": scscf_timestamp_string, "scscf_peer": str(result.scscf_peer)}) + self.handleGeored({"imsi": str(imsi), "scscf": result.scscf, "scscf_realm": result.scscf_realm, "scscf_timestamp": scscf_timestamp_string, "scscf_peer": result.scscf_peer}) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of IMS events", redisClient=self.redisMessaging) except Exception as E: @@ -1839,13 +1839,13 @@ def Update_Serving_APN(self, imsi, apn, pcrf_session_id, serving_pgw, subscriber if 'PCRF' in self.config['geored']['sync_actions'] and self.config['geored']['enabled'] == True: self.logTool.log(service='Database', level='debug', message="Propagate PCRF changes to Geographic PyHSS instances", redisClient=self.redisMessaging) self.handleGeored({"imsi": str(imsi), - 'serving_apn' : str(apn), - 'pcrf_session_id': str(pcrf_session_id), - 'serving_pgw': str(serving_pgw), - 'serving_pgw_realm': str(serving_pgw_realm), - 'serving_pgw_peer': str(serving_pgw_peer), + 'serving_apn' : apn, + 'pcrf_session_id': pcrf_session_id, + 'serving_pgw': serving_pgw, + 'serving_pgw_realm': serving_pgw_realm, + 'serving_pgw_peer': serving_pgw_peer, 'serving_pgw_timestamp': serving_pgw_timestamp_string, - 'subscriber_routing': str(subscriber_routing) + 'subscriber_routing': subscriber_routing }) else: self.logTool.log(service='Database', level='debug', message="Config does not allow sync of PCRF events", redisClient=self.redisMessaging) From 946b20e0394f2f93ce27befacf2ef9090929a831 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 6 Oct 2023 15:34:44 +1000 Subject: [PATCH 41/43] Further performance improvements, worker pool --- CHANGELOG.md | 4 +- README.md | 2 + lib/logtool.py | 2 +- lib/messaging.py | 10 ++++ lib/messagingAsync.py | 27 ++++++++-- services/diameterService.py | 101 +++++++++++++++++++++++------------- services/hssService.py | 73 ++++++++++++++------------ 7 files changed, 141 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aadc806..2879763 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - /oam/deregister/{imsi} endpoint - /geored/peers endpoint - /geored/webhooks endpoint - - Dependency on Redis for inter-service messaging + - Dependency on Redis 7 for inter-service messaging - Significant performance improvements under load - Basic Rx support for RAA, AAA, ASA and STA - Rx MO call flow support (AAR -> RAR -> RAA -> AAA) @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Logtool no longer handles metric processing - Updated config.yaml - Gx CCR-T now flushes PGW / IMS data, depending on Called-Station-Id -- Benchmarked lossless at ~100 diameter requests per second, per hssService. +- Benchmarked capability of at least ~500 diameter requests per second with a response time of under 2 seconds on a local network. ### Fixed diff --git a/README.md b/README.md index e8972ee..b3ee019 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,8 @@ Dependencies can be installed using Pip3: pip3 install -r requirements.txt ``` +PyHSS also requires [Redis 7.0.0](https://redis.io/docs/getting-started/installation/install-redis-on-linux/) or above. + Then after setting up the config, you can fire up the necessary PyHSS services by running: ```shell python3 diameterService.py diff --git a/lib/logtool.py b/lib/logtool.py index caf20a2..8506113 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -55,7 +55,7 @@ async def logAsync(self, service: str, level: str, message: str, redisClient=Non timestamp = time.time() dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() print(f"[{dateTimeString}] [{level.upper()}] {message}") - asyncio.ensure_future(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) + await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) return True def log(self, service: str, level: str, message: str, redisClient=None) -> bool: diff --git a/lib/messaging.py b/lib/messaging.py index cc78835..7b376a3 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -115,6 +115,16 @@ def awaitMessage(self, key: str): except Exception as e: return '' + def awaitBulkMessage(self, key: str, count: int=100): + """ + Blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. + """ + try: + message = self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count) + return message + except Exception as e: + print(traceback.format_exc()) + return '' def deleteQueue(self, queue: str) -> bool: """ diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 78f2ca7..6c33e0a 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -20,15 +20,32 @@ async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> Stores a message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: - async with self.redisClient.pipeline(transaction=True) as redisPipe: - await redisPipe.rpush(queue, message) - if queueExpiry is not None: - await redisPipe.expire(queue, queueExpiry) - sendMessageResult, expireKeyResult = await redisPipe.execute() + await(self.redisClient.rpush(queue, message)) + if queueExpiry is not None: + await(self.redisClient.expire(queue, queueExpiry)) return f'{message} stored in {queue} successfully.' except Exception as e: return '' + async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None) -> str: + """ + Empties a given asyncio queue into a redis pipeline, then sends to redis. + """ + try: + redisPipe = self.redisClient.pipeline() + + for message in messageList: + redisPipe.rpush(queue, message) + if queueExpiry is not None: + redisPipe.expire(queue, queueExpiry) + + await(redisPipe.execute()) + + return f'Messages stored in {queue} successfully.' + + except Exception as e: + return '' + async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: """ Stores a prometheus metric in a format readable by the metric service, asynchronously. diff --git a/services/diameterService.py b/services/diameterService.py index cebf4cf..958ec72 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -13,7 +13,7 @@ class DiameterService: """ PyHSS Diameter Service A class for handling diameter inbounds and replies on Port 3868, via TCP. - Functions in this class are high-performance, please edit with care. Last profiled on 20-09-2023. + Functions in this class are high-performance, please edit with care. Last profiled October 6th, 2023. """ def __init__(self): @@ -40,6 +40,7 @@ def __init__(self): self.benchmarkingInterval = self.config.get('benchmarking', {}).get('reporting_interval', 3600) self.diameterRequests = 0 self.diameterResponses = 0 + self.workerPoolSize = int(self.config.get('hss', {}).get('diameter_service_workers', 10)) async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ @@ -51,7 +52,7 @@ async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inb originHost = bytes.fromhex(originHost).decode("utf-8") peerType = await(self.diameterLibrary.getPeerType(originHost)) self.activePeers[f"{clientAddress}-{clientPort}"].update({'diameterHostname': originHost, - 'peerType': peerType, + 'peerType': (peerType if peerType != None else 'Unknown'), }) return True except Exception as e: @@ -120,44 +121,66 @@ async def logProcessedMessages(self): async def readInboundData(self, reader, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ - Reads and parses incoming data from a connected client. Validated diameter messages are sent to the redis queue for processing. - Terminates the connection if diameter traffic is not received, or if the client disconnects. + Reads incoming data from a connected client. Data is sent to a shared memory-based queue, to be polled and processed by a worker coroutine. + Terminates the connection if the client disconnects, the queue fills or another exception occurs. """ await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] New connection from {clientAddress} on port {clientPort}")) - peerIsValidated = False + clientConnection = f"{clientAddress}-{clientPort}" while True: try: inboundData = await(asyncio.wait_for(reader.read(8192), timeout=socketTimeout)) if reader.at_eof(): - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Timeout for {clientAddress} on port {clientPort}, closing connection.")) return False - + if len(inboundData) > 0: - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Received data from {clientAddress} on port {clientPort}")) - - if not peerIsValidated: - if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundData)): - await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) - await(asyncio.sleep(0)) - continue - else: - await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Validated peer: {clientAddress} on port {clientPort}")) - peerIsValidated = True - - inboundQueueName = f"diameter-inbound" - inboundHexString = json.dumps({"diameter-inbound": inboundData.hex(), "inbound-received-timestamp": time.time_ns(), "clientAddress": clientAddress, "clientPort": clientPort}) - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Queueing {inboundHexString}")) - await(self.redisReaderMessaging.sendMessage(queue=inboundQueueName, message=inboundHexString, queueExpiry=self.diameterRequestTimeout)) - if self.benchmarking: - self.diameterRequests += 1 - await(asyncio.sleep(0)) - + self.sharedQueue.put_nowait({"diameter-inbound": inboundData, "inbound-received-timestamp": time.time(), "clientAddress": clientAddress, "clientPort": clientPort}) + except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [readInboundData] [{coroutineUuid}] Socket Exception for {clientAddress} on port {clientPort}, closing connection.\n{e}")) return False + async def inboundDataWorker(self, coroutineUuid: str) -> bool: + """ + Collects messages from the memory queue, performs peer validation and fires off to redis every 0.1 seconds. + """ + batchInterval = 0.1 + inboundQueueName = f"diameter-inbound" + while True: + try: + nextSendTime = time.time() + batchInterval + messageList = [] + while time.time() < nextSendTime: + try: + inboundData = await(asyncio.wait_for(self.sharedQueue.get(), timeout=nextSendTime - time.time())) + inboundHex = inboundData.get('diameter-inbound', '').hex() + inboundData['diameter-inbound'] = inboundHex + clientAddress = inboundData.get('clientAddress', '') + clientPort = inboundData.get('clientPort', '') + + if len(self.activePeers.get(f'{clientAddress}-{clientPort}', {}).get('peerType', '')) == 0: + if not await(self.validateDiameterInbound(clientAddress, clientPort, inboundHex)): + await(self.logTool.logAsync(service='Diameter', level='warning', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Invalid Diameter Inbound, discarding data.")) + continue + else: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Validated peer: {clientAddress} on port {clientPort}")) + + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Queueing to redis: {inboundData}")) + messageList.append(json.dumps(inboundData)) + if self.benchmarking: + self.diameterRequests += 1 + except asyncio.TimeoutError: + break + + if messageList: + await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout) + messageList = [] + + except Exception as e: + await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [inboundDataWorker] [{coroutineUuid}] Exception for inboundDataWorker, continuing.\n{e}")) + pass + async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, socketTimeout: int, coroutineUuid: str) -> bool: """ Waits for a message to be received from Redis, then sends to the connected client. @@ -167,14 +190,13 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s try: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Waiting for messages for host {clientAddress} on port {clientPort}")) pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}")))[1]) - await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Received message: {pendingOutboundMessage} for host {clientAddress} on port {clientPort}")) diameterOutboundBinary = bytes.fromhex(pendingOutboundMessage.get('diameter-outbound', '')) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) + writer.write(diameterOutboundBinary) await(writer.drain()) if self.benchmarking: self.diameterResponses += 1 - await(asyncio.sleep(0)) except Exception as e: await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Connection closed for {clientAddress} on port {clientPort}, closing writer.")) return False @@ -190,15 +212,15 @@ async def handleConnection(self, reader, writer): await(self.logTool.logAsync(service='Diameter', level='info', message=f"[Diameter] New Connection from: {clientAddress} on port {clientPort}")) if f"{clientAddress}-{clientPort}" not in self.activePeers: self.activePeers[f"{clientAddress}-{clientPort}"] = { - "connectTimestamp": '', - "disconnectTimestamp": '', - "reconnectionCount": 0, - "ipAddress":'', - "port":'', - "connectionStatus": '', - "diameterHostname": '', - "peerType": '', - } + "connectTimestamp": '', + "disconnectTimestamp": '', + "reconnectionCount": 0, + "ipAddress":'', + "port":'', + "connectionStatus": '', + "diameterHostname": '', + "peerType": '', + } else: reconnectionCount = self.activePeers.get(f"{clientAddress}-{clientPort}", {}).get('reconnectionCount', 0) reconnectionCount += 1 @@ -246,6 +268,11 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): Also create a single instance of self.handleActiveDiameterPeers and self.logProcessedMessages. """ + self.sharedQueue = asyncio.Queue(maxsize=1024) + + for i in range(self.workerPoolSize): + asyncio.create_task(self.inboundDataWorker(coroutineUuid=f'inboundDataWorker-{i}')) + if host is None: host=str(self.config.get('hss', {}).get('bind_ip', '0.0.0.0')[0]) diff --git a/services/hssService.py b/services/hssService.py index b1d9da5..46abcbc 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -40,53 +40,60 @@ def handleQueue(self): if self.benchmarking: startTime = time.perf_counter() - inboundMessage = json.loads(self.redisMessaging.awaitMessage(key='diameter-inbound')[1]) + inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound') - inboundBinary = bytes.fromhex(inboundMessage.get('diameter-inbound', None)) - if inboundBinary == None: + if inboundMessageList == None: continue - inboundHost = inboundMessage.get('clientAddress', None) - inboundPort = inboundMessage.get('clientPort', None) - inboundTimestamp = inboundMessage.get('inbound-received-timestamp', None) + for inboundMessage in inboundMessageList[1]: + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] Message: {inboundMessage}", redisClient=self.redisMessaging) - try: - diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) + inboundMessage = json.loads(inboundMessage.decode('ascii')) + inboundBinary = bytes.fromhex(inboundMessage.get('diameter-inbound', None)) - if diameterOutbound == None: - continue - if not len(diameterOutbound) > 0: + if inboundBinary == None: continue + inboundHost = inboundMessage.get('clientAddress', None) + inboundPort = inboundMessage.get('clientPort', None) + inboundTimestamp = inboundMessage.get('inbound-received-timestamp', None) - diameterMessageTypeDict = self.diameterLibrary.getDiameterMessageType(binaryData=inboundBinary) - - if diameterMessageTypeDict == None: - continue - if not len(diameterMessageTypeDict) > 0: - continue + try: + diameterOutbound = self.diameterLibrary.generateDiameterResponse(binaryData=inboundBinary) - diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') - diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') - except Exception as e: - self.logTool.log(service='HSS', level='warning', message=f"[HSS] [handleQueue] Failed to generate diameter outbound: {e}", redisClient=self.redisMessaging) - continue + if diameterOutbound == None: + continue + if not len(diameterOutbound) > 0: + continue - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) - - outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}" - outboundMessage = json.dumps({"diameter-outbound": diameterOutbound, "inbound-received-timestamp": inboundTimestamp}) + diameterMessageTypeDict = self.diameterLibrary.getDiameterMessageType(binaryData=inboundBinary) + + if diameterMessageTypeDict == None: + continue + if not len(diameterMessageTypeDict) > 0: + continue - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) - self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) + diameterMessageTypeInbound = diameterMessageTypeDict.get('inbound', '') + diameterMessageTypeOutbound = diameterMessageTypeDict.get('outbound', '') + except Exception as e: + self.logTool.log(service='HSS', level='warning', message=f"[HSS] [handleQueue] Failed to generate diameter outbound: {e}", redisClient=self.redisMessaging) + continue - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) - if self.benchmarking: - self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Inbound Diameter Inbound: {inboundMessage}", redisClient=self.redisMessaging) + + outboundQueue = f"diameter-outbound-{inboundHost}-{inboundPort}" + outboundMessage = json.dumps({"diameter-outbound": diameterOutbound, "inbound-received-timestamp": inboundTimestamp}) + + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Generated Diameter Outbound: {diameterOutbound}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) + + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + if self.benchmarking: + self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[HSS] [handleQueue] Exception: {traceback.format_exc()}", redisClient=self.redisMessaging) continue - + if __name__ == '__main__': From 09253836c397d4386a7a43bbc31866aaf2d3c900 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Fri, 6 Oct 2023 15:36:57 +1000 Subject: [PATCH 42/43] Disable webhooks by default in config.yaml --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index a12d857..ed9b869 100644 --- a/config.yaml +++ b/config.yaml @@ -102,7 +102,7 @@ database: ## External Webhook Notifications webhooks: - enabled: True + enabled: False endpoints: - http://10.5.5.66:8080 From 06d0d8e010ff77e06dc72e44274da0a4f5b80ec9 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Mon, 9 Oct 2023 11:13:41 +1000 Subject: [PATCH 43/43] Update default ip in webhooks --- config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index ed9b869..e0e2a84 100644 --- a/config.yaml +++ b/config.yaml @@ -104,7 +104,7 @@ database: webhooks: enabled: False endpoints: - - http://10.5.5.66:8080 + - http://127.0.0.1:8181 ## Geographic Redundancy Parameters geored: