@@ -4,6 +4,7 @@ import Replicate, {
44  Model , 
55  Prediction , 
66  validateWebhook , 
7+   parseProgressFromLogs , 
78}  from  "replicate" ; 
89import  nock  from  "nock" ; 
910import  fetch  from  "cross-fetch" ; 
@@ -888,63 +889,124 @@ describe("Replicate client", () => {
888889  } ) ; 
889890
890891  describe ( "run" ,  ( )  =>  { 
891-     test ( "Calls the correct API routes for a version" ,  async  ( )  =>  { 
892-       const  firstPollingRequest  =  true ; 
893- 
892+     test ( "Calls the correct API routes" ,  async  ( )  =>  { 
894893      nock ( BASE_URL ) 
895894        . post ( "/predictions" ) 
896895        . reply ( 201 ,  { 
897896          id : "ufawqhfynnddngldkgtslldrkq" , 
898897          status : "starting" , 
898+           logs : null , 
899899        } ) 
900900        . get ( "/predictions/ufawqhfynnddngldkgtslldrkq" ) 
901-         . twice ( ) 
902901        . reply ( 200 ,  { 
903902          id : "ufawqhfynnddngldkgtslldrkq" , 
904903          status : "processing" , 
904+           logs : [ 
905+             "Using seed: 12345" , 
906+             "0%|          | 0/5 [00:00<?, ?it/s]" , 
907+             "20%|██        | 1/5 [00:00<00:01, 21.38it/s]" , 
908+             "40%|████▍     | 2/5 [00:01<00:01, 22.46it/s]" , 
909+           ] . join ( "\n" ) , 
910+         } ) 
911+         . get ( "/predictions/ufawqhfynnddngldkgtslldrkq" ) 
912+         . reply ( 200 ,  { 
913+           id : "ufawqhfynnddngldkgtslldrkq" , 
914+           status : "processing" , 
915+           logs : [ 
916+             "Using seed: 12345" , 
917+             "0%|          | 0/5 [00:00<?, ?it/s]" , 
918+             "20%|██        | 1/5 [00:00<00:01, 21.38it/s]" , 
919+             "40%|████▍     | 2/5 [00:01<00:01, 22.46it/s]" , 
920+             "60%|████▍     | 3/5 [00:01<00:01, 22.46it/s]" , 
921+             "80%|████████  | 4/5 [00:01<00:00, 22.86it/s]" , 
922+           ] . join ( "\n" ) , 
905923        } ) 
906924        . get ( "/predictions/ufawqhfynnddngldkgtslldrkq" ) 
907925        . reply ( 200 ,  { 
908926          id : "ufawqhfynnddngldkgtslldrkq" , 
909927          status : "succeeded" , 
910928          output : "Goodbye!" , 
929+           logs : [ 
930+             "Using seed: 12345" , 
931+             "0%|          | 0/5 [00:00<?, ?it/s]" , 
932+             "20%|██        | 1/5 [00:00<00:01, 21.38it/s]" , 
933+             "40%|████▍     | 2/5 [00:01<00:01, 22.46it/s]" , 
934+             "60%|████▍     | 3/5 [00:01<00:01, 22.46it/s]" , 
935+             "80%|████████  | 4/5 [00:01<00:00, 22.86it/s]" , 
936+             "100%|██████████| 5/5 [00:02<00:00, 22.26it/s]" , 
937+           ] . join ( "\n" ) , 
911938        } ) ; 
912939
913-       const  progress  =  jest . fn ( ) ; 
940+       const  callback  =  jest . fn ( ) ; 
914941
915942      const  output  =  await  client . run ( 
916943        "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" , 
917944        { 
918945          input : {  text : "Hello, world!"  } , 
919946          wait : {  interval : 1  } , 
920947        } , 
921-         progress 
948+         ( prediction )  =>  { 
949+           const  progress  =  parseProgressFromLogs ( prediction ) ; 
950+           callback ( prediction ,  progress ) ; 
951+         } 
922952      ) ; 
923953
924954      expect ( output ) . toBe ( "Goodbye!" ) ; 
925955
926-       expect ( progress ) . toHaveBeenNthCalledWith ( 1 ,  { 
927-         id : "ufawqhfynnddngldkgtslldrkq" , 
928-         status : "starting" , 
929-       } ) ; 
956+       expect ( callback ) . toHaveBeenNthCalledWith ( 
957+         1 , 
958+         { 
959+           id : "ufawqhfynnddngldkgtslldrkq" , 
960+           status : "starting" , 
961+           logs : null , 
962+         } , 
963+         null 
964+       ) ; 
930965
931-       expect ( progress ) . toHaveBeenNthCalledWith ( 2 ,  { 
932-         id : "ufawqhfynnddngldkgtslldrkq" , 
933-         status : "processing" , 
934-       } ) ; 
966+       expect ( callback ) . toHaveBeenNthCalledWith ( 
967+         2 , 
968+         { 
969+           id : "ufawqhfynnddngldkgtslldrkq" , 
970+           status : "processing" , 
971+           logs : expect . any ( String ) , 
972+         } , 
973+         { 
974+           percentage : 0.4 , 
975+           current : 2 , 
976+           total : 5 , 
977+         } 
978+       ) ; 
935979
936-       expect ( progress ) . toHaveBeenNthCalledWith ( 3 ,  { 
937-         id : "ufawqhfynnddngldkgtslldrkq" , 
938-         status : "processing" , 
939-       } ) ; 
980+       expect ( callback ) . toHaveBeenNthCalledWith ( 
981+         3 , 
982+         { 
983+           id : "ufawqhfynnddngldkgtslldrkq" , 
984+           status : "processing" , 
985+           logs : expect . any ( String ) , 
986+         } , 
987+         { 
988+           percentage : 0.8 , 
989+           current : 4 , 
990+           total : 5 , 
991+         } 
992+       ) ; 
940993
941-       expect ( progress ) . toHaveBeenNthCalledWith ( 4 ,  { 
942-         id : "ufawqhfynnddngldkgtslldrkq" , 
943-         status : "succeeded" , 
944-         output : "Goodbye!" , 
945-       } ) ; 
994+       expect ( callback ) . toHaveBeenNthCalledWith ( 
995+         4 , 
996+         { 
997+           id : "ufawqhfynnddngldkgtslldrkq" , 
998+           status : "succeeded" , 
999+           logs : expect . any ( String ) , 
1000+           output : "Goodbye!" , 
1001+         } , 
1002+         { 
1003+           percentage : 1.0 , 
1004+           current : 5 , 
1005+           total : 5 , 
1006+         } 
1007+       ) ; 
9461008
947-       expect ( progress ) . toHaveBeenCalledTimes ( 4 ) ; 
1009+       expect ( callback ) . toHaveBeenCalledTimes ( 4 ) ; 
9481010    } ) ; 
9491011
9501012    test ( "Calls the correct API routes for a model" ,  async  ( )  =>  { 
0 commit comments