Skip to content

Commit

Permalink
Merge pull request #5130 from chrisburr/backport-e764f2039c
Browse files Browse the repository at this point in the history
[v7r2] Backport: Factorise recurseImport to DIRAC.Core.Utilities.Extensions
  • Loading branch information
Andrei Tsaregorodtsev committed May 19, 2021
2 parents a979a26 + 0ad764c commit 6cb54d6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 149 deletions.
46 changes: 9 additions & 37 deletions src/DIRAC/Core/Base/private/ModuleLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from __future__ import division
from __future__ import print_function

import six
import os
import imp
from DIRAC.Core.Utilities import List
from DIRAC import gConfig, S_ERROR, S_OK, gLogger
from DIRAC.ConfigurationSystem.Client.Helpers import getInstalledExtensions
from DIRAC.ConfigurationSystem.Client import PathFinder
from DIRAC.Core.Utilities.Extensions import recurseImport


class ModuleLoader(object):
Expand Down Expand Up @@ -65,12 +64,11 @@ def loadModules(self, modulesList, hideExceptions=False):
# Look what is installed
parentModule = None
for rootModule in getInstalledExtensions():
if system.find("System") != len(system) - 6:
parentImport = "%s.%sSystem.%s" % (rootModule, system, self.__csSuffix)
else:
parentImport = "%s.%s.%s" % (rootModule, system, self.__csSuffix)
if not system.endswith("System"):
system += "System"
parentImport = "%s.%s.%s" % (rootModule, system, self.__csSuffix)
# HERE!
result = self.__recurseImport(parentImport)
result = recurseImport(parentImport)
if not result['OK']:
return result
parentModule = result['Value']
Expand Down Expand Up @@ -113,7 +111,7 @@ def loadModule(self, modName, hideExceptions=False, parentModule=False):
for loadModName in loadGroup:
if loadModName.find("/") == -1:
loadModName = "%s/%s" % (modList[0], loadModName)
result = self.loadModule(loadModName, hideExceptions=hideExceptions, parentModule=False)
result = self.loadModule(loadModName, hideExceptions=hideExceptions)
if not result['OK']:
return result
return S_OK()
Expand Down Expand Up @@ -146,7 +144,7 @@ def loadModule(self, modName, hideExceptions=False, parentModule=False):
if handlerPath.find(".py", len(handlerPath) - 3) > -1:
handlerPath = handlerPath[:-3]
className = List.fromChar(handlerPath, ".")[-1]
result = self.__recurseImport(handlerPath)
result = recurseImport(handlerPath)
if not result['OK']:
return S_ERROR("Cannot load user defined handler %s: %s" % (handlerPath, result['Message']))
gLogger.verbose("Loaded %s" % handlerPath)
Expand All @@ -156,7 +154,7 @@ def loadModule(self, modName, hideExceptions=False, parentModule=False):
modImport = module
if self.__modSuffix:
modImport = "%s%s" % (modImport, self.__modSuffix)
result = self.__recurseImport(modImport, parentModule, hideExceptions=hideExceptions)
result = recurseImport(modImport, parentModule, hideExceptions=hideExceptions)
else:
# Check to see if the module exists in any of the root modules
gLogger.info("Trying to autodiscover %s" % loadName)
Expand All @@ -166,7 +164,7 @@ def loadModule(self, modName, hideExceptions=False, parentModule=False):
if self.__modSuffix:
importString = "%s%s" % (importString, self.__modSuffix)
gLogger.verbose("Trying to load %s" % importString)
result = self.__recurseImport(importString, hideExceptions=hideExceptions)
result = recurseImport(importString, hideExceptions=hideExceptions)
# Error while loading
if not result['OK']:
return result
Expand Down Expand Up @@ -203,29 +201,3 @@ def loadModule(self, modName, hideExceptions=False, parentModule=False):
gLogger.notice("Loaded module %s" % modName)

return S_OK()

def __recurseImport(self, modName, parentModule=None, hideExceptions=False):
gLogger.debug("importing recursively %s, parentModule=%s, hideExceptions=%s" % (modName,
parentModule,
hideExceptions))
if isinstance(modName, six.string_types):
modName = List.fromChar(modName, ".")
try:
if parentModule:
impData = imp.find_module(modName[0], parentModule.__path__)
else:
impData = imp.find_module(modName[0])
impModule = imp.load_module(modName[0], *impData)
if impData[0]:
impData[0].close()
except ImportError as excp:
strExcp = str(excp)
if strExcp.find("No module named") == 0 and strExcp.find(modName[0]) == len(strExcp) - len(modName[0]):
return S_OK()
errMsg = "Can't load %s" % ".".join(modName)
if not hideExceptions:
gLogger.exception(errMsg)
return S_ERROR(errMsg)
if len(modName) == 1:
return S_OK(impModule)
return self.__recurseImport(modName[1:], impModule)
60 changes: 16 additions & 44 deletions src/DIRAC/Core/Utilities/DErrno.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

import six
import os
import imp
import importlib
import sys

# To avoid conflict, the error numbers should be greater than 1000
Expand Down Expand Up @@ -344,50 +344,22 @@ def includeExtensionErrors():
Should be called only at the initialization of DIRAC, so by the parseCommandLine,
dirac-agent.py, dirac-service.py, dirac-executor.py
"""

def __recurseImport(modName, parentModule=None, fullName=False):
""" Internal function to load modules
"""
if isinstance(modName, six.string_types):
modName = modName.split(".")
if not fullName:
fullName = ".".join(modName)
try:
if parentModule:
impData = imp.find_module(modName[0], parentModule.__path__)
else:
impData = imp.find_module(modName[0])
impModule = imp.load_module(modName[0], *impData)
if impData[0]:
impData[0].close()
except ImportError:
return None
if len(modName) == 1:
return impModule
return __recurseImport(modName[1:], impModule, fullName=fullName)

from DIRAC.ConfigurationSystem.Client.Helpers import CSGlobals
allExtensions = CSGlobals.getCSExtensions()

for extension in allExtensions:
ext_derrno = None
for extension in CSGlobals.getCSExtensions():
try:

ext_derrno = __recurseImport('%sDIRAC.Core.Utilities.DErrno' % extension)

if ext_derrno:
# The next 3 dictionary MUST be present for consistency

# Global name of errors
sys.modules[__name__].__dict__.update(ext_derrno.extra_dErrName)
# Dictionary with the error codes
sys.modules[__name__].dErrorCode.update(ext_derrno.extra_dErrorCode)
# Error description string
sys.modules[__name__].dStrError.update(ext_derrno.extra_dStrError)

# extra_compatErrorString is optional
for err in getattr(ext_derrno, 'extra_compatErrorString', []):
sys.modules[__name__].compatErrorString.setdefault(err, []).extend(ext_derrno.extra_compatErrorString[err])

except BaseException:
ext_derrno = importlib.import_module('%sDIRAC.Core.Utilities.DErrno' % extension)
except ImportError:
pass
else:
# The next 3 dictionary MUST be present for consistency
# Global name of errors
sys.modules[__name__].__dict__.update(ext_derrno.extra_dErrName)
# Dictionary with the error codes
sys.modules[__name__].dErrorCode.update(ext_derrno.extra_dErrorCode)
# Error description string
sys.modules[__name__].dStrError.update(ext_derrno.extra_dStrError)

# extra_compatErrorString is optional
for err in getattr(ext_derrno, 'extra_compatErrorString', []):
sys.modules[__name__].compatErrorString.setdefault(err, []).extend(ext_derrno.extra_compatErrorString[err])
51 changes: 9 additions & 42 deletions src/DIRAC/Core/Utilities/ObjectLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

import six
import re
import imp
import pkgutil
import collections

from DIRAC import gLogger, S_OK, S_ERROR
from DIRAC.Core.Utilities import DErrno
from DIRAC.Core.Utilities import List, DIRACSingleton
from DIRAC.Core.Utilities.Extensions import recurseImport
from DIRAC.ConfigurationSystem.Client.Helpers import CSGlobals


Expand All @@ -41,7 +41,6 @@ def __init__(self, baseModules=False):

def _init(self, baseModules):
""" Actually performs the initialization """

if not baseModules:
baseModules = ['DIRAC']
self.__rootModules = baseModules
Expand Down Expand Up @@ -70,7 +69,7 @@ def __rootImport(self, modName, hideExceptions=False):
if rootModule:
impName = "%s.%s" % (rootModule, impName)
gLogger.debug("Trying to load %s" % impName)
result = self.__recurseImport(impName, hideExceptions=hideExceptions)
result = recurseImport(impName, hideExceptions=hideExceptions)
# Error. Something cannot be imported. Return error
if not result['OK']:
return result
Expand All @@ -81,42 +80,12 @@ def __rootImport(self, modName, hideExceptions=False):
# Return nothing found
return S_OK()

def __recurseImport(self, modName, parentModule=None, hideExceptions=False, fullName=False):
""" Internal function to load modules
"""
if isinstance(modName, six.string_types):
modName = List.fromChar(modName, ".")
if not fullName:
fullName = ".".join(modName)
if fullName in self.__objs:
return S_OK(self.__objs[fullName])
try:
if parentModule:
impData = imp.find_module(modName[0], parentModule.__path__)
else:
impData = imp.find_module(modName[0])
impModule = imp.load_module(modName[0], *impData)
if impData[0]:
impData[0].close()
except Exception as excp:
if "No module named" in str(excp) and modName[0] in str(excp):
return S_OK(None)
errMsg = "Can't load %s in %s" % (".".join(modName), parentModule.__path__[0])
if not hideExceptions:
gLogger.exception(errMsg)
return S_ERROR(DErrno.EIMPERR, errMsg)
if len(modName) == 1:
self.__objs[fullName] = impModule
return S_OK(impModule)
return self.__recurseImport(modName[1:], impModule,
hideExceptions=hideExceptions, fullName=fullName)

def __generateRootModules(self, baseModules):
""" Iterate over all the possible root modules
"""
self.__rootModules = baseModules
for rootModule in reversed(CSGlobals.getCSExtensions()):
if rootModule[-5:] != "DIRAC" and rootModule not in self.__rootModules:
if not rootModule.endswith("DIRAC") and rootModule not in self.__rootModules:
self.__rootModules.append("%sDIRAC" % rootModule)
self.__rootModules.append("")

Expand All @@ -136,18 +105,16 @@ def loadModule(self, importString, hideExceptions=False):
def loadObject(self, importString, objName=False, hideExceptions=False):
""" Load an object from inside a module
"""
if not objName:
objName = importString.split(".")[-1]

result = self.loadModule(importString, hideExceptions=hideExceptions)
if not result['OK']:
return result
modObj = result['Value']
modFile = modObj.__file__

if not objName:
objName = List.fromChar(importString, ".")[-1]

try:
result = S_OK(getattr(modObj, objName))
result['ModuleFile'] = modFile
result['ModuleFile'] = modObj.__file__
return result
except AttributeError:
return S_ERROR(DErrno.EIMPERR, "%s does not contain a %s object" % (importString, objName))
Expand Down Expand Up @@ -179,7 +146,7 @@ def getObjects(self, modulePath, reFilter=None, parentClass=None, recurse=False,
impPath = modulePath
gLogger.debug("Trying to load %s" % impPath)

result = self.__recurseImport(impPath)
result = recurseImport(impPath)
if not result['OK']:
return result
if not result['Value']:
Expand All @@ -204,7 +171,7 @@ def getObjects(self, modulePath, reFilter=None, parentClass=None, recurse=False,
if modKeyName in modules:
continue
fullName = "%s.%s" % (impPath, modName)
result = self.__recurseImport(modName, parentModule=parentModule, fullName=fullName)
result = recurseImport(fullName)
if not result['OK']:
if continueOnError:
gLogger.error("Error loading module but continueOnError is true", "module %s error %s" % (fullName, result))
Expand Down
26 changes: 0 additions & 26 deletions src/DIRAC/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,6 @@
__siteName = False


# # Update DErrno with the extensions errors
# from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
# from DIRAC.ConfigurationSystem.Client.Helpers import CSGlobals
# allExtensions = CSGlobals.getCSExtensions()
#
# # Update for each extension. Careful to conflict :-)
# for extension in allExtensions:
# ol = ObjectLoader( baseModules = ["%sDIRAC" % extension] )
# extraErrorModule = ol.loadModule( 'Core.Utilities.DErrno' )
# if extraErrorModule['OK']:
# extraErrorModule = extraErrorModule['Value']
#
# # The next 3 dictionary MUST be present for consistency
#
# # Global name of errors
# DErrno.__dict__.update( extraErrorModule.extra_dErrName )
# # Dictionary with the error codes
# DErrno.dErrorCode.update( extraErrorModule.extra_dErrorCode )
# # Error description string
# DErrno.dStrError.update( extraErrorModule.extra_dStrError )
#
# # extra_compatErrorString is optional
# for err in getattr( extraErrorModule, 'extra_compatErrorString', [] ) :
# DErrno.compatErrorString.setdefault( err, [] ).extend( extraErrorModule.extra_compatErrorString[err] )


def siteName():
"""
Determine and return DIRAC name for current site
Expand Down

0 comments on commit 6cb54d6

Please sign in to comment.