@@ -681,76 +681,6 @@ private void testOpenAITextEmbeddingModel(String charset, Consumer<Map> verifyRe
681
681
}
682
682
}
683
683
684
- public void testCohereGenerateTextModel () throws IOException , InterruptedException {
685
- // Skip test if key is null
686
- if (COHERE_KEY == null ) {
687
- return ;
688
- }
689
- String entity = "{\n "
690
- + " \" name\" : \" Cohere generate text model Connector\" ,\n "
691
- + " \" description\" : \" The connector to public Cohere generate text model service\" ,\n "
692
- + " \" version\" : 1,\n "
693
- + "\" client_config\" : {\n "
694
- + " \" max_connection\" : 20,\n "
695
- + " \" connection_timeout\" : 50000,\n "
696
- + " \" read_timeout\" : 50000\n "
697
- + " },\n "
698
- + " \" protocol\" : \" http\" ,\n "
699
- + " \" parameters\" : {\n "
700
- + " \" endpoint\" : \" api.cohere.ai\" ,\n "
701
- + " \" auth\" : \" API_Key\" ,\n "
702
- + " \" content_type\" : \" application/json\" ,\n "
703
- + " \" max_tokens\" : \" 20\" \n "
704
- + " },\n "
705
- + " \" credential\" : {\n "
706
- + " \" cohere_key\" : \" "
707
- + COHERE_KEY
708
- + "\" \n "
709
- + " },\n "
710
- + " \" actions\" : [\n "
711
- + " {\n "
712
- + " \" action_type\" : \" predict\" ,\n "
713
- + " \" method\" : \" POST\" ,\n "
714
- + " \" url\" : \" https://${parameters.endpoint}/v1/generate\" ,\n "
715
- + " \" headers\" : { \n "
716
- + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" \n "
717
- + " },\n "
718
- + " \" request_body\" : \" { \\ \" max_tokens\\ \" : ${parameters.max_tokens}, \\ \" return_likelihoods\\ \" : \\ \" NONE\\ \" , \\ \" truncate\\ \" : \\ \" END\\ \" , \\ \" prompt\\ \" : \\ \" ${parameters.prompt}\\ \" }\" \n "
719
- + " }\n "
720
- + " ]\n "
721
- + "}" ;
722
- Response response = createConnector (entity );
723
- Map responseMap = parseResponseToMap (response );
724
- String connectorId = (String ) responseMap .get ("connector_id" );
725
- response = registerRemoteModel ("cohere generate text model" , connectorId );
726
- responseMap = parseResponseToMap (response );
727
- String taskId = (String ) responseMap .get ("task_id" );
728
- waitForTask (taskId , MLTaskState .COMPLETED );
729
- response = getTask (taskId );
730
- responseMap = parseResponseToMap (response );
731
- String modelId = (String ) responseMap .get ("model_id" );
732
- response = deployRemoteModel (modelId );
733
- responseMap = parseResponseToMap (response );
734
- taskId = (String ) responseMap .get ("task_id" );
735
- waitForTask (taskId , MLTaskState .COMPLETED );
736
- String predictInput = "{\n "
737
- + " \" parameters\" : {\n "
738
- + " \" prompt\" : \" Once upon a time in a magical land called\" ,\n "
739
- + " \" max_tokens\" : 40\n "
740
- + " }\n "
741
- + "}" ;
742
- response = predictRemoteModel (modelId , predictInput );
743
- responseMap = parseResponseToMap (response );
744
- List responseList = (List ) responseMap .get ("inference_results" );
745
- responseMap = (Map ) responseList .get (0 );
746
- responseList = (List ) responseMap .get ("output" );
747
- responseMap = (Map ) responseList .get (0 );
748
- responseMap = (Map ) responseMap .get ("dataAsMap" );
749
- responseList = (List ) responseMap .get ("generations" );
750
- responseMap = (Map ) responseList .get (0 );
751
- assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
752
- }
753
-
754
684
public static Response createConnector (String input ) throws IOException {
755
685
try {
756
686
return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/connectors/_create" , null , TestHelper .toHttpEntity (input ), null );
0 commit comments