Skip to content

Commit 425d1d0

Browse files
committed
add error handling in case matrices are incompatible
1 parent e709f83 commit 425d1d0

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

Diff for: src/stdlib_intrinsics.fypp

+9-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module stdlib_intrinsics
88
!!Alternative implementations of some Fortran intrinsic functions offering either faster and/or more accurate evaluation.
99
!! ([Specification](../page/specs/stdlib_intrinsics.html))
1010
use stdlib_kinds
11+
use stdlib_linalg_state, only: linalg_state_type
1112
implicit none
1213
private
1314

@@ -162,10 +163,17 @@ module stdlib_intrinsics
162163
!!
163164
!! Note: The matrices must be of compatible shapes to be multiplied
164165
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
165-
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
166+
pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
166167
${t}$, intent(in) :: m1(:,:), m2(:,:)
167168
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
168169
${t}$, allocatable :: r(:,:)
170+
end function stdlib_matmul_pure_${s}$
171+
172+
module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
173+
${t}$, intent(in) :: m1(:,:), m2(:,:)
174+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
175+
type(linalg_state_type), intent(out) :: err
176+
${t}$, allocatable :: r(:,:)
169177
end function stdlib_matmul_${s}$
170178
#:endfor
171179
end interface stdlib_matmul

Diff for: src/stdlib_intrinsics_matmul.fypp

+156-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
77
use stdlib_linalg_blas, only: gemm
88
use stdlib_constants
9+
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR
910
implicit none
1011

12+
character(len=*), parameter :: this = "stdlib_matmul"
13+
1114
contains
1215

1316
! Algorithm for the optimal parenthesization of matrices
@@ -71,7 +74,7 @@ contains
7174
k = p(start + 1)
7275
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
7376
else
74-
error stop "stdlib_matmul: error: unexpected s(i,j)"
77+
error stop "stdlib_matmul: internal error: unexpected s(i,j)"
7578
end if
7679

7780
end function matmul_chain_mult_${s}$_3
@@ -117,34 +120,64 @@ contains
117120
k = p(start + 3)
118121
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
119122
else
120-
error stop "stdlib_matmul: error: unexpected s(i,j)"
123+
error stop "stdlib_matmul: internal error: unexpected s(i,j)"
121124
end if
122125

123126
end function matmul_chain_mult_${s}$_4
124127

125-
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
128+
module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
126129
${t}$, intent(in) :: m1(:,:), m2(:,:)
127130
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
131+
type(linalg_state_type), intent(out) :: err
128132
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
129133
integer :: p(6), num_present, m, n, k
130134
integer, allocatable :: s(:,:)
131135

136+
type(linalg_state_type) :: err0
137+
132138
p(1) = size(m1, 1)
133139
p(2) = size(m2, 1)
134140
p(3) = size(m2, 2)
135141

142+
if (size(m1, 2) /= p(2)) then
143+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes')
144+
call linalg_error_handling(err0, err)
145+
return
146+
end if
147+
136148
num_present = 2
137149
if (present(m3)) then
150+
151+
if (size(m3, 1) /= p(3)) then
152+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes')
153+
call linalg_error_handling(err0, err)
154+
return
155+
end if
156+
138157
p(3) = size(m3, 1)
139158
p(4) = size(m3, 2)
140159
num_present = num_present + 1
141160
end if
142161
if (present(m4)) then
162+
163+
if (size(m4, 1) /= p(4)) then
164+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes')
165+
call linalg_error_handling(err0, err)
166+
return
167+
end if
168+
143169
p(4) = size(m4, 1)
144170
p(5) = size(m4, 2)
145171
num_present = num_present + 1
146172
end if
147173
if (present(m5)) then
174+
175+
if (size(m5, 1) /= p(5)) then
176+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes')
177+
call linalg_error_handling(err0, err)
178+
return
179+
end if
180+
148181
p(5) = size(m5, 1)
149182
p(6) = size(m5, 2)
150183
num_present = num_present + 1
@@ -217,10 +250,129 @@ contains
217250
k = p(5)
218251
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
219252
case default
220-
error stop "stdlib_matmul: error: unexpected s(i,j)"
253+
error stop "stdlib_matmul: internal error: unexpected s(i,j)"
221254
end select
222255

223256
end function stdlib_matmul_${s}$
224257

258+
pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
259+
${t}$, intent(in) :: m1(:,:), m2(:,:)
260+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
261+
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
262+
integer :: p(6), num_present, m, n, k
263+
integer, allocatable :: s(:,:)
264+
265+
p(1) = size(m1, 1)
266+
p(2) = size(m2, 1)
267+
p(3) = size(m2, 2)
268+
269+
if (size(m1, 2) /= p(2)) then
270+
error stop 'matrices m1, m2 not of compatible sizes'
271+
end if
272+
273+
num_present = 2
274+
if (present(m3)) then
275+
276+
if (size(m3, 1) /= p(3)) then
277+
error stop 'matrices m2, m3 not of compatible sizes'
278+
end if
279+
280+
p(3) = size(m3, 1)
281+
p(4) = size(m3, 2)
282+
num_present = num_present + 1
283+
end if
284+
if (present(m4)) then
285+
286+
if (size(m4, 1) /= p(4)) then
287+
error stop 'matrices m3, m4 not of compatible sizes'
288+
end if
289+
290+
p(4) = size(m4, 1)
291+
p(5) = size(m4, 2)
292+
num_present = num_present + 1
293+
end if
294+
if (present(m5)) then
295+
296+
if (size(m5, 1) /= p(5)) then
297+
error stop 'matrices m4, m5 not of compatible sizes'
298+
end if
299+
300+
p(5) = size(m5, 1)
301+
p(6) = size(m5, 2)
302+
num_present = num_present + 1
303+
end if
304+
305+
allocate(r(p(1), p(num_present + 1)))
306+
307+
if (num_present == 2) then
308+
m = p(1)
309+
n = p(3)
310+
k = p(2)
311+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
312+
return
313+
end if
314+
315+
! Now num_present >= 3
316+
allocate(s(1:num_present - 1, 2:num_present))
317+
318+
s = matmul_chain_order(p(1: num_present + 1))
319+
320+
if (num_present == 3) then
321+
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
322+
return
323+
else if (num_present == 4) then
324+
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
325+
return
326+
end if
327+
328+
! Now num_present is 5
329+
330+
select case (s(1, 5))
331+
case (1)
332+
! m1*(m2*m3*m4*m5)
333+
temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p)
334+
m = p(1)
335+
n = p(6)
336+
k = p(2)
337+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
338+
case (2)
339+
! (m1*m2)*(m3*m4*m5)
340+
m = p(1)
341+
n = p(3)
342+
k = p(2)
343+
allocate(temp(m,n))
344+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
345+
346+
temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p)
347+
348+
k = n
349+
n = p(6)
350+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
351+
case (3)
352+
! (m1*m2*m3)*(m4*m5)
353+
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
354+
355+
m = p(4)
356+
n = p(6)
357+
k = p(5)
358+
allocate(temp1(m,n))
359+
call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m)
360+
361+
k = m
362+
m = p(1)
363+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
364+
case (4)
365+
! (m1*m2*m3*m4)*m5
366+
temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
367+
m = p(1)
368+
n = p(6)
369+
k = p(5)
370+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
371+
case default
372+
error stop "stdlib_matmul: internal error: unexpected s(i,j)"
373+
end select
374+
375+
end function stdlib_matmul_pure_${s}$
376+
225377
#:endfor
226378
end submodule stdlib_intrinsics_matmul

0 commit comments

Comments
 (0)