|
6 | 6 | submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
|
7 | 7 | use stdlib_linalg_blas, only: gemm
|
8 | 8 | use stdlib_constants
|
| 9 | + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR |
9 | 10 | implicit none
|
10 | 11 |
|
| 12 | + character(len=*), parameter :: this = "stdlib_matmul" |
| 13 | + |
11 | 14 | contains
|
12 | 15 |
|
13 | 16 | ! Algorithm for the optimal parenthesization of matrices
|
@@ -71,7 +74,7 @@ contains
|
71 | 74 | k = p(start + 1)
|
72 | 75 | call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
|
73 | 76 | else
|
74 |
| - error stop "stdlib_matmul: error: unexpected s(i,j)" |
| 77 | + error stop "stdlib_matmul: internal error: unexpected s(i,j)" |
75 | 78 | end if
|
76 | 79 |
|
77 | 80 | end function matmul_chain_mult_${s}$_3
|
@@ -117,34 +120,64 @@ contains
|
117 | 120 | k = p(start + 3)
|
118 | 121 | call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
|
119 | 122 | else
|
120 |
| - error stop "stdlib_matmul: error: unexpected s(i,j)" |
| 123 | + error stop "stdlib_matmul: internal error: unexpected s(i,j)" |
121 | 124 | end if
|
122 | 125 |
|
123 | 126 | end function matmul_chain_mult_${s}$_4
|
124 | 127 |
|
125 |
| - pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) |
| 128 | + module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) |
126 | 129 | ${t}$, intent(in) :: m1(:,:), m2(:,:)
|
127 | 130 | ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
|
| 131 | + type(linalg_state_type), intent(out) :: err |
128 | 132 | ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
|
129 | 133 | integer :: p(6), num_present, m, n, k
|
130 | 134 | integer, allocatable :: s(:,:)
|
131 | 135 |
|
| 136 | + type(linalg_state_type) :: err0 |
| 137 | + |
132 | 138 | p(1) = size(m1, 1)
|
133 | 139 | p(2) = size(m2, 1)
|
134 | 140 | p(3) = size(m2, 2)
|
135 | 141 |
|
| 142 | + if (size(m1, 2) /= p(2)) then |
| 143 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes') |
| 144 | + call linalg_error_handling(err0, err) |
| 145 | + return |
| 146 | + end if |
| 147 | + |
136 | 148 | num_present = 2
|
137 | 149 | if (present(m3)) then
|
| 150 | + |
| 151 | + if (size(m3, 1) /= p(3)) then |
| 152 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes') |
| 153 | + call linalg_error_handling(err0, err) |
| 154 | + return |
| 155 | + end if |
| 156 | + |
138 | 157 | p(3) = size(m3, 1)
|
139 | 158 | p(4) = size(m3, 2)
|
140 | 159 | num_present = num_present + 1
|
141 | 160 | end if
|
142 | 161 | if (present(m4)) then
|
| 162 | + |
| 163 | + if (size(m4, 1) /= p(4)) then |
| 164 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes') |
| 165 | + call linalg_error_handling(err0, err) |
| 166 | + return |
| 167 | + end if |
| 168 | + |
143 | 169 | p(4) = size(m4, 1)
|
144 | 170 | p(5) = size(m4, 2)
|
145 | 171 | num_present = num_present + 1
|
146 | 172 | end if
|
147 | 173 | if (present(m5)) then
|
| 174 | + |
| 175 | + if (size(m5, 1) /= p(5)) then |
| 176 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes') |
| 177 | + call linalg_error_handling(err0, err) |
| 178 | + return |
| 179 | + end if |
| 180 | + |
148 | 181 | p(5) = size(m5, 1)
|
149 | 182 | p(6) = size(m5, 2)
|
150 | 183 | num_present = num_present + 1
|
@@ -217,10 +250,129 @@ contains
|
217 | 250 | k = p(5)
|
218 | 251 | call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
|
219 | 252 | case default
|
220 |
| - error stop "stdlib_matmul: error: unexpected s(i,j)" |
| 253 | + error stop "stdlib_matmul: internal error: unexpected s(i,j)" |
221 | 254 | end select
|
222 | 255 |
|
223 | 256 | end function stdlib_matmul_${s}$
|
224 | 257 |
|
| 258 | + pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) |
| 259 | + ${t}$, intent(in) :: m1(:,:), m2(:,:) |
| 260 | + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) |
| 261 | + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) |
| 262 | + integer :: p(6), num_present, m, n, k |
| 263 | + integer, allocatable :: s(:,:) |
| 264 | + |
| 265 | + p(1) = size(m1, 1) |
| 266 | + p(2) = size(m2, 1) |
| 267 | + p(3) = size(m2, 2) |
| 268 | + |
| 269 | + if (size(m1, 2) /= p(2)) then |
| 270 | + error stop 'matrices m1, m2 not of compatible sizes' |
| 271 | + end if |
| 272 | + |
| 273 | + num_present = 2 |
| 274 | + if (present(m3)) then |
| 275 | + |
| 276 | + if (size(m3, 1) /= p(3)) then |
| 277 | + error stop 'matrices m2, m3 not of compatible sizes' |
| 278 | + end if |
| 279 | + |
| 280 | + p(3) = size(m3, 1) |
| 281 | + p(4) = size(m3, 2) |
| 282 | + num_present = num_present + 1 |
| 283 | + end if |
| 284 | + if (present(m4)) then |
| 285 | + |
| 286 | + if (size(m4, 1) /= p(4)) then |
| 287 | + error stop 'matrices m3, m4 not of compatible sizes' |
| 288 | + end if |
| 289 | + |
| 290 | + p(4) = size(m4, 1) |
| 291 | + p(5) = size(m4, 2) |
| 292 | + num_present = num_present + 1 |
| 293 | + end if |
| 294 | + if (present(m5)) then |
| 295 | + |
| 296 | + if (size(m5, 1) /= p(5)) then |
| 297 | + error stop 'matrices m4, m5 not of compatible sizes' |
| 298 | + end if |
| 299 | + |
| 300 | + p(5) = size(m5, 1) |
| 301 | + p(6) = size(m5, 2) |
| 302 | + num_present = num_present + 1 |
| 303 | + end if |
| 304 | + |
| 305 | + allocate(r(p(1), p(num_present + 1))) |
| 306 | + |
| 307 | + if (num_present == 2) then |
| 308 | + m = p(1) |
| 309 | + n = p(3) |
| 310 | + k = p(2) |
| 311 | + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m) |
| 312 | + return |
| 313 | + end if |
| 314 | + |
| 315 | + ! Now num_present >= 3 |
| 316 | + allocate(s(1:num_present - 1, 2:num_present)) |
| 317 | + |
| 318 | + s = matmul_chain_order(p(1: num_present + 1)) |
| 319 | + |
| 320 | + if (num_present == 3) then |
| 321 | + r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) |
| 322 | + return |
| 323 | + else if (num_present == 4) then |
| 324 | + r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) |
| 325 | + return |
| 326 | + end if |
| 327 | + |
| 328 | + ! Now num_present is 5 |
| 329 | + |
| 330 | + select case (s(1, 5)) |
| 331 | + case (1) |
| 332 | + ! m1*(m2*m3*m4*m5) |
| 333 | + temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p) |
| 334 | + m = p(1) |
| 335 | + n = p(6) |
| 336 | + k = p(2) |
| 337 | + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) |
| 338 | + case (2) |
| 339 | + ! (m1*m2)*(m3*m4*m5) |
| 340 | + m = p(1) |
| 341 | + n = p(3) |
| 342 | + k = p(2) |
| 343 | + allocate(temp(m,n)) |
| 344 | + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) |
| 345 | + |
| 346 | + temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p) |
| 347 | + |
| 348 | + k = n |
| 349 | + n = p(6) |
| 350 | + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) |
| 351 | + case (3) |
| 352 | + ! (m1*m2*m3)*(m4*m5) |
| 353 | + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p) |
| 354 | + |
| 355 | + m = p(4) |
| 356 | + n = p(6) |
| 357 | + k = p(5) |
| 358 | + allocate(temp1(m,n)) |
| 359 | + call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m) |
| 360 | + |
| 361 | + k = m |
| 362 | + m = p(1) |
| 363 | + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) |
| 364 | + case (4) |
| 365 | + ! (m1*m2*m3*m4)*m5 |
| 366 | + temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p) |
| 367 | + m = p(1) |
| 368 | + n = p(6) |
| 369 | + k = p(5) |
| 370 | + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m) |
| 371 | + case default |
| 372 | + error stop "stdlib_matmul: internal error: unexpected s(i,j)" |
| 373 | + end select |
| 374 | + |
| 375 | + end function stdlib_matmul_pure_${s}$ |
| 376 | + |
225 | 377 | #:endfor
|
226 | 378 | end submodule stdlib_intrinsics_matmul
|
0 commit comments