Test_CSR_SparseMatrix.pf Source File


This file depends on

sourcefile~~test_csr_sparsematrix.pf~~EfferentGraph sourcefile~test_csr_sparsematrix.pf Test_CSR_SparseMatrix.pf sourcefile~csr_sparsematrix.f90 CSR_SparseMatrix.F90 sourcefile~test_csr_sparsematrix.pf->sourcefile~csr_sparsematrix.f90 sourcefile~keywordenforcer.f90 KeywordEnforcer.F90 sourcefile~csr_sparsematrix.f90->sourcefile~keywordenforcer.f90

Source Code

module Test_CSR_SparseMatrix
   use mapl3g_CSR_SparseMatrix
   use funit
   use, intrinsic :: iso_fortran_env
   implicit none

contains

   @test
   ! [ 1.  1. 0.]
   ! [ 0.  1. 0.]
   subroutine test_simple()
      integer, parameter :: M = 2, N = 3
      type(CSR_SparseMatrix_sp) :: mat
      real :: x(N), y(M)

      mat = CSR_SparseMatrix_sp(M, N, nnz=3)
      call add_row(mat, 1, 1, [1.,1.])
      call add_row(mat, 2, 2, [1.])

      x = 1
      y = matmul(mat, x)

      @assert_that(y, is(equal_to([2.,1.])))
      
   end subroutine test_simple

   @test
   ! Column 1:
   ! [ 1.  1. 0.]
   ! [ 0.  1. 0.]
   ! Column 2:
   ! [ 0.  1. 1.]
   ! [ 0.  0. 2.]
   ! Column 3:
   ! [ 1.  1. 1.]
   ! [ 0.  1. 2.]
   subroutine test_multi_column()
      integer, parameter :: M = 2, N = 3
      type(CSR_SparseMatrix_sp) :: mat(3)
      real :: x(3,N), y_found(3, M), y_expected(3,M)
      
      mat(1) = CSR_SparseMatrix_sp(M, N, 3)
      call add_row(mat(1), 1, 1, [1.,1.])
      call add_row(mat(1), 2, 2, [1.])

      mat(2) = CSR_SparseMatrix_sp(M, N, 3)
      call add_row(mat(2), 1, 2, [1.,1.])
      call add_row(mat(2), 2, 3, [2.])

      mat(3) = CSR_SparseMatrix_sp(M, N, 5)
      call add_row(mat(3), 1, 1, [1.,1.,1.])
      call add_row(mat(3), 2, 2, [1.,2.])

      x = 1
      y_found = matmul(mat, x)

      y_expected(1,:) = [2.,1.]
      y_expected(2,:) = [2.,2.]
      y_expected(3,:) = [3.,3.]

      @assert_that(y_found, is(equal_to(y_expected)))
      
   end subroutine test_multi_column

   @test
   subroutine test_multi_column_real64()
      integer, parameter :: M = 2, N = 3
      type(CSR_SparseMatrix_dp) :: mat(3)
      real(REAL64) :: x(3,N), y_found(3, M), y_expected(3,M)
      
      mat(1) = CSR_SparseMatrix_dp(M, N, 3)
      call add_row(mat(1), 1, 1, [1.d0,1.d0])
      call add_row(mat(1), 2, 2, [1.d0])

      mat(2) = CSR_SparseMatrix_dp(M, N, 3)
      call add_row(mat(2), 1, 2, [1.d0,1.d0])
      call add_row(mat(2), 2, 3, [2.d0])

      mat(3) = CSR_SparseMatrix_dp(M, N, 5)
      call add_row(mat(3), 1, 1, [1.d0,1.d0,1.d0])
      call add_row(mat(3), 2, 2, [1.d0,2.d0])

      x = 1
      y_found = matmul(mat, x)

      y_expected(1,:) = [2.,1.]
      y_expected(2,:) = [2.,2.]
      y_expected(3,:) = [3.,3.]

      @assert_that(y_found, is(equal_to(y_expected)))
      
   end subroutine test_multi_column_real64
   
   @test
   subroutine test_multi_column_mixed_prec()
      integer, parameter :: M = 2, N = 3
      type(CSR_SparseMatrix_dp) :: mat(3)
      real(REAL32) :: x(3,N), y_found(3, M), y_expected(3,M)
      
      mat(1) = CSR_SparseMatrix_dp(M, N, 3)
      call add_row(mat(1), 1, 1, [1.d0,1.d0])
      call add_row(mat(1), 2, 2, [1.d0])

      mat(2) = CSR_SparseMatrix_dp(M, N, 3)
      call add_row(mat(2), 1, 2, [1.d0,1.d0])
      call add_row(mat(2), 2, 3, [2.d0])

      mat(3) = CSR_SparseMatrix_dp(M, N, 5)
      call add_row(mat(3), 1, 1, [1.d0,1.d0,1.d0])
      call add_row(mat(3), 2, 2, [1.d0,2.d0])

      x = 1
      y_found = matmul(mat, x)

      y_expected(1,:) = [2.,1.]
      y_expected(2,:) = [2.,2.]
      y_expected(3,:) = [3.,3.]

      @assert_that(y_found, is(equal_to(y_expected)))
      
   end subroutine test_multi_column_mixed_prec
   
end module Test_CSR_SparseMatrix