5
5
6
6
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
7
7
use stdlib_linalg_blas, only: gemm
8
+ use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR
8
9
use stdlib_constants
9
10
implicit none
10
11
12
+ character(len=*), parameter :: this = "stdlib_matmul"
13
+
11
14
contains
12
15
13
16
! Algorithm for the optimal parenthesization of matrices
@@ -122,41 +125,76 @@ contains
122
125
123
126
end function matmul_chain_mult_${s}$_4
124
127
125
- pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
128
+ pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err)
129
+ ${t}$, intent(out), allocatable :: res(:,:)
126
130
${t}$, intent(in) :: m1(:,:), m2(:,:)
127
131
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
128
- ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
132
+ type(linalg_state_type), intent(out), optional :: err
133
+ ${t}$, allocatable :: temp(:,:), temp1(:,:)
129
134
integer :: p(6), num_present, m, n, k
130
135
integer, allocatable :: s(:,:)
131
136
137
+ type(linalg_state_type) :: err0
138
+
132
139
p(1) = size(m1, 1)
133
140
p(2) = size(m2, 1)
134
141
p(3) = size(m2, 2)
135
142
143
+ if (size(m1, 2) /= p(2)) then
144
+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes')
145
+ call linalg_error_handling(err0, err)
146
+ allocate(res(0, 0))
147
+ return
148
+ end if
149
+
136
150
num_present = 2
137
151
if (present(m3)) then
152
+
153
+ if (size(m3, 1) /= p(3)) then
154
+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes')
155
+ call linalg_error_handling(err0, err)
156
+ allocate(res(0, 0))
157
+ return
158
+ end if
159
+
138
160
p(3) = size(m3, 1)
139
161
p(4) = size(m3, 2)
140
162
num_present = num_present + 1
141
163
end if
142
164
if (present(m4)) then
165
+
166
+ if (size(m4, 1) /= p(4)) then
167
+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes')
168
+ call linalg_error_handling(err0, err)
169
+ allocate(res(0, 0))
170
+ return
171
+ end if
172
+
143
173
p(4) = size(m4, 1)
144
174
p(5) = size(m4, 2)
145
175
num_present = num_present + 1
146
176
end if
147
177
if (present(m5)) then
178
+
179
+ if (size(m5, 1) /= p(5)) then
180
+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes')
181
+ call linalg_error_handling(err0, err)
182
+ allocate(res(0, 0))
183
+ return
184
+ end if
185
+
148
186
p(5) = size(m5, 1)
149
187
p(6) = size(m5, 2)
150
188
num_present = num_present + 1
151
189
end if
152
190
153
- allocate(r (p(1), p(num_present + 1)))
191
+ allocate(res (p(1), p(num_present + 1)))
154
192
155
193
if (num_present == 2) then
156
194
m = p(1)
157
195
n = p(3)
158
196
k = p(2)
159
- call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r , m)
197
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res , m)
160
198
return
161
199
end if
162
200
@@ -166,10 +204,10 @@ contains
166
204
s = matmul_chain_order(p(1: num_present + 1))
167
205
168
206
if (num_present == 3) then
169
- r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
207
+ res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
170
208
return
171
209
else if (num_present == 4) then
172
- r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
210
+ res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
173
211
return
174
212
end if
175
213
@@ -182,7 +220,7 @@ contains
182
220
m = p(1)
183
221
n = p(6)
184
222
k = p(2)
185
- call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r , m)
223
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res , m)
186
224
case (2)
187
225
! (m1*m2)*(m3*m4*m5)
188
226
m = p(1)
@@ -195,7 +233,7 @@ contains
195
233
196
234
k = n
197
235
n = p(6)
198
- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
236
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
199
237
case (3)
200
238
! (m1*m2*m3)*(m4*m5)
201
239
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
@@ -208,18 +246,35 @@ contains
208
246
209
247
k = m
210
248
m = p(1)
211
- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
249
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
212
250
case (4)
213
251
! (m1*m2*m3*m4)*m5
214
252
temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215
253
m = p(1)
216
254
n = p(6)
217
255
k = p(5)
218
- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r , m)
256
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res , m)
219
257
case default
220
- error stop "stdlib_matmul: error: unexpected s(i,j)"
258
+ error stop "stdlib_matmul: internal error: unexpected s(i,j)"
221
259
end select
222
260
261
+ end subroutine stdlib_matmul_sub_${s}$
262
+
263
+ pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
264
+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
265
+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
266
+ ${t}$, allocatable :: r(:,:)
267
+
268
+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5)
269
+ end function stdlib_matmul_pure_${s}$
270
+
271
+ module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
272
+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
273
+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
274
+ type(linalg_state_type), intent(out) :: err
275
+ ${t}$, allocatable :: r(:,:)
276
+
277
+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err)
223
278
end function stdlib_matmul_${s}$
224
279
225
280
#:endfor
0 commit comments