From 5d27ec72722a377f782ded54a6e607221009bc7f Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 1 Nov 2023 09:15:05 -0400 Subject: [PATCH] arch: revamp compiler init for more robust custom compiler --- devito/arch/compiler.py | 100 +++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 0dc1c968585..594ff9e3d76 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -172,19 +172,20 @@ def __init__(self): """ fields = {'cc', 'ld'} + _cpp = False def __init__(self, **kwargs): - super(Compiler, self).__init__(**kwargs) + super().__init__(**kwargs) self.__lookup_cmds__() self.suffix = kwargs.get('suffix') if not kwargs.get('mpi'): - self.cc = self.CC if kwargs.get('cpp', False) is False else self.CXX + self.cc = self.CC if self._cpp is False else self.CXX self.cc = self.cc if self.suffix is None else ('%s-%s' % (self.cc, self.suffix)) else: - self.cc = self.MPICC if kwargs.get('cpp', False) is False else self.MPICXX + self.cc = self.MPICC if self._cpp is False else self.MPICXX self.ld = self.cc # Wanted by the superclass self.cflags = ['-O3', '-g', '-fPIC', '-Wall', '-std=c99'] @@ -196,7 +197,7 @@ def __init__(self, **kwargs): self.defines = [] self.undefines = [] - self.src_ext = 'c' if kwargs.get('cpp', False) is False else 'cpp' + self.src_ext = 'c' if self._cpp is False else 'cpp' if platform.system() == "Linux": self.so_ext = '.so' @@ -216,6 +217,11 @@ def __init__(self, **kwargs): # Knowing the version may still be useful to pick supported flags self.version = sniff_compiler_version(self.CC) + self.__init_finalize__(**kwargs) + + def __init_finalize__(self, **kwargs): + pass + def __new_with__(self, **kwargs): """ Create a new Compiler from an existing one, inherenting from it @@ -394,9 +400,7 @@ def add_ldflags(self, flags): class GNUCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - + def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) self.cflags += ['-Wno-unused-result', @@ -443,9 +447,7 @@ def __lookup_cmds__(self): class ArmCompiler(GNUCompiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - + def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) # Graviton flag @@ -455,8 +457,7 @@ def __init__(self, *args, **kwargs): class ClangCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): self.cflags += ['-Wno-unused-result', '-Wno-unused-variable'] if not configuration['safe-math']: @@ -522,8 +523,7 @@ class AOMPCompiler(Compiler): """AMD's fork of Clang for OpenMP offloading on both AMD and NVidia cards.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) @@ -556,8 +556,7 @@ def __lookup_cmds__(self): class DPCPPCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): self.cflags += ['-qopenmp', '-fopenmp-targets=spir64'] @@ -572,8 +571,7 @@ def __lookup_cmds__(self): class PGICompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, cpp=True, **kwargs) + def __init_finalize__(self, **kwargs): self.cflags.remove('-std=c99') self.cflags.remove('-O3') @@ -618,8 +616,9 @@ def __lookup_cmds__(self): class CudaCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, cpp=True, **kwargs) + _cpp = True + + def __init_finalize__(self, **kwargs): self.cflags.remove('-std=c99') self.cflags.remove('-Wall') @@ -683,8 +682,9 @@ def __lookup_cmds__(self): class HipCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, cpp=True, **kwargs) + _cpp = True + + def __init_finalize__(self, **kwargs): self.cflags.remove('-std=c99') self.cflags.remove('-Wall') @@ -712,8 +712,7 @@ def __lookup_cmds__(self): class IntelCompiler(Compiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) language = kwargs.pop('language', configuration['language']) @@ -771,8 +770,7 @@ def __lookup_cmds__(self): class IntelKNLCompiler(IntelCompiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) @@ -784,8 +782,7 @@ def __init__(self, *args, **kwargs): class OneapiCompiler(IntelCompiler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) language = kwargs.pop('language', configuration['language']) @@ -841,38 +838,37 @@ def __new__(cls, *args, **kwargs): platform = kwargs.pop('platform', configuration['platform']) language = kwargs.pop('language', configuration['language']) - if any(i in environ for i in ['CC', 'CXX', 'CFLAGS', 'LDFLAGS']): - obj = super().__new__(cls) - obj.__init__(*args, **kwargs) - return obj - elif platform is M1: - return ClangCompiler(*args, **kwargs) + if platform is M1: + _base = ClangCompiler elif platform is INTELGPUX: - return OneapiCompiler(*args, **kwargs) + _base = OneapiCompiler elif platform is NVIDIAX: if language == 'cuda': - return CudaCompiler(*args, **kwargs) + _base = CudaCompiler else: - return NvidiaCompiler(*args, **kwargs) + _base = NvidiaCompiler elif platform is AMDGPUX: if language == 'hip': - return HipCompiler(*args, **kwargs) + _base = HipCompiler else: - return AOMPCompiler(*args, **kwargs) + _base = AOMPCompiler else: - return GNUCompiler(*args, **kwargs) - - def __init__(self, *args, **kwargs): - super(CustomCompiler, self).__init__(*args, **kwargs) - - default = '-O3 -g -march=native -fPIC -Wall -std=c99' - self.cflags = environ.get('CFLAGS', default).split(' ') - self.ldflags = environ.get('LDFLAGS', '-shared').split(' ') - - language = kwargs.pop('language', configuration['language']) - - if language == 'openmp': - self.ldflags += environ.get('OMP_LDFLAGS', '-fopenmp').split(' ') + _base = GNUCompiler + + obj = super().__new__(cls) + # Keep base to initialize accordingly + obj._base = _base + + return obj + + def __init_finalize__(self, **kwargs): + self._base.__init_finalize__(self, **kwargs) + # Update cflags + extrac = environ.get('CFLAGS', '').split(' ') + self.cflags = filter_ordered(self.cflags + extrac) + # Update ldflags + extrald = environ.get('LDFLAGS', '').split(' ') + self.ldflags = filter_ordered(self.ldflags + extrald) def __lookup_cmds__(self): self.CC = environ.get('CC', 'gcc')