From 0053e980363aade9b9158db9ca083753b598b2c8 Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Wed, 16 Feb 2022 12:25:52 +0100 Subject: [PATCH] added regtests for pytorch module --- regtest/.gitignore | 1 + regtest/pytorch/Makefile | 2 + .../pytorch/rt-pytorch_model/COLVAR.reference | 14 ++++++ .../rt-pytorch_model/DERIVATIVES.reference | 12 +++++ regtest/pytorch/rt-pytorch_model/Makefile | 1 + regtest/pytorch/rt-pytorch_model/alanine.xtc | Bin 0 -> 8756 bytes regtest/pytorch/rt-pytorch_model/config | 10 ++++ .../rt-pytorch_model/create-pytorch-model.py | 43 +++++++++++++++++ regtest/pytorch/rt-pytorch_model/plumed.dat | 14 ++++++ .../pytorch/rt-pytorch_model/torch_model.ptc | Bin 0 -> 1502 bytes .../rt-pytorch_model_2d/COLVAR.reference | 16 +++++++ .../rt-pytorch_model_2d/DERIVATIVES.reference | 23 +++++++++ regtest/pytorch/rt-pytorch_model_2d/Makefile | 1 + .../pytorch/rt-pytorch_model_2d/alanine.xtc | Bin 0 -> 8756 bytes regtest/pytorch/rt-pytorch_model_2d/config | 10 ++++ .../create-pytorch-model.py | 45 ++++++++++++++++++ .../pytorch/rt-pytorch_model_2d/plumed.dat | 15 ++++++ .../rt-pytorch_model_2d/torch_model.ptc | Bin 0 -> 1502 bytes .../COLVAR.reference | 14 ++++++ .../DERIVATIVES.reference | 12 +++++ .../rt-pytorch_model_derivatives/Makefile | 1 + .../rt-pytorch_model_derivatives/alanine.xtc | Bin 0 -> 8756 bytes .../rt-pytorch_model_derivatives/config | 10 ++++ .../create-pytorch-model.py | 43 +++++++++++++++++ .../rt-pytorch_model_derivatives/plumed.dat | 17 +++++++ .../torch_model.ptc | Bin 0 -> 1502 bytes .../rt-pytorch_model_script/COLVAR.reference | 14 ++++++ .../DERIVATIVES.reference | 12 +++++ .../pytorch/rt-pytorch_model_script/Makefile | 1 + .../rt-pytorch_model_script/alanine.xtc | Bin 0 -> 8756 bytes .../pytorch/rt-pytorch_model_script/config | 10 ++++ .../create-pytorch-model.py | 28 +++++++++++ .../rt-pytorch_model_script/plumed.dat | 14 ++++++ .../rt-pytorch_model_script/torch_model.ptc | Bin 0 -> 1630 bytes 34 files changed, 383 insertions(+) create mode 100644 regtest/pytorch/Makefile create mode 100644 regtest/pytorch/rt-pytorch_model/COLVAR.reference create mode 100644 regtest/pytorch/rt-pytorch_model/DERIVATIVES.reference create mode 100644 regtest/pytorch/rt-pytorch_model/Makefile create mode 100644 regtest/pytorch/rt-pytorch_model/alanine.xtc create mode 100644 regtest/pytorch/rt-pytorch_model/config create mode 100644 regtest/pytorch/rt-pytorch_model/create-pytorch-model.py create mode 100644 regtest/pytorch/rt-pytorch_model/plumed.dat create mode 100644 regtest/pytorch/rt-pytorch_model/torch_model.ptc create mode 100644 regtest/pytorch/rt-pytorch_model_2d/COLVAR.reference create mode 100644 regtest/pytorch/rt-pytorch_model_2d/DERIVATIVES.reference create mode 100644 regtest/pytorch/rt-pytorch_model_2d/Makefile create mode 100644 regtest/pytorch/rt-pytorch_model_2d/alanine.xtc create mode 100644 regtest/pytorch/rt-pytorch_model_2d/config create mode 100644 regtest/pytorch/rt-pytorch_model_2d/create-pytorch-model.py create mode 100644 regtest/pytorch/rt-pytorch_model_2d/plumed.dat create mode 100644 regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/COLVAR.reference create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/DERIVATIVES.reference create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/Makefile create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/alanine.xtc create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/config create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/create-pytorch-model.py create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/plumed.dat create mode 100644 regtest/pytorch/rt-pytorch_model_derivatives/torch_model.ptc create mode 100644 regtest/pytorch/rt-pytorch_model_script/COLVAR.reference create mode 100644 regtest/pytorch/rt-pytorch_model_script/DERIVATIVES.reference create mode 100644 regtest/pytorch/rt-pytorch_model_script/Makefile create mode 100644 regtest/pytorch/rt-pytorch_model_script/alanine.xtc create mode 100644 regtest/pytorch/rt-pytorch_model_script/config create mode 100644 regtest/pytorch/rt-pytorch_model_script/create-pytorch-model.py create mode 100644 regtest/pytorch/rt-pytorch_model_script/plumed.dat create mode 100644 regtest/pytorch/rt-pytorch_model_script/torch_model.ptc diff --git a/regtest/.gitignore b/regtest/.gitignore index 7e71e66bea..606374a229 100644 --- a/regtest/.gitignore +++ b/regtest/.gitignore @@ -31,6 +31,7 @@ !/tools !/sasa !/s2cm +!/pytorch # These files we just want to ignore completely tmp report.txt diff --git a/regtest/pytorch/Makefile b/regtest/pytorch/Makefile new file mode 100644 index 0000000000..42480767ae --- /dev/null +++ b/regtest/pytorch/Makefile @@ -0,0 +1,2 @@ +include ../scripts/module.make + diff --git a/regtest/pytorch/rt-pytorch_model/COLVAR.reference b/regtest/pytorch/rt-pytorch_model/COLVAR.reference new file mode 100644 index 0000000000..de11fe158c --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/COLVAR.reference @@ -0,0 +1,14 @@ +#! FIELDS time phi model.node-0 +#! SET min_phi -pi +#! SET max_phi pi + 0.000000 -2.85656 -0.281187 + 5.000000 -2.29224 -0.750853 + 10.000000 -2.13168 -0.846786 + 15.000000 -2.92489 -0.215007 + 20.000000 -2.56135 -0.548228 + 25.000000 -1.10925 -0.895364 + 30.000000 1.47871 0.995763 + 35.000000 -1.27313 -0.956024 + 40.000000 -1.10608 -0.893949 + 45.000000 -3.03879 -0.102622 + 50.000000 -1.69099 -0.992786 diff --git a/regtest/pytorch/rt-pytorch_model/DERIVATIVES.reference b/regtest/pytorch/rt-pytorch_model/DERIVATIVES.reference new file mode 100644 index 0000000000..7b5269019b --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/DERIVATIVES.reference @@ -0,0 +1,12 @@ +#! FIELDS time parameter model.node-0 + 0.000000 0 -0.9596529007 + 5.000000 0 -0.6604691148 + 10.000000 0 -0.5319328904 + 15.000000 0 -0.9766126275 + 20.000000 0 -0.8363288045 + 25.000000 0 0.4453351200 + 30.000000 0 0.0919610485 + 35.000000 0 0.2932883203 + 40.000000 0 0.4481694996 + 45.000000 0 -0.9947203994 + 50.000000 0 -0.1199020892 diff --git a/regtest/pytorch/rt-pytorch_model/Makefile b/regtest/pytorch/rt-pytorch_model/Makefile new file mode 100644 index 0000000000..3703b27cea --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/Makefile @@ -0,0 +1 @@ +include ../../scripts/test.make diff --git a/regtest/pytorch/rt-pytorch_model/alanine.xtc b/regtest/pytorch/rt-pytorch_model/alanine.xtc new file mode 100644 index 0000000000000000000000000000000000000000..182058e94cceed7318fef51f1a0fc31cdba2fc81 GIT binary patch literal 8756 zcmb8#c{J4R`v-6}t0j#>ww@Sk)=^~3I%8k6WzAL|JE2IMN6|30NcBV#A^Q>%MJPpt zEQ!dLBw0!!$?qERr>F6*{@Q2*6*kBp+5IuC%5#tHMA^xV3M-E1goNp0P z{Ho_+u_b$!MLZ_=nLam>{>=B?Z2Kl%0`<&&ejF*_?w37=*2AVEMt?YI2CWTkS~XW$PKG7PCe~^+TeA+Ki2IV)=;p%18WOdGhi*ls{~xE5FTGA`#t)U8tZ7XMj|cr|!wwq6D6Pq5a5^BQD%LLds(nl(U)Ti&%L z?iKIPj<4j9;imK+xfi1S$83gVx=OEG#JJ8Fj6^4pGS2i{wH5?ynPf|O$Hi8!Uh`RJ zyv5^=CSMWdKLZI=w)krkOT30l9 zTUI{FEqSX8Gk;2qbD{AyrdwLm&+wgJRQV(c;+79T&DTpt33?8=hwP${JXvz{zSj9h z4MkN!zNgkt+l`bGq6G{>@=}t{kRkQIx~8aUq|t0$gt|6^wF9h2Aj=X09bv82wTii3 z$3(mMl}Mt+he#D(nUM%<1j82vpYQkD4m3@ywrz9!wtHci`SeNSzpM4o!K_iG`d8W_ zc9txNxHqgc(p|k~<=U?{7H3L-DU_A|1MaYtaKx!XQ zuPu`xQ%f?jzkpq`)4_%EVds+bZ3ZO60B=p$sMN6Y1Nz&H8Z*pXl&OihuMOW2OALmq z^Hv$G0^ETw-dtf>kbhP9)W#sJmEN<}boHr4exp$6AM2!}#w@K?BwYU}6R_?8>kU{x zz-R$$t*$%?*78=nw=wVsIGd7F2u%3GnlP1E@iSapI!meN{!#2Umc3tar23s)`S~H+ zm-qf2aTQ#y8hh8^>7m1qaOfBNxaXi_L)T-wc+#1}x=U#gyI$8VEPg~LnytBDMS|5A zoKlcw0D<8CZM~PMAn#2vW)SqnIg^j2;P4veAs$ovcw6_Y1&7`}X?eA|zj@5h!s>!Ra0S+=?rs!Z{r0g9q=pY=y5t z&Vn@(tWRK#g7r3puJynwLWRTWAud7+q2@$U9%91xwMr`X#3H7z%bjB+I0QN?jqJ9+ z8()5B62)Em==f=tc0t^iCdnmsx*lWP$te>D!|d%ODc8wcxA^nsnhZBtBhrXEG+W^- zkVRmv1!org+$vzDmSD{~93GDK)nRLtTJ_x$Rg#9otMBtZ=tDlw^JAZ1ijT1T;Pg*B zF8&lm@bYScbg$AxYvhzpK}A3np|QK8HFEi3OWUkS@=9WP!Pp6DX+gp!Ye6y5j%KSV zSQo$=4OW=77#Xm3Z&=-@S~=VmShe_@Ow^+06Xfv`?^O(pI4VhBODiKEJ?YBM3|B~R zH&Z?_S%3bW=W8BVjlf#5VLhW!fOW&!poOsn0y!=?E&cYt9C}@H6*t@^=V$vv)5dp>GC19d zq);a8e(uW(=`4Gu5>MckT<&P=`Sof%^zlu**Ba;~v`3e1bcSmVY&ZYSf#r5$7|nHs zJqpg2dg{*yG9o(?Gg(I-Gy7L-h75q4|Sfv5J)_+y6&eFB}Wz$t9tI{#UC(ci& zDoW4oJzO4dBpm7;r?C<%9AC4*-t|gO?7OS6u5o57=8egZB9ruTfoSU3UY}W8@k2y- zC;!K}3SRg7Z^$om|ItNVrnvmGT&FBce4nljB%+XorpYF1H zE6LU;aRKE$9PY}F=_Rq%3+kWde|@^n6dT|lp=f-|WM5~E_;~5nyK~m1n6~oHde-Xc ztmbVKGyHjKf2_K*#H%!0_k*=?b@dlC6S^I;q`?aH{e5k{E-lBNs+-u#$5I~}g&V@+ zHAmPKryoypMu;b-RtK`FiOwh1-kr<+LA~{LZTqSQc9N#iOQiHCut(1+Ws}59deqFwNQL&cQx{=KLb`obVe`|kk zOI%iy@FD`A`_9ecC;cYtbsQ0PAsTytYJCDd&;wTLwY3wh|IJsKAOu#=Au~cq=|o9n z95K-O_83m_g=+Xi{~kY87L{$b8O{}5XY&{hF$=9c zy2sDve5OR_9k)2pO_12+lrbT`qS^T`y#LeSyaP_^{+b8Q`Hh+`-melds9Hhub*%VB zJ=bByn>?xyc^`ZEnnM*@($%cl!Omq#fPKirP_mPIGWW8cZsxIKtNyG}^L)#$ut{=n zzji;fwxd%E<-%l1ntmS!vL94HB}27>iCSZ{!}A3~`!!04BIg9x!#y_r`) zup_Dx$%;?-!EwquLGx4;Yw}eY&A4KJuZh|ll)bWmNRTdPR*&|`IkFd}f!Gae zgSS6Nfa8}&sev>?RLzNp>ES$H9y1qMnP=o1J=ud!e5D^b$rmZnxzoFSv6IJZvdAvC z>7wCZLsxvdVxLu9Xj@e)zj;ZCq$GT}yXZ-&`yVUvuVNC0W-H7l6r8J8O2G=hmFXSOz;aBlb!~s_q&{D5kRq1h@4R_cEB z0<3e81$k>us@b?(Nu0xTzkMUGo@gDpY>Ej#POoCqkZ*g0NO!Yu5pnUno=b*nCy!|9 zvv($Q1c&-J&v?CjJXOA$^co#b``+K+A}hJW;LgcyerH8kFV&TAvL+6Z;5*QGU0njJ z2v`@_t$QF)0oIzenGJz8!0~DE9hBZej(PYt9TkhS`E@oS^vvt9U0DIk`#GJ>Bz#?t z<{$cUz=HGh*3Us#n9h0$eY`qzh<}DX#i-VRYwHVDcZ7g&NE$gfYLm5UmIPf!qZQ88 zD0Q$xpQ4wbuI3O(J+x7Kz{G4ZNB*>^6@$bUZ|SJ#j68VR87rI5yHH)2ZEmU0Jx{&W zs*QJ*YK~r9svGxxw{J(ITvw&?a!$3l8!Jh!K8Iz@l;qOb9PCv|-!n<{9SZ%^10!@K zUYe~iPuK3XVGlsVi3dY%0_vu{W=&q?;P4N3ix9;R5h&Uj0_i2jjW%!J);%#ll;(sU z?7}ci*$Xu#zxj1v!1;}7ya{w!`ETd;{6id6rXgE6GCdbh^% z@II-{dlrhoOM<5-H1*h0#;5i!WH_vUri@5@@L$V(foEF{K#ooCOkdi z3y(jNK6~7Vz0CCaM4DGpl~;)xvb_7AzHoH`U&Y7AS}xCS&k{ zcuV1Q9!rz((!`PZ@xt<2XWwp(=&{<177ytTWlf|aCi;B>V&?mEo98nF`>#i4H<$Xy zo?8ly{!HS(*D)@>$$CzWBtx?m?o&|aV0{8sY9CPFe?3^#emzuBWW{k9k&hON3FD6( zL%HGA?nZG9*9akJKk}X8ZJ}`@7}KBdYy3L(EyqIr)r&v$&!mu(#^ke2)3uqPT~$0KYzp4dklME1}WPb zVUU%7>9)cZoq^aZ<*trBws{qs<`$~Ok)%Si^(a_1z?uoxd9YGXLugpk>&w7i|1{@T zbrrXst=Q{n_v*z0wqz7?4`0td7aMfqn#E+wT&R&=v5Y;PGIC`rHkYB+V|?F5U*gu# zm@9P}rTK3Ds+MhuayHthZa?z8`=_Q~{YXTbouS|i0p}>>KrPXe;N#hF?)O59=(MWU zAE~T)@T9X|qA<9{_16-s@Fnd%^Mcvnn_`2FWZbA^(0Xm;L!888pdWiU2C7PU_WV6Dwo#z1j-uRg&> zF_j@+vK@UdUi+N9==5nPv$LX}6=}XHs?9Y-*O)}Puyb?AXXT0$I=9~$oc&Sq;LPb} z-W3jV@Jgx?N9(dNW?RPyf2f#^v|JSTi$45kT z`aC0qzI`61d?z6By}<=s(~9qRR1qY~`DHw>=si64_eW!g=D(d%46SepF|XuEi32fJ zWmb&Jg5{BHMB}4$ z?%K&`J3|kjb8nEdLpU5QzNVuuRAv_*$3ijOWQ{8(!P5vDt#x3v0qeWpP7HOfYQtJP zw`i3raCm<=XOML8{D<;3@$N`MYK&;F{x24ORl4zxHqjc9Kn0d;k1$=n2gk^yz?x(F zgO<*VL1hf=KU}&k6@_=Hp^8|o(M22@r@2K$O(1rCA3#?;CfU(!h5Hp07OW(&CO{VS z0-fW!)!=j2Uv5r%RX0q-4_)Q!VHQ2_zqO#&)=Eh9bMvx7y||`y??uI?F9O$3@7B@Z zwG#5Nq0r!)&yEMv&QE6!j^CLCro6F zKb!;Jt~-sJ+KYAE^=z7X(>vqIDhlH1WxURcZzvC&cDf(D!_sQ^)HaYXAAec;ErZ-E zPi&U$_BaYp;`l<+nK=)(P`tj+Ys-<`%2{H~p>mz?t5mv80JBii}6;(T}t1=GN;SCU|*r@YB3HcRrP*$VFq3Z7f8 zJ#&Y1%i5X6213_nfd9cH5rdO*%%a~TJEIO~2s#%TswuwvHE^-GXLN}-C^DOe>aRmQu3dgdSSN@es_U=H#jA7iyvRmLW(8662KL*sSz z6ncRAT<8tdGzhHJxhl6|H5cfX*LAUYB{n+kPtMqqnSN|F(qrnXW=v8cy|*$gAFEZUrtRiIWGjj)nf|fe;+& z1kKhKsOxdCHiNYfvfxgg4$iW_XF!_n7HmjmftGkl?<>lSlCt!hs7%EV5_UU`zSj+0 zAD{K8bTR&Tx=ME{$o`U>KTF;F8{U0SfP5Ls16Z3-D zjefn&OZ~qUG}aXw9=QxwZ?GPOEb2ahU$+`5l5KSC!kmm^>2Ce|R`KY^%Wk)dt+J#f zj)oO5*N8@?YojSQ8xHpD(!|(De|vlrz2`ZSX~q3|gs;}4laKhc%~f+``uRVpUD$R& L{`DX)9o_!{2f={` literal 0 HcmV?d00001 diff --git a/regtest/pytorch/rt-pytorch_model/config b/regtest/pytorch/rt-pytorch_model/config new file mode 100644 index 0000000000..da5a9cd874 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/config @@ -0,0 +1,10 @@ +plumed_needs=libtorch +plumed_modules=pytorch +type=driver +arg="--plumed plumed.dat --mf_xtc alanine.xtc" + +# note: model has been previously created with create-pytorch-model.py +# the following crashes in CI +# function plumed_regtest_before(){ +# python create-pytorch-model.py +#} diff --git a/regtest/pytorch/rt-pytorch_model/create-pytorch-model.py b/regtest/pytorch/rt-pytorch_model/create-pytorch-model.py new file mode 100644 index 0000000000..2442615499 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/create-pytorch-model.py @@ -0,0 +1,43 @@ +import torch +print(torch.__version__) + +def my_torch_cv(x): + ''' + Here goes the definition of the CV. + + Inputs: + x (torch.tensor): input, either scalar or 1-D array + Return: + y (torch.tensor): collective variable (scalar) + ''' + # CV definition + y = torch.sin(x) + + return y + +input_size = 1 + +# -- DEFINE INPUT -- +#random +#x = torch.rand(input_size, dtype=torch.float32, requires_grad=True).unsqueeze(0) +#or by choosing the value(s) of the array +x = torch.tensor([0.], dtype=torch.float32, requires_grad=True) + +# -- CALCULATE CV -- +y = my_torch_cv(x) + +# -- CALCULATE DERIVATIVES -- +for yy in y: + dy = torch.autograd.grad(yy, x, create_graph=True) + # -- PRINT -- + print('CV TEST') + print('n_input\t: {}'.format(input_size)) + print('x\t: {}'.format(x)) + print('cv\t: {}'.format(yy)) + print('der\t: {}'.format(dy)) + +# Compile via tracing +traced_cv = torch.jit.trace ( my_torch_cv, example_inputs=x ) +filename='torch_model.pt' +traced_cv.save(filename) + diff --git a/regtest/pytorch/rt-pytorch_model/plumed.dat b/regtest/pytorch/rt-pytorch_model/plumed.dat new file mode 100644 index 0000000000..9d011dafa9 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model/plumed.dat @@ -0,0 +1,14 @@ +# vim:ft=plumed + +#define input x +phi: TORSION ATOMS=5,7,9,15 + +#load model computing y=sin(x) +model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi + +#output derivatives dy/dx +DUMPDERIVATIVES ARG=model.* STRIDE=5 FILE=DERIVATIVES + +#print colvar +PRINT FMT=%g STRIDE=5 FILE=COLVAR ARG=phi,model.* +ENDPLUMED diff --git a/regtest/pytorch/rt-pytorch_model/torch_model.ptc b/regtest/pytorch/rt-pytorch_model/torch_model.ptc new file mode 100644 index 0000000000000000000000000000000000000000..1e06d5e5043c133010318f87bec6a0a81d4de6cf GIT binary patch literal 1502 zcmWIWW@cev;NW1u03r;03?=zR$r({G#&2q7;qd)SNUc zpkdI!(!;P*2dF_op~6ZbBsH%%zerO-*AB`Axk5pqD7B=tC{F>bU#~bbPoqMU3lxaG z;>lC@0|O9*L4ha&j84q>RAsY#{j@!%NNAUZgYPTClB*np>PE|WLgyhgz` zmnq-Y9VnF9-ESLxuGA%Y!{Zxavriv$x!2DA;AX)kqs?36bmrE*oLus2PGQ)FA2O>Y z)0@42{HXo-ev!9C^DQZsz1w;fzAh4Zddq8};tlgxE$$g#uO*aDpZIkXTfg)nhlDM6 zSU*hJ`d37wcWprWXRX}h9NPF3w921$u@7 zgaf=8K@_}9K+dBgAPE$Jt%N`~0Xc!nqUi1d@{moy$k`C15Gfit=+#k-I?qVJC`9^3 zcOP=tOQIO+z>IDvQXHZigB*y$D8@`dGX@%M0p4tEI#7jj%(`$btWXw=&H{Q31cX3| c0|ww@Sk)=^~3I%8k6WzAL|JE2IMN6|30NcBV#A^Q>%MJPpt zEQ!dLBw0!!$?qERr>F6*{@Q2*6*kBp+5IuC%5#tHMA^xV3M-E1goNp0P z{Ho_+u_b$!MLZ_=nLam>{>=B?Z2Kl%0`<&&ejF*_?w37=*2AVEMt?YI2CWTkS~XW$PKG7PCe~^+TeA+Ki2IV)=;p%18WOdGhi*ls{~xE5FTGA`#t)U8tZ7XMj|cr|!wwq6D6Pq5a5^BQD%LLds(nl(U)Ti&%L z?iKIPj<4j9;imK+xfi1S$83gVx=OEG#JJ8Fj6^4pGS2i{wH5?ynPf|O$Hi8!Uh`RJ zyv5^=CSMWdKLZI=w)krkOT30l9 zTUI{FEqSX8Gk;2qbD{AyrdwLm&+wgJRQV(c;+79T&DTpt33?8=hwP${JXvz{zSj9h z4MkN!zNgkt+l`bGq6G{>@=}t{kRkQIx~8aUq|t0$gt|6^wF9h2Aj=X09bv82wTii3 z$3(mMl}Mt+he#D(nUM%<1j82vpYQkD4m3@ywrz9!wtHci`SeNSzpM4o!K_iG`d8W_ zc9txNxHqgc(p|k~<=U?{7H3L-DU_A|1MaYtaKx!XQ zuPu`xQ%f?jzkpq`)4_%EVds+bZ3ZO60B=p$sMN6Y1Nz&H8Z*pXl&OihuMOW2OALmq z^Hv$G0^ETw-dtf>kbhP9)W#sJmEN<}boHr4exp$6AM2!}#w@K?BwYU}6R_?8>kU{x zz-R$$t*$%?*78=nw=wVsIGd7F2u%3GnlP1E@iSapI!meN{!#2Umc3tar23s)`S~H+ zm-qf2aTQ#y8hh8^>7m1qaOfBNxaXi_L)T-wc+#1}x=U#gyI$8VEPg~LnytBDMS|5A zoKlcw0D<8CZM~PMAn#2vW)SqnIg^j2;P4veAs$ovcw6_Y1&7`}X?eA|zj@5h!s>!Ra0S+=?rs!Z{r0g9q=pY=y5t z&Vn@(tWRK#g7r3puJynwLWRTWAud7+q2@$U9%91xwMr`X#3H7z%bjB+I0QN?jqJ9+ z8()5B62)Em==f=tc0t^iCdnmsx*lWP$te>D!|d%ODc8wcxA^nsnhZBtBhrXEG+W^- zkVRmv1!org+$vzDmSD{~93GDK)nRLtTJ_x$Rg#9otMBtZ=tDlw^JAZ1ijT1T;Pg*B zF8&lm@bYScbg$AxYvhzpK}A3np|QK8HFEi3OWUkS@=9WP!Pp6DX+gp!Ye6y5j%KSV zSQo$=4OW=77#Xm3Z&=-@S~=VmShe_@Ow^+06Xfv`?^O(pI4VhBODiKEJ?YBM3|B~R zH&Z?_S%3bW=W8BVjlf#5VLhW!fOW&!poOsn0y!=?E&cYt9C}@H6*t@^=V$vv)5dp>GC19d zq);a8e(uW(=`4Gu5>MckT<&P=`Sof%^zlu**Ba;~v`3e1bcSmVY&ZYSf#r5$7|nHs zJqpg2dg{*yG9o(?Gg(I-Gy7L-h75q4|Sfv5J)_+y6&eFB}Wz$t9tI{#UC(ci& zDoW4oJzO4dBpm7;r?C<%9AC4*-t|gO?7OS6u5o57=8egZB9ruTfoSU3UY}W8@k2y- zC;!K}3SRg7Z^$om|ItNVrnvmGT&FBce4nljB%+XorpYF1H zE6LU;aRKE$9PY}F=_Rq%3+kWde|@^n6dT|lp=f-|WM5~E_;~5nyK~m1n6~oHde-Xc ztmbVKGyHjKf2_K*#H%!0_k*=?b@dlC6S^I;q`?aH{e5k{E-lBNs+-u#$5I~}g&V@+ zHAmPKryoypMu;b-RtK`FiOwh1-kr<+LA~{LZTqSQc9N#iOQiHCut(1+Ws}59deqFwNQL&cQx{=KLb`obVe`|kk zOI%iy@FD`A`_9ecC;cYtbsQ0PAsTytYJCDd&;wTLwY3wh|IJsKAOu#=Au~cq=|o9n z95K-O_83m_g=+Xi{~kY87L{$b8O{}5XY&{hF$=9c zy2sDve5OR_9k)2pO_12+lrbT`qS^T`y#LeSyaP_^{+b8Q`Hh+`-melds9Hhub*%VB zJ=bByn>?xyc^`ZEnnM*@($%cl!Omq#fPKirP_mPIGWW8cZsxIKtNyG}^L)#$ut{=n zzji;fwxd%E<-%l1ntmS!vL94HB}27>iCSZ{!}A3~`!!04BIg9x!#y_r`) zup_Dx$%;?-!EwquLGx4;Yw}eY&A4KJuZh|ll)bWmNRTdPR*&|`IkFd}f!Gae zgSS6Nfa8}&sev>?RLzNp>ES$H9y1qMnP=o1J=ud!e5D^b$rmZnxzoFSv6IJZvdAvC z>7wCZLsxvdVxLu9Xj@e)zj;ZCq$GT}yXZ-&`yVUvuVNC0W-H7l6r8J8O2G=hmFXSOz;aBlb!~s_q&{D5kRq1h@4R_cEB z0<3e81$k>us@b?(Nu0xTzkMUGo@gDpY>Ej#POoCqkZ*g0NO!Yu5pnUno=b*nCy!|9 zvv($Q1c&-J&v?CjJXOA$^co#b``+K+A}hJW;LgcyerH8kFV&TAvL+6Z;5*QGU0njJ z2v`@_t$QF)0oIzenGJz8!0~DE9hBZej(PYt9TkhS`E@oS^vvt9U0DIk`#GJ>Bz#?t z<{$cUz=HGh*3Us#n9h0$eY`qzh<}DX#i-VRYwHVDcZ7g&NE$gfYLm5UmIPf!qZQ88 zD0Q$xpQ4wbuI3O(J+x7Kz{G4ZNB*>^6@$bUZ|SJ#j68VR87rI5yHH)2ZEmU0Jx{&W zs*QJ*YK~r9svGxxw{J(ITvw&?a!$3l8!Jh!K8Iz@l;qOb9PCv|-!n<{9SZ%^10!@K zUYe~iPuK3XVGlsVi3dY%0_vu{W=&q?;P4N3ix9;R5h&Uj0_i2jjW%!J);%#ll;(sU z?7}ci*$Xu#zxj1v!1;}7ya{w!`ETd;{6id6rXgE6GCdbh^% z@II-{dlrhoOM<5-H1*h0#;5i!WH_vUri@5@@L$V(foEF{K#ooCOkdi z3y(jNK6~7Vz0CCaM4DGpl~;)xvb_7AzHoH`U&Y7AS}xCS&k{ zcuV1Q9!rz((!`PZ@xt<2XWwp(=&{<177ytTWlf|aCi;B>V&?mEo98nF`>#i4H<$Xy zo?8ly{!HS(*D)@>$$CzWBtx?m?o&|aV0{8sY9CPFe?3^#emzuBWW{k9k&hON3FD6( zL%HGA?nZG9*9akJKk}X8ZJ}`@7}KBdYy3L(EyqIr)r&v$&!mu(#^ke2)3uqPT~$0KYzp4dklME1}WPb zVUU%7>9)cZoq^aZ<*trBws{qs<`$~Ok)%Si^(a_1z?uoxd9YGXLugpk>&w7i|1{@T zbrrXst=Q{n_v*z0wqz7?4`0td7aMfqn#E+wT&R&=v5Y;PGIC`rHkYB+V|?F5U*gu# zm@9P}rTK3Ds+MhuayHthZa?z8`=_Q~{YXTbouS|i0p}>>KrPXe;N#hF?)O59=(MWU zAE~T)@T9X|qA<9{_16-s@Fnd%^Mcvnn_`2FWZbA^(0Xm;L!888pdWiU2C7PU_WV6Dwo#z1j-uRg&> zF_j@+vK@UdUi+N9==5nPv$LX}6=}XHs?9Y-*O)}Puyb?AXXT0$I=9~$oc&Sq;LPb} z-W3jV@Jgx?N9(dNW?RPyf2f#^v|JSTi$45kT z`aC0qzI`61d?z6By}<=s(~9qRR1qY~`DHw>=si64_eW!g=D(d%46SepF|XuEi32fJ zWmb&Jg5{BHMB}4$ z?%K&`J3|kjb8nEdLpU5QzNVuuRAv_*$3ijOWQ{8(!P5vDt#x3v0qeWpP7HOfYQtJP zw`i3raCm<=XOML8{D<;3@$N`MYK&;F{x24ORl4zxHqjc9Kn0d;k1$=n2gk^yz?x(F zgO<*VL1hf=KU}&k6@_=Hp^8|o(M22@r@2K$O(1rCA3#?;CfU(!h5Hp07OW(&CO{VS z0-fW!)!=j2Uv5r%RX0q-4_)Q!VHQ2_zqO#&)=Eh9bMvx7y||`y??uI?F9O$3@7B@Z zwG#5Nq0r!)&yEMv&QE6!j^CLCro6F zKb!;Jt~-sJ+KYAE^=z7X(>vqIDhlH1WxURcZzvC&cDf(D!_sQ^)HaYXAAec;ErZ-E zPi&U$_BaYp;`l<+nK=)(P`tj+Ys-<`%2{H~p>mz?t5mv80JBii}6;(T}t1=GN;SCU|*r@YB3HcRrP*$VFq3Z7f8 zJ#&Y1%i5X6213_nfd9cH5rdO*%%a~TJEIO~2s#%TswuwvHE^-GXLN}-C^DOe>aRmQu3dgdSSN@es_U=H#jA7iyvRmLW(8662KL*sSz z6ncRAT<8tdGzhHJxhl6|H5cfX*LAUYB{n+kPtMqqnSN|F(qrnXW=v8cy|*$gAFEZUrtRiIWGjj)nf|fe;+& z1kKhKsOxdCHiNYfvfxgg4$iW_XF!_n7HmjmftGkl?<>lSlCt!hs7%EV5_UU`zSj+0 zAD{K8bTR&Tx=ME{$o`U>KTF;F8{U0SfP5Ls16Z3-D zjefn&OZ~qUG}aXw9=QxwZ?GPOEb2ahU$+`5l5KSC!kmm^>2Ce|R`KY^%Wk)dt+J#f zj)oO5*N8@?YojSQ8xHpD(!|(De|vlrz2`ZSX~q3|gs;}4laKhc%~f+``uRVpUD$R& L{`DX)9o_!{2f={` literal 0 HcmV?d00001 diff --git a/regtest/pytorch/rt-pytorch_model_2d/config b/regtest/pytorch/rt-pytorch_model_2d/config new file mode 100644 index 0000000000..da5a9cd874 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_2d/config @@ -0,0 +1,10 @@ +plumed_needs=libtorch +plumed_modules=pytorch +type=driver +arg="--plumed plumed.dat --mf_xtc alanine.xtc" + +# note: model has been previously created with create-pytorch-model.py +# the following crashes in CI +# function plumed_regtest_before(){ +# python create-pytorch-model.py +#} diff --git a/regtest/pytorch/rt-pytorch_model_2d/create-pytorch-model.py b/regtest/pytorch/rt-pytorch_model_2d/create-pytorch-model.py new file mode 100644 index 0000000000..b3ec622a3e --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_2d/create-pytorch-model.py @@ -0,0 +1,45 @@ +import torch +print(torch.__version__) + +def my_torch_cv(x): + ''' + Here goes the definition of the CV. + + Inputs: + x (torch.tensor): input, either scalar or 1-D array + Return: + y (torch.tensor): collective variable (scalar) + ''' + # CV definition + #y1 = torch.sin(x) + #y2 = torch.cos(x) + #y = torch.cat((y1,y2)) + y = torch.sin(x) + return y + +input_size = 2 + +# -- DEFINE INPUT -- +#random +#x = torch.rand(input_size, dtype=torch.float32, requires_grad=True).unsqueeze(0) +#or by choosing the value(s) of the array +x = torch.tensor([0.,1.57], dtype=torch.float32, requires_grad=True) + +# -- CALCULATE CV -- +y = my_torch_cv(x) + +# -- CALCULATE DERIVATIVES -- +for yy in y: + dy = torch.autograd.grad(yy, x, create_graph=True) + # -- PRINT -- + print('CV TEST') + print('n_input\t: {}'.format(input_size)) + print('x\t: {}'.format(x)) + print('cv\t: {}'.format(yy)) + print('der\t: {}'.format(dy)) + +# Compile via tracing +traced_cv = torch.jit.trace ( my_torch_cv, example_inputs=x ) +filename='torch_model.pt' +traced_cv.save(filename) + diff --git a/regtest/pytorch/rt-pytorch_model_2d/plumed.dat b/regtest/pytorch/rt-pytorch_model_2d/plumed.dat new file mode 100644 index 0000000000..03f8c3bda3 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_2d/plumed.dat @@ -0,0 +1,15 @@ +# vim:ft=plumed + +#define input x +phi: TORSION ATOMS=5,7,9,15 +psi: TORSION ATOMS=7,9,15,17 + +#load model computing [sin(x),cos(x)] +model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi + +#output derivatives dy/dx +DUMPDERIVATIVES ARG=model.* STRIDE=5 FILE=DERIVATIVES + +#print colvar +PRINT FMT=%g STRIDE=5 FILE=COLVAR ARG=phi,psi,model.* +ENDPLUMED diff --git a/regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc b/regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc new file mode 100644 index 0000000000000000000000000000000000000000..a5fe64483a74c5fd83139e446caac54aca83f687 GIT binary patch literal 1502 zcmWIWW@cev;NW1u03r;03?=zR$r({G#&2q7;qd)SNUc zpkdI!(!;P*2dF_op~6ZbBsH%%zerO-*AB`Axk5pqD7B=tC{F>bU#~bbPoqMU3lxaG z;>lC@0|O9*L4ha&j84q>RAsY#{j@!%NNAUZgYPTJ^q#DJ&mGV`>&J&ejc zYhC$QHV7p3-`E|x^GXb}=Vg_Bt9Kb*{Qi)!!tByL!_AJ&I!5toZ-NX$N<8)bAm8hPACJWzs&qR3evwZI0;>xmmUiA4Fd=V zcr$`1c$t8lM@2vqC;(dtfo=kF0+mJ4-38ww@Sk)=^~3I%8k6WzAL|JE2IMN6|30NcBV#A^Q>%MJPpt zEQ!dLBw0!!$?qERr>F6*{@Q2*6*kBp+5IuC%5#tHMA^xV3M-E1goNp0P z{Ho_+u_b$!MLZ_=nLam>{>=B?Z2Kl%0`<&&ejF*_?w37=*2AVEMt?YI2CWTkS~XW$PKG7PCe~^+TeA+Ki2IV)=;p%18WOdGhi*ls{~xE5FTGA`#t)U8tZ7XMj|cr|!wwq6D6Pq5a5^BQD%LLds(nl(U)Ti&%L z?iKIPj<4j9;imK+xfi1S$83gVx=OEG#JJ8Fj6^4pGS2i{wH5?ynPf|O$Hi8!Uh`RJ zyv5^=CSMWdKLZI=w)krkOT30l9 zTUI{FEqSX8Gk;2qbD{AyrdwLm&+wgJRQV(c;+79T&DTpt33?8=hwP${JXvz{zSj9h z4MkN!zNgkt+l`bGq6G{>@=}t{kRkQIx~8aUq|t0$gt|6^wF9h2Aj=X09bv82wTii3 z$3(mMl}Mt+he#D(nUM%<1j82vpYQkD4m3@ywrz9!wtHci`SeNSzpM4o!K_iG`d8W_ zc9txNxHqgc(p|k~<=U?{7H3L-DU_A|1MaYtaKx!XQ zuPu`xQ%f?jzkpq`)4_%EVds+bZ3ZO60B=p$sMN6Y1Nz&H8Z*pXl&OihuMOW2OALmq z^Hv$G0^ETw-dtf>kbhP9)W#sJmEN<}boHr4exp$6AM2!}#w@K?BwYU}6R_?8>kU{x zz-R$$t*$%?*78=nw=wVsIGd7F2u%3GnlP1E@iSapI!meN{!#2Umc3tar23s)`S~H+ zm-qf2aTQ#y8hh8^>7m1qaOfBNxaXi_L)T-wc+#1}x=U#gyI$8VEPg~LnytBDMS|5A zoKlcw0D<8CZM~PMAn#2vW)SqnIg^j2;P4veAs$ovcw6_Y1&7`}X?eA|zj@5h!s>!Ra0S+=?rs!Z{r0g9q=pY=y5t z&Vn@(tWRK#g7r3puJynwLWRTWAud7+q2@$U9%91xwMr`X#3H7z%bjB+I0QN?jqJ9+ z8()5B62)Em==f=tc0t^iCdnmsx*lWP$te>D!|d%ODc8wcxA^nsnhZBtBhrXEG+W^- zkVRmv1!org+$vzDmSD{~93GDK)nRLtTJ_x$Rg#9otMBtZ=tDlw^JAZ1ijT1T;Pg*B zF8&lm@bYScbg$AxYvhzpK}A3np|QK8HFEi3OWUkS@=9WP!Pp6DX+gp!Ye6y5j%KSV zSQo$=4OW=77#Xm3Z&=-@S~=VmShe_@Ow^+06Xfv`?^O(pI4VhBODiKEJ?YBM3|B~R zH&Z?_S%3bW=W8BVjlf#5VLhW!fOW&!poOsn0y!=?E&cYt9C}@H6*t@^=V$vv)5dp>GC19d zq);a8e(uW(=`4Gu5>MckT<&P=`Sof%^zlu**Ba;~v`3e1bcSmVY&ZYSf#r5$7|nHs zJqpg2dg{*yG9o(?Gg(I-Gy7L-h75q4|Sfv5J)_+y6&eFB}Wz$t9tI{#UC(ci& zDoW4oJzO4dBpm7;r?C<%9AC4*-t|gO?7OS6u5o57=8egZB9ruTfoSU3UY}W8@k2y- zC;!K}3SRg7Z^$om|ItNVrnvmGT&FBce4nljB%+XorpYF1H zE6LU;aRKE$9PY}F=_Rq%3+kWde|@^n6dT|lp=f-|WM5~E_;~5nyK~m1n6~oHde-Xc ztmbVKGyHjKf2_K*#H%!0_k*=?b@dlC6S^I;q`?aH{e5k{E-lBNs+-u#$5I~}g&V@+ zHAmPKryoypMu;b-RtK`FiOwh1-kr<+LA~{LZTqSQc9N#iOQiHCut(1+Ws}59deqFwNQL&cQx{=KLb`obVe`|kk zOI%iy@FD`A`_9ecC;cYtbsQ0PAsTytYJCDd&;wTLwY3wh|IJsKAOu#=Au~cq=|o9n z95K-O_83m_g=+Xi{~kY87L{$b8O{}5XY&{hF$=9c zy2sDve5OR_9k)2pO_12+lrbT`qS^T`y#LeSyaP_^{+b8Q`Hh+`-melds9Hhub*%VB zJ=bByn>?xyc^`ZEnnM*@($%cl!Omq#fPKirP_mPIGWW8cZsxIKtNyG}^L)#$ut{=n zzji;fwxd%E<-%l1ntmS!vL94HB}27>iCSZ{!}A3~`!!04BIg9x!#y_r`) zup_Dx$%;?-!EwquLGx4;Yw}eY&A4KJuZh|ll)bWmNRTdPR*&|`IkFd}f!Gae zgSS6Nfa8}&sev>?RLzNp>ES$H9y1qMnP=o1J=ud!e5D^b$rmZnxzoFSv6IJZvdAvC z>7wCZLsxvdVxLu9Xj@e)zj;ZCq$GT}yXZ-&`yVUvuVNC0W-H7l6r8J8O2G=hmFXSOz;aBlb!~s_q&{D5kRq1h@4R_cEB z0<3e81$k>us@b?(Nu0xTzkMUGo@gDpY>Ej#POoCqkZ*g0NO!Yu5pnUno=b*nCy!|9 zvv($Q1c&-J&v?CjJXOA$^co#b``+K+A}hJW;LgcyerH8kFV&TAvL+6Z;5*QGU0njJ z2v`@_t$QF)0oIzenGJz8!0~DE9hBZej(PYt9TkhS`E@oS^vvt9U0DIk`#GJ>Bz#?t z<{$cUz=HGh*3Us#n9h0$eY`qzh<}DX#i-VRYwHVDcZ7g&NE$gfYLm5UmIPf!qZQ88 zD0Q$xpQ4wbuI3O(J+x7Kz{G4ZNB*>^6@$bUZ|SJ#j68VR87rI5yHH)2ZEmU0Jx{&W zs*QJ*YK~r9svGxxw{J(ITvw&?a!$3l8!Jh!K8Iz@l;qOb9PCv|-!n<{9SZ%^10!@K zUYe~iPuK3XVGlsVi3dY%0_vu{W=&q?;P4N3ix9;R5h&Uj0_i2jjW%!J);%#ll;(sU z?7}ci*$Xu#zxj1v!1;}7ya{w!`ETd;{6id6rXgE6GCdbh^% z@II-{dlrhoOM<5-H1*h0#;5i!WH_vUri@5@@L$V(foEF{K#ooCOkdi z3y(jNK6~7Vz0CCaM4DGpl~;)xvb_7AzHoH`U&Y7AS}xCS&k{ zcuV1Q9!rz((!`PZ@xt<2XWwp(=&{<177ytTWlf|aCi;B>V&?mEo98nF`>#i4H<$Xy zo?8ly{!HS(*D)@>$$CzWBtx?m?o&|aV0{8sY9CPFe?3^#emzuBWW{k9k&hON3FD6( zL%HGA?nZG9*9akJKk}X8ZJ}`@7}KBdYy3L(EyqIr)r&v$&!mu(#^ke2)3uqPT~$0KYzp4dklME1}WPb zVUU%7>9)cZoq^aZ<*trBws{qs<`$~Ok)%Si^(a_1z?uoxd9YGXLugpk>&w7i|1{@T zbrrXst=Q{n_v*z0wqz7?4`0td7aMfqn#E+wT&R&=v5Y;PGIC`rHkYB+V|?F5U*gu# zm@9P}rTK3Ds+MhuayHthZa?z8`=_Q~{YXTbouS|i0p}>>KrPXe;N#hF?)O59=(MWU zAE~T)@T9X|qA<9{_16-s@Fnd%^Mcvnn_`2FWZbA^(0Xm;L!888pdWiU2C7PU_WV6Dwo#z1j-uRg&> zF_j@+vK@UdUi+N9==5nPv$LX}6=}XHs?9Y-*O)}Puyb?AXXT0$I=9~$oc&Sq;LPb} z-W3jV@Jgx?N9(dNW?RPyf2f#^v|JSTi$45kT z`aC0qzI`61d?z6By}<=s(~9qRR1qY~`DHw>=si64_eW!g=D(d%46SepF|XuEi32fJ zWmb&Jg5{BHMB}4$ z?%K&`J3|kjb8nEdLpU5QzNVuuRAv_*$3ijOWQ{8(!P5vDt#x3v0qeWpP7HOfYQtJP zw`i3raCm<=XOML8{D<;3@$N`MYK&;F{x24ORl4zxHqjc9Kn0d;k1$=n2gk^yz?x(F zgO<*VL1hf=KU}&k6@_=Hp^8|o(M22@r@2K$O(1rCA3#?;CfU(!h5Hp07OW(&CO{VS z0-fW!)!=j2Uv5r%RX0q-4_)Q!VHQ2_zqO#&)=Eh9bMvx7y||`y??uI?F9O$3@7B@Z zwG#5Nq0r!)&yEMv&QE6!j^CLCro6F zKb!;Jt~-sJ+KYAE^=z7X(>vqIDhlH1WxURcZzvC&cDf(D!_sQ^)HaYXAAec;ErZ-E zPi&U$_BaYp;`l<+nK=)(P`tj+Ys-<`%2{H~p>mz?t5mv80JBii}6;(T}t1=GN;SCU|*r@YB3HcRrP*$VFq3Z7f8 zJ#&Y1%i5X6213_nfd9cH5rdO*%%a~TJEIO~2s#%TswuwvHE^-GXLN}-C^DOe>aRmQu3dgdSSN@es_U=H#jA7iyvRmLW(8662KL*sSz z6ncRAT<8tdGzhHJxhl6|H5cfX*LAUYB{n+kPtMqqnSN|F(qrnXW=v8cy|*$gAFEZUrtRiIWGjj)nf|fe;+& z1kKhKsOxdCHiNYfvfxgg4$iW_XF!_n7HmjmftGkl?<>lSlCt!hs7%EV5_UU`zSj+0 zAD{K8bTR&Tx=ME{$o`U>KTF;F8{U0SfP5Ls16Z3-D zjefn&OZ~qUG}aXw9=QxwZ?GPOEb2ahU$+`5l5KSC!kmm^>2Ce|R`KY^%Wk)dt+J#f zj)oO5*N8@?YojSQ8xHpD(!|(De|vlrz2`ZSX~q3|gs;}4laKhc%~f+``uRVpUD$R& L{`DX)9o_!{2f={` literal 0 HcmV?d00001 diff --git a/regtest/pytorch/rt-pytorch_model_derivatives/config b/regtest/pytorch/rt-pytorch_model_derivatives/config new file mode 100644 index 0000000000..da5a9cd874 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_derivatives/config @@ -0,0 +1,10 @@ +plumed_needs=libtorch +plumed_modules=pytorch +type=driver +arg="--plumed plumed.dat --mf_xtc alanine.xtc" + +# note: model has been previously created with create-pytorch-model.py +# the following crashes in CI +# function plumed_regtest_before(){ +# python create-pytorch-model.py +#} diff --git a/regtest/pytorch/rt-pytorch_model_derivatives/create-pytorch-model.py b/regtest/pytorch/rt-pytorch_model_derivatives/create-pytorch-model.py new file mode 100644 index 0000000000..2442615499 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_derivatives/create-pytorch-model.py @@ -0,0 +1,43 @@ +import torch +print(torch.__version__) + +def my_torch_cv(x): + ''' + Here goes the definition of the CV. + + Inputs: + x (torch.tensor): input, either scalar or 1-D array + Return: + y (torch.tensor): collective variable (scalar) + ''' + # CV definition + y = torch.sin(x) + + return y + +input_size = 1 + +# -- DEFINE INPUT -- +#random +#x = torch.rand(input_size, dtype=torch.float32, requires_grad=True).unsqueeze(0) +#or by choosing the value(s) of the array +x = torch.tensor([0.], dtype=torch.float32, requires_grad=True) + +# -- CALCULATE CV -- +y = my_torch_cv(x) + +# -- CALCULATE DERIVATIVES -- +for yy in y: + dy = torch.autograd.grad(yy, x, create_graph=True) + # -- PRINT -- + print('CV TEST') + print('n_input\t: {}'.format(input_size)) + print('x\t: {}'.format(x)) + print('cv\t: {}'.format(yy)) + print('der\t: {}'.format(dy)) + +# Compile via tracing +traced_cv = torch.jit.trace ( my_torch_cv, example_inputs=x ) +filename='torch_model.pt' +traced_cv.save(filename) + diff --git a/regtest/pytorch/rt-pytorch_model_derivatives/plumed.dat b/regtest/pytorch/rt-pytorch_model_derivatives/plumed.dat new file mode 100644 index 0000000000..088471c90d --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_derivatives/plumed.dat @@ -0,0 +1,17 @@ +# vim:ft=plumed + +#define input x +phi: TORSION ATOMS=5,7,9,15 + +#load model computing y=sin(x) +model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi +#compute sin(x) with python +sinphi: CUSTOM ARG=phi FUNC=sin(x) PERIODIC=NO + +#output derivatives dy/dx (copy the latter to reference file and change header) +DUMPDERIVATIVES ARG=model.* STRIDE=5 FILE=DERIVATIVES FMT=%10.5f +DUMPDERIVATIVES ARG=sinphi STRIDE=5 FILE=DERIVATIVES_REF FMT=%10.5f + +#print colvar +PRINT FMT=%g STRIDE=5 FILE=COLVAR ARG=phi,model.*,sinphi +ENDPLUMED diff --git a/regtest/pytorch/rt-pytorch_model_derivatives/torch_model.ptc b/regtest/pytorch/rt-pytorch_model_derivatives/torch_model.ptc new file mode 100644 index 0000000000000000000000000000000000000000..1e06d5e5043c133010318f87bec6a0a81d4de6cf GIT binary patch literal 1502 zcmWIWW@cev;NW1u03r;03?=zR$r({G#&2q7;qd)SNUc zpkdI!(!;P*2dF_op~6ZbBsH%%zerO-*AB`Axk5pqD7B=tC{F>bU#~bbPoqMU3lxaG z;>lC@0|O9*L4ha&j84q>RAsY#{j@!%NNAUZgYPTClB*np>PE|WLgyhgz` zmnq-Y9VnF9-ESLxuGA%Y!{Zxavriv$x!2DA;AX)kqs?36bmrE*oLus2PGQ)FA2O>Y z)0@42{HXo-ev!9C^DQZsz1w;fzAh4Zddq8};tlgxE$$g#uO*aDpZIkXTfg)nhlDM6 zSU*hJ`d37wcWprWXRX}h9NPF3w921$u@7 zgaf=8K@_}9K+dBgAPE$Jt%N`~0Xc!nqUi1d@{moy$k`C15Gfit=+#k-I?qVJC`9^3 zcOP=tOQIO+z>IDvQXHZigB*y$D8@`dGX@%M0p4tEI#7jj%(`$btWXw=&H{Q31cX3| c0|ww@Sk)=^~3I%8k6WzAL|JE2IMN6|30NcBV#A^Q>%MJPpt zEQ!dLBw0!!$?qERr>F6*{@Q2*6*kBp+5IuC%5#tHMA^xV3M-E1goNp0P z{Ho_+u_b$!MLZ_=nLam>{>=B?Z2Kl%0`<&&ejF*_?w37=*2AVEMt?YI2CWTkS~XW$PKG7PCe~^+TeA+Ki2IV)=;p%18WOdGhi*ls{~xE5FTGA`#t)U8tZ7XMj|cr|!wwq6D6Pq5a5^BQD%LLds(nl(U)Ti&%L z?iKIPj<4j9;imK+xfi1S$83gVx=OEG#JJ8Fj6^4pGS2i{wH5?ynPf|O$Hi8!Uh`RJ zyv5^=CSMWdKLZI=w)krkOT30l9 zTUI{FEqSX8Gk;2qbD{AyrdwLm&+wgJRQV(c;+79T&DTpt33?8=hwP${JXvz{zSj9h z4MkN!zNgkt+l`bGq6G{>@=}t{kRkQIx~8aUq|t0$gt|6^wF9h2Aj=X09bv82wTii3 z$3(mMl}Mt+he#D(nUM%<1j82vpYQkD4m3@ywrz9!wtHci`SeNSzpM4o!K_iG`d8W_ zc9txNxHqgc(p|k~<=U?{7H3L-DU_A|1MaYtaKx!XQ zuPu`xQ%f?jzkpq`)4_%EVds+bZ3ZO60B=p$sMN6Y1Nz&H8Z*pXl&OihuMOW2OALmq z^Hv$G0^ETw-dtf>kbhP9)W#sJmEN<}boHr4exp$6AM2!}#w@K?BwYU}6R_?8>kU{x zz-R$$t*$%?*78=nw=wVsIGd7F2u%3GnlP1E@iSapI!meN{!#2Umc3tar23s)`S~H+ zm-qf2aTQ#y8hh8^>7m1qaOfBNxaXi_L)T-wc+#1}x=U#gyI$8VEPg~LnytBDMS|5A zoKlcw0D<8CZM~PMAn#2vW)SqnIg^j2;P4veAs$ovcw6_Y1&7`}X?eA|zj@5h!s>!Ra0S+=?rs!Z{r0g9q=pY=y5t z&Vn@(tWRK#g7r3puJynwLWRTWAud7+q2@$U9%91xwMr`X#3H7z%bjB+I0QN?jqJ9+ z8()5B62)Em==f=tc0t^iCdnmsx*lWP$te>D!|d%ODc8wcxA^nsnhZBtBhrXEG+W^- zkVRmv1!org+$vzDmSD{~93GDK)nRLtTJ_x$Rg#9otMBtZ=tDlw^JAZ1ijT1T;Pg*B zF8&lm@bYScbg$AxYvhzpK}A3np|QK8HFEi3OWUkS@=9WP!Pp6DX+gp!Ye6y5j%KSV zSQo$=4OW=77#Xm3Z&=-@S~=VmShe_@Ow^+06Xfv`?^O(pI4VhBODiKEJ?YBM3|B~R zH&Z?_S%3bW=W8BVjlf#5VLhW!fOW&!poOsn0y!=?E&cYt9C}@H6*t@^=V$vv)5dp>GC19d zq);a8e(uW(=`4Gu5>MckT<&P=`Sof%^zlu**Ba;~v`3e1bcSmVY&ZYSf#r5$7|nHs zJqpg2dg{*yG9o(?Gg(I-Gy7L-h75q4|Sfv5J)_+y6&eFB}Wz$t9tI{#UC(ci& zDoW4oJzO4dBpm7;r?C<%9AC4*-t|gO?7OS6u5o57=8egZB9ruTfoSU3UY}W8@k2y- zC;!K}3SRg7Z^$om|ItNVrnvmGT&FBce4nljB%+XorpYF1H zE6LU;aRKE$9PY}F=_Rq%3+kWde|@^n6dT|lp=f-|WM5~E_;~5nyK~m1n6~oHde-Xc ztmbVKGyHjKf2_K*#H%!0_k*=?b@dlC6S^I;q`?aH{e5k{E-lBNs+-u#$5I~}g&V@+ zHAmPKryoypMu;b-RtK`FiOwh1-kr<+LA~{LZTqSQc9N#iOQiHCut(1+Ws}59deqFwNQL&cQx{=KLb`obVe`|kk zOI%iy@FD`A`_9ecC;cYtbsQ0PAsTytYJCDd&;wTLwY3wh|IJsKAOu#=Au~cq=|o9n z95K-O_83m_g=+Xi{~kY87L{$b8O{}5XY&{hF$=9c zy2sDve5OR_9k)2pO_12+lrbT`qS^T`y#LeSyaP_^{+b8Q`Hh+`-melds9Hhub*%VB zJ=bByn>?xyc^`ZEnnM*@($%cl!Omq#fPKirP_mPIGWW8cZsxIKtNyG}^L)#$ut{=n zzji;fwxd%E<-%l1ntmS!vL94HB}27>iCSZ{!}A3~`!!04BIg9x!#y_r`) zup_Dx$%;?-!EwquLGx4;Yw}eY&A4KJuZh|ll)bWmNRTdPR*&|`IkFd}f!Gae zgSS6Nfa8}&sev>?RLzNp>ES$H9y1qMnP=o1J=ud!e5D^b$rmZnxzoFSv6IJZvdAvC z>7wCZLsxvdVxLu9Xj@e)zj;ZCq$GT}yXZ-&`yVUvuVNC0W-H7l6r8J8O2G=hmFXSOz;aBlb!~s_q&{D5kRq1h@4R_cEB z0<3e81$k>us@b?(Nu0xTzkMUGo@gDpY>Ej#POoCqkZ*g0NO!Yu5pnUno=b*nCy!|9 zvv($Q1c&-J&v?CjJXOA$^co#b``+K+A}hJW;LgcyerH8kFV&TAvL+6Z;5*QGU0njJ z2v`@_t$QF)0oIzenGJz8!0~DE9hBZej(PYt9TkhS`E@oS^vvt9U0DIk`#GJ>Bz#?t z<{$cUz=HGh*3Us#n9h0$eY`qzh<}DX#i-VRYwHVDcZ7g&NE$gfYLm5UmIPf!qZQ88 zD0Q$xpQ4wbuI3O(J+x7Kz{G4ZNB*>^6@$bUZ|SJ#j68VR87rI5yHH)2ZEmU0Jx{&W zs*QJ*YK~r9svGxxw{J(ITvw&?a!$3l8!Jh!K8Iz@l;qOb9PCv|-!n<{9SZ%^10!@K zUYe~iPuK3XVGlsVi3dY%0_vu{W=&q?;P4N3ix9;R5h&Uj0_i2jjW%!J);%#ll;(sU z?7}ci*$Xu#zxj1v!1;}7ya{w!`ETd;{6id6rXgE6GCdbh^% z@II-{dlrhoOM<5-H1*h0#;5i!WH_vUri@5@@L$V(foEF{K#ooCOkdi z3y(jNK6~7Vz0CCaM4DGpl~;)xvb_7AzHoH`U&Y7AS}xCS&k{ zcuV1Q9!rz((!`PZ@xt<2XWwp(=&{<177ytTWlf|aCi;B>V&?mEo98nF`>#i4H<$Xy zo?8ly{!HS(*D)@>$$CzWBtx?m?o&|aV0{8sY9CPFe?3^#emzuBWW{k9k&hON3FD6( zL%HGA?nZG9*9akJKk}X8ZJ}`@7}KBdYy3L(EyqIr)r&v$&!mu(#^ke2)3uqPT~$0KYzp4dklME1}WPb zVUU%7>9)cZoq^aZ<*trBws{qs<`$~Ok)%Si^(a_1z?uoxd9YGXLugpk>&w7i|1{@T zbrrXst=Q{n_v*z0wqz7?4`0td7aMfqn#E+wT&R&=v5Y;PGIC`rHkYB+V|?F5U*gu# zm@9P}rTK3Ds+MhuayHthZa?z8`=_Q~{YXTbouS|i0p}>>KrPXe;N#hF?)O59=(MWU zAE~T)@T9X|qA<9{_16-s@Fnd%^Mcvnn_`2FWZbA^(0Xm;L!888pdWiU2C7PU_WV6Dwo#z1j-uRg&> zF_j@+vK@UdUi+N9==5nPv$LX}6=}XHs?9Y-*O)}Puyb?AXXT0$I=9~$oc&Sq;LPb} z-W3jV@Jgx?N9(dNW?RPyf2f#^v|JSTi$45kT z`aC0qzI`61d?z6By}<=s(~9qRR1qY~`DHw>=si64_eW!g=D(d%46SepF|XuEi32fJ zWmb&Jg5{BHMB}4$ z?%K&`J3|kjb8nEdLpU5QzNVuuRAv_*$3ijOWQ{8(!P5vDt#x3v0qeWpP7HOfYQtJP zw`i3raCm<=XOML8{D<;3@$N`MYK&;F{x24ORl4zxHqjc9Kn0d;k1$=n2gk^yz?x(F zgO<*VL1hf=KU}&k6@_=Hp^8|o(M22@r@2K$O(1rCA3#?;CfU(!h5Hp07OW(&CO{VS z0-fW!)!=j2Uv5r%RX0q-4_)Q!VHQ2_zqO#&)=Eh9bMvx7y||`y??uI?F9O$3@7B@Z zwG#5Nq0r!)&yEMv&QE6!j^CLCro6F zKb!;Jt~-sJ+KYAE^=z7X(>vqIDhlH1WxURcZzvC&cDf(D!_sQ^)HaYXAAec;ErZ-E zPi&U$_BaYp;`l<+nK=)(P`tj+Ys-<`%2{H~p>mz?t5mv80JBii}6;(T}t1=GN;SCU|*r@YB3HcRrP*$VFq3Z7f8 zJ#&Y1%i5X6213_nfd9cH5rdO*%%a~TJEIO~2s#%TswuwvHE^-GXLN}-C^DOe>aRmQu3dgdSSN@es_U=H#jA7iyvRmLW(8662KL*sSz z6ncRAT<8tdGzhHJxhl6|H5cfX*LAUYB{n+kPtMqqnSN|F(qrnXW=v8cy|*$gAFEZUrtRiIWGjj)nf|fe;+& z1kKhKsOxdCHiNYfvfxgg4$iW_XF!_n7HmjmftGkl?<>lSlCt!hs7%EV5_UU`zSj+0 zAD{K8bTR&Tx=ME{$o`U>KTF;F8{U0SfP5Ls16Z3-D zjefn&OZ~qUG}aXw9=QxwZ?GPOEb2ahU$+`5l5KSC!kmm^>2Ce|R`KY^%Wk)dt+J#f zj)oO5*N8@?YojSQ8xHpD(!|(De|vlrz2`ZSX~q3|gs;}4laKhc%~f+``uRVpUD$R& L{`DX)9o_!{2f={` literal 0 HcmV?d00001 diff --git a/regtest/pytorch/rt-pytorch_model_script/config b/regtest/pytorch/rt-pytorch_model_script/config new file mode 100644 index 0000000000..da5a9cd874 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_script/config @@ -0,0 +1,10 @@ +plumed_needs=libtorch +plumed_modules=pytorch +type=driver +arg="--plumed plumed.dat --mf_xtc alanine.xtc" + +# note: model has been previously created with create-pytorch-model.py +# the following crashes in CI +# function plumed_regtest_before(){ +# python create-pytorch-model.py +#} diff --git a/regtest/pytorch/rt-pytorch_model_script/create-pytorch-model.py b/regtest/pytorch/rt-pytorch_model_script/create-pytorch-model.py new file mode 100644 index 0000000000..0233a5a838 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_script/create-pytorch-model.py @@ -0,0 +1,28 @@ +import torch +print(torch.__version__) + +def my_torch_cv(x): + ''' + Here goes the definition of the CV. + + Inputs: + x (torch.tensor): input, either scalar or 1-D array + Return: + y (torch.tensor): collective variable (scalar) + ''' + if x > 0: + # CV definition + y = torch.sin(x) + else: + y = torch.tan(x) + + return y + +input_size = 1 + +# Compile via scripting +scripted_cv = torch.jit.script( my_torch_cv ) + +filename='torch_model.pt' +scripted_cv.save(filename) + diff --git a/regtest/pytorch/rt-pytorch_model_script/plumed.dat b/regtest/pytorch/rt-pytorch_model_script/plumed.dat new file mode 100644 index 0000000000..2eeafc8ad7 --- /dev/null +++ b/regtest/pytorch/rt-pytorch_model_script/plumed.dat @@ -0,0 +1,14 @@ +# vim:ft=plumed + +#define input x +phi: TORSION ATOMS=5,7,9,15 + +#load model computing y=sin(x) +model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi + +#output derivatives dy/dx +DUMPDERIVATIVES ARG=model.* STRIDE=5 FILE=DERIVATIVES FMT=%10.5f + +#print colvar +PRINT FMT=%g STRIDE=5 FILE=COLVAR ARG=phi,model.* +ENDPLUMED diff --git a/regtest/pytorch/rt-pytorch_model_script/torch_model.ptc b/regtest/pytorch/rt-pytorch_model_script/torch_model.ptc new file mode 100644 index 0000000000000000000000000000000000000000..c5da0fd3853548b2aef7235e521429073717e1f1 GIT binary patch literal 1630 zcmWIWW@cev;NW1u03r;03?=zR$r=p!K`$(P*`R z?$-_X910%#sVsTBME8h`n&!6c=1Y`3!i{U*r!Om3bBme#GJ|KG`o5OBGL>zLp{yBy zjdJEq=~#Tt;^9AyEpylYb?!X-MBUzDA@9M|OOso9|CrtB_Bgjn#=CKM;On&V&F2^s z0(VU8Z&c-b{(e=L(aj)jFGpjQjaMwZc6J_^b^pc^i^`2BnohdjlAiQaa>k`Ia@k)Q zLE)s*!?k)5FpM}E@r9E*11J7)(o0EADou|EMx}7Qr>ZtylD+Mn*6B6zr1#W9p?0SLjj6uJ zdR|?gk)5(rX4#F~oO47kTPTO?jqs?F46i zO5^0#%E&ctTQdW@=UG?ptG#u5iP&D3e6hQcUtV4KJEz*?*`ec0uC4q%_bZpZ?u&A7 z_PX@#hvd0`8UOre9kQl19u%`bjxLa|0mf=16C`G_=S4{dTeO(XD=taQD=CI#NK0~4 ze^gWhlO{MP=IV2X0}BEM5C-K$?Cuao%ZX*FMa7x(ZYWZ`qZ@-9h{7nwhybG#*=^8h3-D%R(}60KW7dUhVTH0_bQaKSATSA( dIDi1Cj~zt+fhqt=2Y9oxfy7vW5TqWW765c^?Lq(m literal 0 HcmV?d00001