Skip to content

Commit cf5f030

Browse files
committed
add error handling in a better way
1 parent e709f83 commit cf5f030

File tree

2 files changed

+87
-12
lines changed

2 files changed

+87
-12
lines changed

src/stdlib_intrinsics.fypp

+21-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module stdlib_intrinsics
88
!!Alternative implementations of some Fortran intrinsic functions offering either faster and/or more accurate evaluation.
99
!! ([Specification](../page/specs/stdlib_intrinsics.html))
1010
use stdlib_kinds
11+
use stdlib_linalg_state, only: linalg_state_type
1112
implicit none
1213
private
1314

@@ -162,14 +163,33 @@ module stdlib_intrinsics
162163
!!
163164
!! Note: The matrices must be of compatible shapes to be multiplied
164165
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
165-
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
166+
pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
166167
${t}$, intent(in) :: m1(:,:), m2(:,:)
167168
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
168169
${t}$, allocatable :: r(:,:)
170+
end function stdlib_matmul_pure_${s}$
171+
172+
module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
173+
${t}$, intent(in) :: m1(:,:), m2(:,:)
174+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
175+
type(linalg_state_type), intent(out) :: err
176+
${t}$, allocatable :: r(:,:)
169177
end function stdlib_matmul_${s}$
170178
#:endfor
171179
end interface stdlib_matmul
172180
public :: stdlib_matmul
181+
182+
! internal interface
183+
interface stdlib_matmul_sub
184+
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
185+
pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err)
186+
${t}$, intent(out), allocatable :: res(:,:)
187+
${t}$, intent(in) :: m1(:,:), m2(:,:)
188+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
189+
type(linalg_state_type), intent(out), optional :: err
190+
end subroutine stdlib_matmul_sub_${s}$
191+
#:endfor
192+
end interface stdlib_matmul_sub
173193

174194
contains
175195

src/stdlib_intrinsics_matmul.fypp

+66-11
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
77
use stdlib_linalg_blas, only: gemm
8+
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR
89
use stdlib_constants
910
implicit none
1011

12+
character(len=*), parameter :: this = "stdlib_matmul"
13+
1114
contains
1215

1316
! Algorithm for the optimal parenthesization of matrices
@@ -122,41 +125,76 @@ contains
122125

123126
end function matmul_chain_mult_${s}$_4
124127

125-
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
128+
pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err)
129+
${t}$, intent(out), allocatable :: res(:,:)
126130
${t}$, intent(in) :: m1(:,:), m2(:,:)
127131
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
128-
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
132+
type(linalg_state_type), intent(out), optional :: err
133+
${t}$, allocatable :: temp(:,:), temp1(:,:)
129134
integer :: p(6), num_present, m, n, k
130135
integer, allocatable :: s(:,:)
131136

137+
type(linalg_state_type) :: err0
138+
132139
p(1) = size(m1, 1)
133140
p(2) = size(m2, 1)
134141
p(3) = size(m2, 2)
135142

143+
if (size(m1, 2) /= p(2)) then
144+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes')
145+
call linalg_error_handling(err0, err)
146+
allocate(res(0, 0))
147+
return
148+
end if
149+
136150
num_present = 2
137151
if (present(m3)) then
152+
153+
if (size(m3, 1) /= p(3)) then
154+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes')
155+
call linalg_error_handling(err0, err)
156+
allocate(res(0, 0))
157+
return
158+
end if
159+
138160
p(3) = size(m3, 1)
139161
p(4) = size(m3, 2)
140162
num_present = num_present + 1
141163
end if
142164
if (present(m4)) then
165+
166+
if (size(m4, 1) /= p(4)) then
167+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes')
168+
call linalg_error_handling(err0, err)
169+
allocate(res(0, 0))
170+
return
171+
end if
172+
143173
p(4) = size(m4, 1)
144174
p(5) = size(m4, 2)
145175
num_present = num_present + 1
146176
end if
147177
if (present(m5)) then
178+
179+
if (size(m5, 1) /= p(5)) then
180+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes')
181+
call linalg_error_handling(err0, err)
182+
allocate(res(0, 0))
183+
return
184+
end if
185+
148186
p(5) = size(m5, 1)
149187
p(6) = size(m5, 2)
150188
num_present = num_present + 1
151189
end if
152190

153-
allocate(r(p(1), p(num_present + 1)))
191+
allocate(res(p(1), p(num_present + 1)))
154192

155193
if (num_present == 2) then
156194
m = p(1)
157195
n = p(3)
158196
k = p(2)
159-
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
197+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res, m)
160198
return
161199
end if
162200

@@ -166,10 +204,10 @@ contains
166204
s = matmul_chain_order(p(1: num_present + 1))
167205

168206
if (num_present == 3) then
169-
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
207+
res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
170208
return
171209
else if (num_present == 4) then
172-
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
210+
res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
173211
return
174212
end if
175213

@@ -182,7 +220,7 @@ contains
182220
m = p(1)
183221
n = p(6)
184222
k = p(2)
185-
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
223+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res, m)
186224
case (2)
187225
! (m1*m2)*(m3*m4*m5)
188226
m = p(1)
@@ -195,7 +233,7 @@ contains
195233

196234
k = n
197235
n = p(6)
198-
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
236+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m)
199237
case (3)
200238
! (m1*m2*m3)*(m4*m5)
201239
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
@@ -208,18 +246,35 @@ contains
208246

209247
k = m
210248
m = p(1)
211-
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
249+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m)
212250
case (4)
213251
! (m1*m2*m3*m4)*m5
214252
temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215253
m = p(1)
216254
n = p(6)
217255
k = p(5)
218-
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
256+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res, m)
219257
case default
220-
error stop "stdlib_matmul: error: unexpected s(i,j)"
258+
error stop "stdlib_matmul: internal error: unexpected s(i,j)"
221259
end select
222260

261+
end subroutine stdlib_matmul_sub_${s}$
262+
263+
pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
264+
${t}$, intent(in) :: m1(:,:), m2(:,:)
265+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
266+
${t}$, allocatable :: r(:,:)
267+
268+
call stdlib_matmul_sub(r, m1, m2, m3, m4, m5)
269+
end function stdlib_matmul_pure_${s}$
270+
271+
module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
272+
${t}$, intent(in) :: m1(:,:), m2(:,:)
273+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
274+
type(linalg_state_type), intent(out) :: err
275+
${t}$, allocatable :: r(:,:)
276+
277+
call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err)
223278
end function stdlib_matmul_${s}$
224279

225280
#:endfor

0 commit comments

Comments
 (0)