@@ -4,6 +4,7 @@ import Replicate, {
4
4
Model ,
5
5
Prediction ,
6
6
validateWebhook ,
7
+ parseProgressFromLogs ,
7
8
} from "replicate" ;
8
9
import nock from "nock" ;
9
10
import fetch from "cross-fetch" ;
@@ -888,63 +889,124 @@ describe("Replicate client", () => {
888
889
} ) ;
889
890
890
891
describe ( "run" , ( ) => {
891
- test ( "Calls the correct API routes for a version" , async ( ) => {
892
- const firstPollingRequest = true ;
893
-
892
+ test ( "Calls the correct API routes" , async ( ) => {
894
893
nock ( BASE_URL )
895
894
. post ( "/predictions" )
896
895
. reply ( 201 , {
897
896
id : "ufawqhfynnddngldkgtslldrkq" ,
898
897
status : "starting" ,
898
+ logs : null ,
899
899
} )
900
900
. get ( "/predictions/ufawqhfynnddngldkgtslldrkq" )
901
- . twice ( )
902
901
. reply ( 200 , {
903
902
id : "ufawqhfynnddngldkgtslldrkq" ,
904
903
status : "processing" ,
904
+ logs : [
905
+ "Using seed: 12345" ,
906
+ "0%| | 0/5 [00:00<?, ?it/s]" ,
907
+ "20%|██ | 1/5 [00:00<00:01, 21.38it/s]" ,
908
+ "40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]" ,
909
+ ] . join ( "\n" ) ,
910
+ } )
911
+ . get ( "/predictions/ufawqhfynnddngldkgtslldrkq" )
912
+ . reply ( 200 , {
913
+ id : "ufawqhfynnddngldkgtslldrkq" ,
914
+ status : "processing" ,
915
+ logs : [
916
+ "Using seed: 12345" ,
917
+ "0%| | 0/5 [00:00<?, ?it/s]" ,
918
+ "20%|██ | 1/5 [00:00<00:01, 21.38it/s]" ,
919
+ "40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]" ,
920
+ "60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]" ,
921
+ "80%|████████ | 4/5 [00:01<00:00, 22.86it/s]" ,
922
+ ] . join ( "\n" ) ,
905
923
} )
906
924
. get ( "/predictions/ufawqhfynnddngldkgtslldrkq" )
907
925
. reply ( 200 , {
908
926
id : "ufawqhfynnddngldkgtslldrkq" ,
909
927
status : "succeeded" ,
910
928
output : "Goodbye!" ,
929
+ logs : [
930
+ "Using seed: 12345" ,
931
+ "0%| | 0/5 [00:00<?, ?it/s]" ,
932
+ "20%|██ | 1/5 [00:00<00:01, 21.38it/s]" ,
933
+ "40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]" ,
934
+ "60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]" ,
935
+ "80%|████████ | 4/5 [00:01<00:00, 22.86it/s]" ,
936
+ "100%|██████████| 5/5 [00:02<00:00, 22.26it/s]" ,
937
+ ] . join ( "\n" ) ,
911
938
} ) ;
912
939
913
- const progress = jest . fn ( ) ;
940
+ const callback = jest . fn ( ) ;
914
941
915
942
const output = await client . run (
916
943
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
917
944
{
918
945
input : { text : "Hello, world!" } ,
919
946
wait : { interval : 1 } ,
920
947
} ,
921
- progress
948
+ ( prediction ) => {
949
+ const progress = parseProgressFromLogs ( prediction ) ;
950
+ callback ( prediction , progress ) ;
951
+ }
922
952
) ;
923
953
924
954
expect ( output ) . toBe ( "Goodbye!" ) ;
925
955
926
- expect ( progress ) . toHaveBeenNthCalledWith ( 1 , {
927
- id : "ufawqhfynnddngldkgtslldrkq" ,
928
- status : "starting" ,
929
- } ) ;
956
+ expect ( callback ) . toHaveBeenNthCalledWith (
957
+ 1 ,
958
+ {
959
+ id : "ufawqhfynnddngldkgtslldrkq" ,
960
+ status : "starting" ,
961
+ logs : null ,
962
+ } ,
963
+ null
964
+ ) ;
930
965
931
- expect ( progress ) . toHaveBeenNthCalledWith ( 2 , {
932
- id : "ufawqhfynnddngldkgtslldrkq" ,
933
- status : "processing" ,
934
- } ) ;
966
+ expect ( callback ) . toHaveBeenNthCalledWith (
967
+ 2 ,
968
+ {
969
+ id : "ufawqhfynnddngldkgtslldrkq" ,
970
+ status : "processing" ,
971
+ logs : expect . any ( String ) ,
972
+ } ,
973
+ {
974
+ percentage : 0.4 ,
975
+ current : 2 ,
976
+ total : 5 ,
977
+ }
978
+ ) ;
935
979
936
- expect ( progress ) . toHaveBeenNthCalledWith ( 3 , {
937
- id : "ufawqhfynnddngldkgtslldrkq" ,
938
- status : "processing" ,
939
- } ) ;
980
+ expect ( callback ) . toHaveBeenNthCalledWith (
981
+ 3 ,
982
+ {
983
+ id : "ufawqhfynnddngldkgtslldrkq" ,
984
+ status : "processing" ,
985
+ logs : expect . any ( String ) ,
986
+ } ,
987
+ {
988
+ percentage : 0.8 ,
989
+ current : 4 ,
990
+ total : 5 ,
991
+ }
992
+ ) ;
940
993
941
- expect ( progress ) . toHaveBeenNthCalledWith ( 4 , {
942
- id : "ufawqhfynnddngldkgtslldrkq" ,
943
- status : "succeeded" ,
944
- output : "Goodbye!" ,
945
- } ) ;
994
+ expect ( callback ) . toHaveBeenNthCalledWith (
995
+ 4 ,
996
+ {
997
+ id : "ufawqhfynnddngldkgtslldrkq" ,
998
+ status : "succeeded" ,
999
+ logs : expect . any ( String ) ,
1000
+ output : "Goodbye!" ,
1001
+ } ,
1002
+ {
1003
+ percentage : 1.0 ,
1004
+ current : 5 ,
1005
+ total : 5 ,
1006
+ }
1007
+ ) ;
946
1008
947
- expect ( progress ) . toHaveBeenCalledTimes ( 4 ) ;
1009
+ expect ( callback ) . toHaveBeenCalledTimes ( 4 ) ;
948
1010
} ) ;
949
1011
950
1012
test ( "Calls the correct API routes for a model" , async ( ) => {
0 commit comments