diff --git a/NN_module/Makefile b/NN_module/Makefile index 869895d..a2f7a11 100644 --- a/NN_module/Makefile +++ b/NN_module/Makefile @@ -1,27 +1,31 @@ -# compiler -FC = gfortran - -# compile flags +# Set compiler and flags - uncomment for ifort or gfortran as appropriate # Set free length and include netcdf-fortran -# Try using nf-config tool if available, otherwise manually add path as per examples below -FCFLAGS = -g -ffree-line-length-none $(shell nf-config --fflags) -# CSD3: -# FCFLAGS = -g -ffree-line-length-none -I/usr/local/software/spack/spack-views/rocky8-icelake-20220710/netcdf-fortran-4.5.4/gcc-11.3.0/intel-oneapi-mpi-2021.6.0/mscqvjsc7bwypshmhnqfc2u3zxnims3r/include -# jwa34 local: -# FCFLAGS = -g -ffree-line-length-none -I/opt/netcdf-fortran/include +# Try using nf-config tool if available, otherwise need to manually add path as per examples below + +# ifort configuration +FC = ifort +FCFLAGS = $(shell nf-config --fflags) + +# gfortran configuration +# FC = gfortran +# FCFLAGS = -g -ffree-line-length-none $(shell nf-config --fflags) # link flags # Link to netcdf and netcdf-fortran # Try using nf-config tool if available, otherwise manually add path as per examples below LDFLAGS = $(shell nf-config --flibs) + +# If flags need adding manually this can be done here # CSD3: +# FCFLAGS = -g -ffree-line-length-none -I/usr/local/software/spack/spack-views/rocky8-icelake-20220710/netcdf-fortran-4.5.4/gcc-11.3.0/intel-oneapi-mpi-2021.6.0/mscqvjsc7bwypshmhnqfc2u3zxnims3r/include # LDFLAGS = -L/usr/local/software/spack/spack-views/rhel8-icelake-20211027_2/netcdf-fortran-4.5.3/gcc-11.2.0/intel-oneapi-mpi-2021.4.0/g4qjb23rucofcg5uitt4jwrkgyf7gba7/lib -lnetcdff -lnetcdf -lm # jwa34 local: +# FCFLAGS = -g -ffree-line-length-none -I/opt/netcdf-fortran/include # LDFLAGS = -L/opt/netcdf-c/lib -lnetcdf -L/opt/netcdf-fortran/lib -lnetcdff -lm PROGRAM = test -SRC = nn_cf_net.f90 nn_convection_flux.f90 nn_interface_SAM.f90 test.f90 +SRC = precision.f90 test_utils.f90 nn_cf_net.f90 nn_convection_flux.f90 nn_interface_SAM.f90 test.f90 OBJECTS = $(SRC:.f90=.o) diff --git a/NN_module/README.md b/NN_module/README.md index a6bca0d..c900cf0 100644 --- a/NN_module/README.md +++ b/NN_module/README.md @@ -14,3 +14,11 @@ Files: - `test.f90` - simple smoke tests for parameterisation routines - `NN_weights_YOG_convection.nc` - NetCDF file with weights for the neural net - Makefile - Makefile to compile these files + +## Running tests + +To run the tests (requires `ifort` or `gfortran` and `netcdf`): +Edit `Makefile` to select the appropriate compiler (`ifort` or `gfortran`) and then build with: +```bash +make test +``` diff --git a/NN_module/nn_convection_flux.f90 b/NN_module/nn_convection_flux.f90 index 72e3558..d23e72b 100644 --- a/NN_module/nn_convection_flux.f90 +++ b/NN_module/nn_convection_flux.f90 @@ -234,12 +234,12 @@ subroutine nn_convection_flux(tabs_i, q_i, y_in, & !! Vector of input features for the NN real(4), dimension(n_outputs) :: outputs !! vector of output features from the NN + real, dimension(nrf) :: t_flux_adv, q_flux_adv, q_tend_auto, & + q_sed_flux nx = size(tabs_i, 1) ny = size(tabs_i, 2) nzm = size(tabs_i, 3) ! NN outputs - real, dimension(nrf) :: t_flux_adv, q_flux_adv, q_tend_auto, & - q_sed_flux ! Output variable t_rad_rest_tend is also an output from the NN (defined above) diff --git a/NN_module/nn_ones.txt b/NN_module/nn_ones.txt new file mode 100644 index 0000000..ff18f75 --- /dev/null +++ b/NN_module/nn_ones.txt @@ -0,0 +1,148 @@ +-1.548654079437 +-1.019196391106 +-0.797004640102 +-0.510117173195 +-0.385661870241 +-0.495599985123 +-1.240013718605 +-1.892510175705 +-0.581315040588 +-0.359731495380 +-0.586395740509 +-0.714409291744 +-0.583598196507 +-0.282303869724 +-0.258652120829 +-0.378690242767 +-0.432278931141 +-0.338912516832 +-0.186107128859 +-0.017364665866 +0.139586955309 +0.260299444199 +0.331388920546 +0.386305928230 +0.428786695004 +0.434366226196 +0.435577899218 +0.455284386873 +0.526996791363 +0.610138535500 +-0.057859659195 +-0.219089031219 +-0.206303104758 +-0.094517558813 +0.067444443703 +0.094589412212 +0.043635398149 +0.371261000633 +0.423852056265 +0.272715002298 +0.252417802811 +0.259922802448 +0.237769454718 +0.239017575979 +0.227276831865 +0.199871718884 +0.176855623722 +0.182130187750 +0.181883573532 +0.180479094386 +0.180221661925 +0.184932529926 +0.185359880328 +0.193712815642 +0.204352051020 +0.205779701471 +0.209363222122 +0.215554103255 +0.213893324137 +-1.202421665192 +-1.255304813385 +-1.070469856262 +-0.605962812901 +-0.148154020309 +0.248896420002 +-0.074281871319 +-1.016677975655 +-1.852268099785 +-1.268948078156 +-0.997557282448 +-1.173674106598 +-1.406493425369 +-1.447221517563 +-1.352780103683 +-1.280928015709 +-1.267250061035 +-1.276774883270 +-1.280700206757 +-1.271878838539 +-1.266510248184 +-1.263604521751 +-1.259351134300 +-1.250584840775 +-1.236418008804 +-1.221935629845 +-1.209158539772 +-1.200800895691 +-1.197698116302 +0.547356724739 +0.541226983070 +0.486596643925 +0.325039625168 +0.369046747684 +0.247089684010 +-0.085400015116 +0.015755124390 +0.722970366478 +0.675642728806 +0.442468702793 +0.440599977970 +0.478552460670 +0.464282572269 +0.445607423782 +0.442082822323 +0.430344402790 +0.413671255112 +0.380288362503 +0.367297232151 +0.360955625772 +0.362972438335 +0.372948408127 +0.382159471512 +0.387978047132 +0.389867186546 +0.389176517725 +0.387510210276 +0.385305345058 +0.383541584015 +0.404772669077 +0.397490769625 +0.385760128498 +0.386608541012 +0.422613054514 +0.463099241257 +0.455268323421 +0.308662593365 +0.524466514587 +0.566737413406 +0.454918086529 +0.346627175808 +0.443202495575 +0.478479027748 +0.382643520832 +0.354028671980 +0.385809838772 +0.406863689423 +0.439339756966 +0.465253412724 +0.485793352127 +0.503615021706 +0.509339690208 +0.497745215893 +0.477612733841 +0.450194656849 +0.430692702532 +0.420731335878 +0.415764302015 +0.414340615273 diff --git a/NN_module/param_test.txt b/NN_module/param_test.txt new file mode 100644 index 0000000..7f9af8b --- /dev/null +++ b/NN_module/param_test.txt @@ -0,0 +1,31 @@ + -0.15258199E+03 -0.59331436E-01 -0.27831873E+03 0.11130974E+00 0.13943182E+02 -0.49222764E-02 -0.12378716E+03 + -0.14338339E+03 -0.59102684E-01 -0.77372841E+02 0.30944204E-01 -0.30485462E+02 0.10762096E-01 -0.99609421E+02 + -0.12325390E+03 -0.94135292E-01 0.97592537E+02 -0.39030794E-01 -0.47757668E+02 0.16859598E-01 -0.81354347E+02 + -0.10417913E+03 -0.18035039E+00 0.26003372E+03 -0.10399691E+00 -0.83319756E+02 0.29413866E-01 -0.76795265E+02 + -0.18704465E+03 -0.10708020E+00 0.45196835E+03 -0.18075854E+00 -0.30860632E+02 0.10894541E-01 -0.74161324E+02 + -0.13310071E+03 0.00000000E+00 0.70741974E+03 -0.28292280E+00 0.20239415E+03 -0.71449973E-01 -0.85394409E+02 + 0.61529816E+02 0.00000000E+00 0.11190435E+04 -0.44754606E+00 0.10683840E+02 -0.37716506E-02 -0.10791988E+03 + 0.14473958E+01 0.00000000E+00 0.12501992E+04 -0.50000000E+00 -0.10559273E+03 0.37276756E-01 -0.13053223E+03 + 0.51176235E+02 0.00000000E+00 0.12501992E+04 -0.50000000E+00 0.00000000E+00 0.00000000E+00 -0.16328114E+03 + 0.93860085E+02 0.00000000E+00 0.12501992E+04 -0.50000000E+00 0.00000000E+00 0.00000000E+00 -0.13449281E+03 + 0.14911021E+03 0.00000000E+00 0.12501992E+04 -0.50000000E+00 0.00000000E+00 0.00000000E+00 -0.87997696E+02 + 0.14079984E+03 0.00000000E+00 0.12501992E+04 -0.50000000E+00 0.00000000E+00 0.00000000E+00 -0.31148405E+02 + 0.21945201E+03 0.49292561E-01 0.13734503E+04 -0.54929256E+00 0.00000000E+00 0.00000000E+00 0.40992861E+01 + 0.27411746E+03 0.46004165E-01 0.13652280E+04 -0.54600418E+00 0.00000000E+00 0.00000000E+00 0.42614012E+01 + 0.33452908E+03 -0.63766092E-01 0.10907585E+04 -0.43623391E+00 0.00000000E+00 0.00000000E+00 -0.13481074E+02 + 0.25489995E+03 -0.31530634E-01 0.11713601E+04 -0.46846938E+00 -0.61843784E+02 0.21832334E-01 -0.70356183E+01 + 0.17641861E+02 0.00000000E+00 0.73373505E+03 -0.29344726E+00 -0.15497134E+03 0.54708585E-01 0.26573279E+01 + -0.17049492E+03 0.36195751E-01 0.42922144E+03 -0.17166121E+00 -0.68361343E+02 0.24133189E-01 0.79629316E+01 + -0.11305133E+03 0.10003121E+00 0.31256876E+03 -0.12500758E+00 -0.18950542E+03 0.66899940E-01 0.96100616E+01 + -0.48079750E+02 0.66902570E-01 0.19118565E+03 -0.76462075E-01 -0.12471986E+03 0.44029091E-01 0.62052822E+01 + -0.78505005E+02 0.65000609E-01 0.12925836E+03 -0.51695105E-01 0.12382239E+03 -0.43712262E-01 0.13433976E+01 + -0.40449608E+02 0.25606329E-01 0.90200867E+02 -0.36074597E-01 0.30558572E+03 -0.10787906E+00 -0.33138554E+01 + 0.34186165E+02 0.11162106E-01 0.74156761E+02 -0.29657977E-01 0.23453377E+03 -0.82796022E-01 -0.56750917E+01 + 0.84712524E+02 -0.32118667E-01 0.69470520E+02 -0.27783781E-01 0.74060410E+02 -0.26145093E-01 -0.32166588E+01 + 0.78100677E+02 -0.38102403E-01 0.70230988E+02 -0.28087918E-01 -0.15935593E+02 0.56256452E-02 -0.20918531E+01 + 0.48418488E+02 -0.28241172E-01 0.71406212E+02 -0.28557934E-01 -0.37185413E+02 0.13127340E-01 -0.19065496E+01 + -0.38718097E+01 -0.66498751E-02 0.71594452E+02 -0.28633216E-01 -0.38420971E+02 0.13563521E-01 0.37559479E-01 + -0.19452473E+02 0.16073607E-01 0.67622612E+02 -0.27044734E-01 -0.26526558E+02 0.93645090E-02 0.30559354E+01 + 0.10720540E+02 0.28101396E-01 0.60303963E+02 -0.24117742E-01 -0.22085796E+02 0.77968142E-02 0.63656316E+01 + -0.53725366E+03 0.25603855E+00 0.51792259E+02 -0.20713603E-01 0.15537941E+01 -0.54852647E-03 0.81452036E+01 + -3.0075550E-02 diff --git a/NN_module/precision.f90 b/NN_module/precision.f90 new file mode 100644 index 0000000..ff9a5e0 --- /dev/null +++ b/NN_module/precision.f90 @@ -0,0 +1,14 @@ +module precision + + use, intrinsic :: iso_fortran_env, only : sp=>real32, dp=>real64 + ! Imports primitives used to interface with C + use, intrinsic :: iso_c_binding, only: c_sp=>c_float, c_dp=>c_double + + implicit none + + public + integer, parameter :: c_wp = c_sp + integer, parameter :: wps = sp + integer, parameter :: wpc = dp + +end module precision diff --git a/NN_module/test.f90 b/NN_module/test.f90 index 3d90f17..19d01ec 100644 --- a/NN_module/test.f90 +++ b/NN_module/test.f90 @@ -1,146 +1,214 @@ -program run_tests - use nn_cf_net_mod, only: relu, nn_cf_net_init, nn_cf_net_finalize - - use nn_convection_flux_mod, only: nn_convection_flux, nn_convection_flux_init, nn_convection_flux_finalize - - implicit none - - real(4), dimension(4) :: test_array = (/ -1.0, 0.0, 0.5, 1.0 /) - integer :: nin, nout - - - - - ! Parameters from SAM that are used here - ! From domain.f90 - integer, parameter :: YES3D = 1 - !! Domain dimensionality: 1 - 3D, 0 - 2D - integer, parameter :: nx_gl = 1 - !! Number of grid points in X - Yani changed to 36 from 32 - integer, parameter :: ny_gl = 1 - !! Number of grid points in Y - integer, parameter :: nz_gl = 48 - !! Number of pressure (scalar) levels - integer, parameter :: nsubdomains_x = 1 - !! No of subdomains in x - integer, parameter :: nsubdomains_y = 1 - !! No of subdomains in y - - ! From grid.f90 - integer, parameter :: nx = nx_gl/nsubdomains_x - !! Number of x points in a subdomain - integer, parameter :: ny = ny_gl/nsubdomains_y - !! Number of y points in a subdomain - integer, parameter :: nz = nz_gl+1 - !! Number of z points in a subdomain - ! Store useful variations on these values - integer, parameter :: nzm = nz-1 - integer, parameter :: nxp3 = nx + 3 - integer, parameter :: nyp3 = ny + 3 * YES3D - integer, parameter :: dimx1_s = -2 - integer, parameter :: dimx2_s = nxp3 - integer, parameter :: dimy1_s = 1-3*YES3D - integer, parameter :: dimy2_s = nyp3 - - integer :: nrf - - != unit J :: t - real t(dimx1_s:dimx2_s, dimy1_s:dimy2_s, nzm) - !! moist static energy - real q(dimx1_s:dimx2_s, dimy1_s:dimy2_s, nzm) - !! total water - ! fluxes at the top and bottom of the domain: - real precsfc(nx,ny) - !! surface precip. rate - ! Horizontally varying stuff (as a function of xy) - real prec_xy(nx,ny) - !! surface precipitation rate - ! reference vertical profiles: - - != unit (kg / m**3) :: rho - real rho(nzm) - !! air density at pressure levels - - real :: tabs(nx, ny, nzm) - !! absolute temperature - - ! Input Variables - - ! Fields from beginning of time step used as NN inputs - real tabs_i(nx,ny,nzm) - !! Temperature - real q_i(nx,ny,nzm) - !! Non-precipitating water mixing ratio - real :: y_in(nx,ny) - !! Distance of column from equator (proxy for insolation and sfc albedo) - real adz(nzm) - !! ratio of the grid spacing to dz for pressure levels - != unit s :: dtn - real dtn - !! current dynamical timestep (can be smaller than dt) - real dz - !! current dynamical timestep (can be smaller than dt) - - - real, dimension(nx,ny,30) :: t_rad_rest_tend, & - t_delta_adv, q_delta_adv, & - t_delta_auto, q_delta_auto, & - t_delta_sed, q_delta_sed - real, dimension(nx,ny) :: prec_sed - - ! Test NN routines - - call relu(test_array) - - write (*,*) test_array - - call nn_cf_net_init("./NN_weights_YOG_convection.nc", nin, nout, 30) - call nn_cf_net_finalize() - - - ! Test convection flux routines - - - t = 0. - q = 0.4 - precsfc = 0. - prec_xy = 0. - rho = 1. - tabs = 1. - tabs_i = 287.15 - q_i = 0.2 - adz = 1. - y_in = 1. - dz = 1. - dtn = 1. - - nrf = 30 - - call nn_convection_flux_init("./NN_weights_YOG_convection.nc") - - call nn_convection_flux(tabs_i, q_i, y_in, & - tabs, & - t, q, & - rho, adz, dz, dtn, & - t_rad_rest_tend, & - t_delta_adv, q_delta_adv, & - t_delta_auto, q_delta_auto, & - t_delta_sed, q_delta_sed, prec_sed) - - q(:,:,1:nrf) = q(:,:,1:nrf) + q_delta_adv(:,:,:) & - + q_delta_auto(:,:,:) & - + q_delta_sed(:,:,:) - t(:,:,1:nrf) = t(:,:,1:nrf) + t_delta_adv(:,:,:) & - + t_delta_auto(:,:,:) & - + t_delta_sed(:,:,:) & - + t_rad_rest_tend(:,:,:)*dtn - - call nn_convection_flux_finalize() +module tests + !! Module containing individual tests for the CAM ML code + + !-------------------------------------------------------------------------- + ! Libraries to use + use netcdf + use nn_cf_net_mod, only: relu, net_forward, nn_cf_net_init, nn_cf_net_finalize + use nn_convection_flux_mod, only: nn_convection_flux, nn_convection_flux_init, nn_convection_flux_finalize + use test_utils, only: assert_array_equal + + implicit none + + character(len=15) :: pass = char(27)//'[32m'//'PASSED'//char(27)//'[0m' + character(len=15) :: fail = char(27)//'[31m'//'FAILED'//char(27)//'[0m' + integer, parameter :: nrf = 30 + integer, parameter :: n_nn_out = 148 + real(4), dimension(n_nn_out) :: nn_out_ones + + contains + + subroutine test_relu(test_name) + !! Test relu function is working as expected + + character(len=*), intent(in) :: test_name + real(4), dimension(4) :: test_array = (/ -1.0, 0.0, 0.5, 1.0 /) + real(4), dimension(4) :: res_array = (/ 0.0, 0.0, 0.5, 1.0 /) + + call relu(test_array) + + call assert_array_equal(res_array, test_array, test_name) + + end subroutine test_relu + + subroutine test_nn_cf_init(test_name) + !! Test NN initialisation is working as expected + !! Checks that nin and nout are read in as expected + !! Checking deeper requires interaction with nn_cf_net_mod module vars + + character(len=*), intent(in) :: test_name + character(len=1024) :: nn_filename + integer :: nin, nout + + nn_filename = "./NN_weights_YOG_convection.nc" + call nn_cf_net_init(nn_filename, nin, nout, nrf) + + call nn_cf_net_finalize() + + if (nin == 61) then + write(*, '(A, " :: [", A, " - nin]")') pass, trim(test_name) + if (nout == n_nn_out) then + write(*, '(A, " :: [", A, " - nout]")') pass, trim(test_name) + else + write(*, '(A, " :: [", A, "] with nout = ", I3)') fail, trim(test_name), nout + end if + else + write(*, '(A, " :: [", A, "] with nin = ", I3)') fail, trim(test_name), nin + end if - write (*,*) t(-2:2, 0, 1) - write (*,*) t(0, -2:2, 1) - write (*,*) t(-2, -2, 1:48) - ! write (*,*) t(-1, -1, 1:48) + end subroutine test_nn_cf_init + + subroutine test_nn(test_name) + !! Test NN is producing the expected results. + + integer :: i, nin, nout + character(len=*), intent(in) :: test_name + character(len=1024) :: nn_filename + real(4) :: nn_in(61) + real(4) :: nn_out(n_nn_out) + + nn_in = 1.0 + + nn_filename = "./NN_weights_YOG_convection.nc" + call nn_cf_net_init(nn_filename, nin, nout, nrf) + + call net_forward(nn_in, nn_out) + + call nn_cf_net_finalize() + + call assert_array_equal(nn_out, nn_out_ones, test_name, 1.0e-4) + + end subroutine test_nn + + subroutine test_param(test_name) + !! Test Parameterisation is producing the same results as it initially did. + !! Run for a single column with physically plausible parameters. + + integer :: i, io, stat + character(len=*), intent(in) :: test_name + character(len=1024) :: nn_filename + character(len=512) :: msg + + real :: tabs_i(1, 1, 48) = 293.15 + real :: q_i(1, 1, 48) = 0.5 + real :: y_in(1, 48) = 0.0 + real :: tabs(1, 1, 48) = 293.15 + real :: t_0(1, 1, 48) = 1.0e4 + real :: q_0(1, 1, 48) = 0.5 + real :: rho(48) = 1.2 + real :: adz(48) = 1.0 + real :: dz = 100.0 + real :: dtn = 2.0 + real, dimension(1,1,nrf) :: t_delta_adv, q_delta_adv, & + t_delta_auto, q_delta_auto, & + t_delta_sed, q_delta_sed + real :: t_rad_rest_tend(1,1,nrf) + real :: prec_sed(1,1) + + real, dimension(1,1,nrf) :: t_delta_adv_dat, q_delta_adv_dat, & + t_delta_auto_dat, q_delta_auto_dat, & + t_delta_sed_dat, q_delta_sed_dat + real :: t_rad_rest_tend_dat(1,1,nrf) + real :: prec_sed_dat(1,1) + + nn_filename = "./NN_weights_YOG_convection.nc" + call nn_convection_flux_init(nn_filename) + + call nn_convection_flux(tabs_i, q_i, y_in, & + tabs, & + t_0, q_0, & + rho, adz, dz, dtn, & + t_rad_rest_tend, & + t_delta_adv, q_delta_adv, & + t_delta_auto, q_delta_auto, & + t_delta_sed, q_delta_sed, prec_sed) + + call nn_convection_flux_finalize() + + nn_filename = "param_test.txt" + + ! Writing data out to file from original code runniing + ! open(newunit=io, file=trim(nn_filename), status="replace", action="write", & + ! iostat=stat, iomsg=msg) + ! if (stat /= 0) then + ! print *, trim(msg) + ! stop + ! end if + ! do i = 1,nrf + ! write(io, '(7E18.8)') t_delta_adv(1,1,i), q_delta_adv(1,1,i), & + ! t_delta_auto(1,1,i), q_delta_auto(1,1,i), & + ! t_delta_sed(1,1,i), q_delta_sed(1,1,i), t_rad_rest_tend(1,1,i) + ! enddo + ! write(io, '(E18.8)') prec_sed(1,1) + ! close(io) + + open(newunit=io, file=trim(nn_filename), status="old", action="read", & + iostat=stat, iomsg=msg) + if (stat /= 0) then + print *, trim(msg) + stop + end if + do i = 1,nrf + read(io, '(7E18.8)') t_delta_adv_dat(1,1,i), q_delta_adv_dat(1,1,i), & + t_delta_auto_dat(1,1,i), q_delta_auto_dat(1,1,i), & + t_delta_sed_dat(1,1,i), q_delta_sed_dat(1,1,i), & + t_rad_rest_tend_dat(1,1,i) + enddo + read(io, '(E18.8)') prec_sed_dat(1,1) + close(io) + + call assert_array_equal(t_delta_adv, t_delta_adv_dat, test_name//" t adv", 1.0e-6) + call assert_array_equal(q_delta_adv, q_delta_adv_dat, test_name//" q adv", 1.0e-6) + call assert_array_equal(t_delta_auto, t_delta_auto_dat, test_name//" t auto", 1.0e-6) + call assert_array_equal(q_delta_auto, q_delta_auto_dat, test_name//" q auto", 1.0e-6) + call assert_array_equal(t_delta_sed, t_delta_sed_dat, test_name//" t sed", 1.0e-6) + call assert_array_equal(q_delta_sed, q_delta_sed_dat, test_name//" q sed", 1.0e-6) + call assert_array_equal(t_rad_rest_tend, t_rad_rest_tend_dat, test_name//" t rad", 1.0e-6) + call assert_array_equal(prec_sed, prec_sed_dat, test_name//" prec", 1.0e-6) + + end subroutine test_param + + subroutine load_nn_out_ones(nn_ones_file) + !! Load the result of running the NN with ones + + integer :: io, stat, i + character(len=512) :: msg + character(len=*) :: nn_ones_file + + open(newunit=io, file=trim(nn_ones_file), status="old", action="read", & + iostat=stat, iomsg=msg) + if (stat /= 0) then + print *, trim(msg) + stop + end if + do i = 1,n_nn_out + read(io, *) nn_out_ones(i) + enddo + close(io) + + end subroutine load_nn_out_ones + +end module tests + + + +program run_tests + + use tests + + implicit none + + character(len=1024) :: nn_filename + + ! Test NN routines + + call test_relu("test_relu") + call test_nn_cf_init("test_nn_cf_init") + + call load_nn_out_ones("nn_ones.txt") + call test_nn("Test NN ones") + + call test_param("Test param simple") end program run_tests diff --git a/NN_module/test_utils.f90 b/NN_module/test_utils.f90 new file mode 100644 index 0000000..a8c92b5 --- /dev/null +++ b/NN_module/test_utils.f90 @@ -0,0 +1,160 @@ +module test_utils + + use :: precision, only: sp, dp + + implicit none + + character(len=15) :: pass = char(27)//'[32m'//'PASSED'//char(27)//'[0m' + character(len=15) :: fail = char(27)//'[31m'//'FAILED'//char(27)//'[0m' + + interface assert_array_equal + module procedure & + assert_array_equal_1d_sp, assert_array_equal_2d_sp, assert_array_equal_3d_sp, & + assert_array_equal_1d_dp, assert_array_equal_2d_dp, assert_array_equal_3d_dp + end interface + + interface print_assert + module procedure print_assert_sp, print_assert_dp + end interface + + contains + + subroutine print_assert_sp(test_name, is_close, relative_error) + + character(len=*), intent(in) :: test_name + logical, intent(in) :: is_close + real(sp), intent(in) :: relative_error + + if (is_close) then + write(*, '(A, " :: [", A, "] maximum relative error = ", E11.4)') pass, trim(test_name), relative_error + else + write(*, '(A, " :: [", A, "] maximum relative error = ", E11.4)') fail, trim(test_name), relative_error + end if + + end subroutine print_assert_sp + + subroutine print_assert_dp(test_name, is_close, relative_error) + + character(len=*), intent(in) :: test_name + logical, intent(in) :: is_close + real(dp), intent(in) :: relative_error + + if (is_close) then + write(*, '(A, " :: [", A, "] maximum relative error = ", E11.4)') pass, trim(test_name), relative_error + else + write(*, '(A, " :: [", A, "] maximum relative error = ", E11.4)') fail, trim(test_name), relative_error + end if + + end subroutine print_assert_dp + + subroutine assert_array_equal_1d_sp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(sp), intent(in), dimension(:) :: a, b + real(sp), intent(in), optional :: rtol_opt + real(sp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_1d_sp + + subroutine assert_array_equal_2d_sp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(sp), intent(in), dimension(:,:) :: a, b + real(sp), intent(in), optional :: rtol_opt + real(sp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_2d_sp + + subroutine assert_array_equal_3d_sp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(sp), intent(in), dimension(:,:,:) :: a, b + real(sp), intent(in), optional :: rtol_opt + real(sp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_3d_sp + + subroutine assert_array_equal_1d_dp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(dp), intent(in), dimension(:) :: a, b + real(dp), intent(in), optional :: rtol_opt + real(dp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_1d_dp + + subroutine assert_array_equal_2d_dp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(dp), intent(in), dimension(:,:) :: a, b + real(dp), intent(in), optional :: rtol_opt + real(dp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_2d_dp + + subroutine assert_array_equal_3d_dp(a, b, test_name, rtol_opt) + + character(len=*), intent(in) :: test_name + real(dp), intent(in), dimension(:,:,:) :: a, b + real(dp), intent(in), optional :: rtol_opt + real(dp) :: relative_error, rtol + + if (.not. present(rtol_opt)) then + rtol = 1.0e-5 + else + rtol = rtol_opt + end if + + relative_error = maxval(abs(a/b - 1.0)) + call print_assert(test_name, (rtol > relative_error), relative_error) + + end subroutine assert_array_equal_3d_dp + +end module test_utils