diff --git a/.release-please-manifest.json b/.release-please-manifest.json index f14b480a..aaf968a1 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.2" + ".": "0.1.0-alpha.3" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 58f04461..e97b0bff 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 41 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-f53d9282224f2c3943d83d014d64ba61271f3aedef59197cc4dae0102d2b365d.yml -openapi_spec_hash: a884fe7d04e9f64675e3943a962ebb65 -config_hash: 73457be4d72f0bf4c22de49f2b2d4ec3 +configured_endpoints: 45 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-bfa7422593036f383fcc5209e8a52705f582be9480f90747f8962a46ed5b1152.yml +openapi_spec_hash: 400da476d5f86a3493bf6dacfe6826f0 +config_hash: 87a5832ab2ecefe567d22108531232f5 diff --git a/CHANGELOG.md b/CHANGELOG.md index 71158bb4..f23c324e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,63 @@ # Changelog +## 0.1.0-alpha.3 (2025-11-27) + +Full Changelog: [v0.1.0-alpha.2...v0.1.0-alpha.3](https://github.com/togethercomputer/together-go/compare/v0.1.0-alpha.2...v0.1.0-alpha.3) + +### ⚠ BREAKING CHANGES + +* **api:** Change call signature for `audio.create` to `audio.speech.create` to match spec with python library and add space for future APIs +* **api:** Update method signature for reranking to `rerank.create()` +* **api:** Change Fine Tuning method name from `download()` to `content()` to align with other namespaces +* **api:** For the TS SDK the `images.create` is now `images.generate` +* **api:** Access to the api for listing checkpoints has changed its name to `list_checkpoints` +* **api:** Access to fine tuning APIs namespace has changed from `fine_tune` to `fine_tuning` +* **api:** The default max retries for api calls has changed from 5 to 2. This may result in more frequent non-200 responses. + +### Features + +* **api:** Add audio.voices.list sdk ([0aacf5d](https://github.com/togethercomputer/together-go/commit/0aacf5d6e75b326f74921d3c3dd2e88ed1f32b96)) +* **api:** Add batches.cancel API ([a8e2951](https://github.com/togethercomputer/together-go/commit/a8e29513556c09d803000fcdadb1a4604736f036)) +* **api:** Add endpoints.list_avzones ([c07fe49](https://github.com/togethercomputer/together-go/commit/c07fe49d9ea1202facfd0c4c6019e9119b636297)) +* **api:** Add fine_tune.delete API ([dceaf80](https://github.com/togethercomputer/together-go/commit/dceaf802fb93e5ef9e2fb7ed04b80ff4a6dbacec)) +* **api:** api update ([be1a06c](https://github.com/togethercomputer/together-go/commit/be1a06c5a1fdbf6106d5bb288b2d1568aec85d2b)) +* **api:** api update ([638ebc7](https://github.com/togethercomputer/together-go/commit/638ebc7e0e5ce88228a4ffb6f6f1bd904bb09f8c)) +* **api:** api update ([5169921](https://github.com/togethercomputer/together-go/commit/51699214da2ce174e13634af5ca11ac9936cbd13)) +* **api:** api update ([48737a3](https://github.com/togethercomputer/together-go/commit/48737a323b7fb5f5272e5c208c7078084c8507ab)) +* **api:** api update ([5699d1a](https://github.com/togethercomputer/together-go/commit/5699d1ae3e6e40db16dc6ee1eb136c399e660d1d)) +* **api:** api update ([a51a5f3](https://github.com/togethercomputer/together-go/commit/a51a5f3c92661e87dcc2af8b86d709a0c755016e)) +* **api:** api update ([2c9ca33](https://github.com/togethercomputer/together-go/commit/2c9ca331f6de3549c4dedb5a78dbc50609d49cf7)) +* **api:** api update ([e53f0e7](https://github.com/togethercomputer/together-go/commit/e53f0e79fcf213f40e6537c07f18629d09d9302d)) +* **api:** api update ([c64fc6d](https://github.com/togethercomputer/together-go/commit/c64fc6db4e868f65c32d836636b91b65f03b0e8a)) +* **api:** Change fine tuning download method to `.create` ([faeb0f8](https://github.com/togethercomputer/together-go/commit/faeb0f8c877a78cf8c77395bf48f7dd55efc131b)) +* **api:** Change image creation signature to `images.generate` ([5ec73f1](https://github.com/togethercomputer/together-go/commit/5ec73f1b1f6ac90d705d2a4d03b21a607922eff7)) +* **api:** Change rerank method signature ([15519be](https://github.com/togethercomputer/together-go/commit/15519be8442800c3ddbfe1fa808f4d21027a91e1)) +* **api:** Change the default max retries from 5 to 2 ([becb776](https://github.com/togethercomputer/together-go/commit/becb77688fe7899159c09a34df6061c0a46716f1)) +* **api:** Change TTS call signature ([f906b2e](https://github.com/togethercomputer/together-go/commit/f906b2e51c2a043cdf1bc6e5d057d25787dc9bfb)) +* **api:** Fix internal references for VideoJob spec ([bbf9a21](https://github.com/togethercomputer/together-go/commit/bbf9a21463c634a490c389d407f1f07fc098dc7a)) +* **api:** manual updates ([651e447](https://github.com/togethercomputer/together-go/commit/651e4473bdc16002ec4fc12374b6d445eaf70bdd)) +* **api:** Update Eval APIs ([a2baaa3](https://github.com/togethercomputer/together-go/commit/a2baaa3a41e18d0d9356ffd0f46ae5558a5048a4)) + + +### Bug Fixes + +* **client:** correctly specify Accept header with */* instead of empty ([6d504e0](https://github.com/togethercomputer/together-go/commit/6d504e008ed150736cf1e3498d603f9bd9782418)) +* remove invalid cast ([72d5d52](https://github.com/togethercomputer/together-go/commit/72d5d521cd2fa88dbf31edb4e5dd9fc9fd195ed4)) + + +### Chores + +* **api:** Cleanup some exported types ([aade2f0](https://github.com/togethercomputer/together-go/commit/aade2f06b3c4c705e08fdec6dd4847aec4c8409a)) +* **api:** Remove API that is not intended to be public. ([df90a15](https://github.com/togethercomputer/together-go/commit/df90a159de5d770e1e9049d9a9dc7b7befe66921)) +* bump gjson version ([704f413](https://github.com/togethercomputer/together-go/commit/704f413869dd157060e6cf96d9f1cfb6aa9cc424)) +* **internal:** grammar fix (it's -> its) ([97b3fc5](https://github.com/togethercomputer/together-go/commit/97b3fc5382b8e4d36340516f06b6c60c7daeb95e)) + + +### Styles + +* **api:** Change fine tuning method `retrieve_checkpoints` to `list_checkpoints` ([7e12276](https://github.com/togethercomputer/together-go/commit/7e1227697c594718cedaa70e55be4b7ffd08e836)) +* **api:** Change fine tuning namespace to `fine_tuning` ([cfb8297](https://github.com/togethercomputer/together-go/commit/cfb82976c6e84cba4cab24fbe07616fa7d8e561f)) + ## 0.1.0-alpha.2 (2025-10-30) Full Changelog: [v0.1.0-alpha.1...v0.1.0-alpha.2](https://github.com/togethercomputer/together-go/compare/v0.1.0-alpha.1...v0.1.0-alpha.2) diff --git a/README.md b/README.md index 79dfc9e6..18fa406c 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Or to pin the version: ```sh -go get -u 'github.com/togethercomputer/together-go@v0.1.0-alpha.2' +go get -u 'github.com/togethercomputer/together-go@v0.1.0-alpha.3' ``` @@ -135,7 +135,7 @@ custom := param.Override[together.FooParams](12) ### Request unions -Unions are represented as a struct with fields prefixed by "Of" for each of it's variants, +Unions are represented as a struct with fields prefixed by "Of" for each of its variants, only one field can be non-zero. The non-zero field will be serialized. Sub-properties of the union can be accessed via methods on the union struct. @@ -399,7 +399,7 @@ together.FileUploadParams{ ### Retries -Certain errors will be automatically retried 5 times by default, with a short exponential backoff. +Certain errors will be automatically retried 2 times by default, with a short exponential backoff. We retry by default all connection errors, 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors. diff --git a/api.md b/api.md index db547f74..6ee48468 100644 --- a/api.md +++ b/api.md @@ -1,13 +1,3 @@ -# together - -Response Types: - -- together.RerankResponse - -Methods: - -- client.Rerank(ctx context.Context, body together.RerankParams) (together.RerankResponse, error) - # Chat ## Completions @@ -66,60 +56,44 @@ Params Types: Response Types: +- together.FileList - together.FilePurpose +- together.FileResponse - together.FileType -- together.FileGetResponse -- together.FileListResponse - together.FileDeleteResponse -- together.FileUploadResponse Methods: -- client.Files.Get(ctx context.Context, id string) (together.FileGetResponse, error) -- client.Files.List(ctx context.Context) (together.FileListResponse, error) +- client.Files.Get(ctx context.Context, id string) (together.FileResponse, error) +- client.Files.List(ctx context.Context) (together.FileList, error) - client.Files.Delete(ctx context.Context, id string) (together.FileDeleteResponse, error) - client.Files.Content(ctx context.Context, id string) (http.Response, error) -- client.Files.Upload(ctx context.Context, body together.FileUploadParams) (together.FileUploadResponse, error) +- client.Files.Upload(ctx context.Context, body together.FileUploadParams) (together.FileResponse, error) -# FineTune - -Params Types: - -- together.CosineLrSchedulerArgsParam -- together.FullTrainingTypeParam -- together.LinearLrSchedulerArgsParam -- together.LoRaTrainingTypeParam -- together.LrSchedulerParam -- together.TrainingMethodDpoParam -- together.TrainingMethodSftParam +# FineTuning Response Types: -- together.CosineLrSchedulerArgs -- together.FineTune -- together.FineTuneEvent -- together.FullTrainingType -- together.LinearLrSchedulerArgs -- together.LoRaTrainingType -- together.LrScheduler -- together.TrainingMethodDpo -- together.TrainingMethodSft -- together.FineTuneNewResponse -- together.FineTuneListResponse -- together.FineTuneCancelResponse -- together.FineTuneDownloadResponse -- together.FineTuneListEventsResponse -- together.FineTuneGetCheckpointsResponse +- together.FinetuneEvent +- together.FinetuneEventType +- together.FinetuneResponse +- together.FineTuningNewResponse +- together.FineTuningListResponse +- together.FineTuningDeleteResponse +- together.FineTuningCancelResponse +- together.FineTuningListCheckpointsResponse +- together.FineTuningListEventsResponse Methods: -- client.FineTune.New(ctx context.Context, body together.FineTuneNewParams) (together.FineTuneNewResponse, error) -- client.FineTune.Get(ctx context.Context, id string) (together.FineTune, error) -- client.FineTune.List(ctx context.Context) (together.FineTuneListResponse, error) -- client.FineTune.Cancel(ctx context.Context, id string) (together.FineTuneCancelResponse, error) -- client.FineTune.Download(ctx context.Context, query together.FineTuneDownloadParams) (together.FineTuneDownloadResponse, error) -- client.FineTune.ListEvents(ctx context.Context, id string) (together.FineTuneListEventsResponse, error) -- client.FineTune.GetCheckpoints(ctx context.Context, id string) (together.FineTuneGetCheckpointsResponse, error) +- client.FineTuning.New(ctx context.Context, body together.FineTuningNewParams) (together.FineTuningNewResponse, error) +- client.FineTuning.Get(ctx context.Context, id string) (together.FinetuneResponse, error) +- client.FineTuning.List(ctx context.Context) (together.FineTuningListResponse, error) +- client.FineTuning.Delete(ctx context.Context, id string, body together.FineTuningDeleteParams) (together.FineTuningDeleteResponse, error) +- client.FineTuning.Cancel(ctx context.Context, id string) (together.FineTuningCancelResponse, error) +- client.FineTuning.Content(ctx context.Context, query together.FineTuningContentParams) (http.Response, error) +- client.FineTuning.ListCheckpoints(ctx context.Context, id string) (together.FineTuningListCheckpointsResponse, error) +- client.FineTuning.ListEvents(ctx context.Context, id string) (together.FineTuningListEventsResponse, error) # CodeInterpreter @@ -151,18 +125,17 @@ Response Types: Methods: -- client.Images.New(ctx context.Context, body together.ImageNewParams) (together.ImageFile, error) +- client.Images.Generate(ctx context.Context, body together.ImageGenerateParams) (together.ImageFile, error) # Videos Response Types: - together.VideoJob -- together.VideoNewResponse Methods: -- client.Videos.New(ctx context.Context, body together.VideoNewParams) (together.VideoNewResponse, error) +- client.Videos.New(ctx context.Context, body together.VideoNewParams) (together.VideoJob, error) - client.Videos.Get(ctx context.Context, id string) (together.VideoJob, error) # Audio @@ -171,9 +144,21 @@ Response Types: - together.AudioSpeechStreamChunk +## Speech + Methods: -- client.Audio.New(ctx context.Context, body together.AudioNewParams) (http.Response, error) +- client.Audio.Speech.New(ctx context.Context, body together.AudioSpeechNewParams) (http.Response, error) + +## Voices + +Response Types: + +- together.AudioVoiceListResponse + +Methods: + +- client.Audio.Voices.List(ctx context.Context) (together.AudioVoiceListResponse, error) ## Transcriptions @@ -199,12 +184,12 @@ Methods: Response Types: -- together.ModelListResponse +- together.ModelObject - together.ModelUploadResponse Methods: -- client.Models.List(ctx context.Context) ([]together.ModelListResponse, error) +- client.Models.List(ctx context.Context) ([]together.ModelObject, error) - client.Models.Upload(ctx context.Context, body together.ModelUploadParams) (together.ModelUploadResponse, error) # Jobs @@ -228,18 +213,18 @@ Params Types: Response Types: - together.Autoscaling -- together.EndpointNewResponse -- together.EndpointGetResponse -- together.EndpointUpdateResponse +- together.DedicatedEndpoint - together.EndpointListResponse +- together.EndpointListAvzonesResponse Methods: -- client.Endpoints.New(ctx context.Context, body together.EndpointNewParams) (together.EndpointNewResponse, error) -- client.Endpoints.Get(ctx context.Context, endpointID string) (together.EndpointGetResponse, error) -- client.Endpoints.Update(ctx context.Context, endpointID string, body together.EndpointUpdateParams) (together.EndpointUpdateResponse, error) +- client.Endpoints.New(ctx context.Context, body together.EndpointNewParams) (together.DedicatedEndpoint, error) +- client.Endpoints.Get(ctx context.Context, endpointID string) (together.DedicatedEndpoint, error) +- client.Endpoints.Update(ctx context.Context, endpointID string, body together.EndpointUpdateParams) (together.DedicatedEndpoint, error) - client.Endpoints.List(ctx context.Context, query together.EndpointListParams) (together.EndpointListResponse, error) - client.Endpoints.Delete(ctx context.Context, endpointID string) error +- client.Endpoints.ListAvzones(ctx context.Context) (together.EndpointListAvzonesResponse, error) # Hardware @@ -251,32 +236,41 @@ Methods: - client.Hardware.List(ctx context.Context, query together.HardwareListParams) (together.HardwareListResponse, error) +# Rerank + +Response Types: + +- together.RerankNewResponse + +Methods: + +- client.Rerank.New(ctx context.Context, body together.RerankNewParams) (together.RerankNewResponse, error) + # Batches Response Types: +- together.BatchJob - together.BatchNewResponse -- together.BatchGetResponse -- together.BatchListResponse Methods: - client.Batches.New(ctx context.Context, body together.BatchNewParams) (together.BatchNewResponse, error) -- client.Batches.Get(ctx context.Context, id string) (together.BatchGetResponse, error) -- client.Batches.List(ctx context.Context) ([]together.BatchListResponse, error) +- client.Batches.Get(ctx context.Context, id string) (together.BatchJob, error) +- client.Batches.List(ctx context.Context) ([]together.BatchJob, error) +- client.Batches.Cancel(ctx context.Context, id string) (together.BatchJob, error) # Evals Response Types: -- together.EvalGetResponse -- together.EvalListResponse -- together.EvalGetAllowedModelsResponse -- together.EvalGetStatusResponse +- together.EvaluationJob +- together.EvalNewResponse +- together.EvalStatusResponse Methods: -- client.Evals.Get(ctx context.Context, id string) (together.EvalGetResponse, error) -- client.Evals.List(ctx context.Context, query together.EvalListParams) ([]together.EvalListResponse, error) -- client.Evals.GetAllowedModels(ctx context.Context) (together.EvalGetAllowedModelsResponse, error) -- client.Evals.GetStatus(ctx context.Context, id string) (together.EvalGetStatusResponse, error) +- client.Evals.New(ctx context.Context, body together.EvalNewParams) (together.EvalNewResponse, error) +- client.Evals.Get(ctx context.Context, id string) (together.EvaluationJob, error) +- client.Evals.List(ctx context.Context, query together.EvalListParams) ([]together.EvaluationJob, error) +- client.Evals.Status(ctx context.Context, id string) (together.EvalStatusResponse, error) diff --git a/audio.go b/audio.go index ae30c7c7..edc12c3e 100644 --- a/audio.go +++ b/audio.go @@ -3,16 +3,9 @@ package together import ( - "context" - "net/http" - "slices" - "github.com/togethercomputer/together-go/internal/apijson" - "github.com/togethercomputer/together-go/internal/requestconfig" "github.com/togethercomputer/together-go/option" - "github.com/togethercomputer/together-go/packages/param" "github.com/togethercomputer/together-go/packages/respjson" - "github.com/togethercomputer/together-go/packages/ssestream" ) // AudioService contains methods and other services that help with interacting with @@ -23,6 +16,8 @@ import ( // the [NewAudioService] method instead. type AudioService struct { Options []option.RequestOption + Speech AudioSpeechService + Voices AudioVoiceService Transcriptions AudioTranscriptionService Translations AudioTranslationService } @@ -33,33 +28,13 @@ type AudioService struct { func NewAudioService(opts ...option.RequestOption) (r AudioService) { r = AudioService{} r.Options = opts + r.Speech = NewAudioSpeechService(opts...) + r.Voices = NewAudioVoiceService(opts...) r.Transcriptions = NewAudioTranscriptionService(opts...) r.Translations = NewAudioTranslationService(opts...) return } -// Generate audio from input text -func (r *AudioService) New(ctx context.Context, body AudioNewParams, opts ...option.RequestOption) (res *http.Response, err error) { - opts = slices.Concat(r.Options, opts) - opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream")}, opts...) - path := "audio/speech" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) - return -} - -// Generate audio from input text -func (r *AudioService) NewStreaming(ctx context.Context, body AudioNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[AudioSpeechStreamChunk]) { - var ( - raw *http.Response - err error - ) - opts = slices.Concat(r.Options, opts) - opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream"), option.WithJSONSet("stream", true)}, opts...) - path := "audio/speech" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) - return ssestream.NewStream[AudioSpeechStreamChunk](ssestream.NewDecoder(raw), err) -} - type AudioSpeechStreamChunk struct { // base64 encoded audio stream B64 string `json:"b64,required"` @@ -87,105 +62,3 @@ type AudioSpeechStreamChunkObject string const ( AudioSpeechStreamChunkObjectAudioTtsChunk AudioSpeechStreamChunkObject = "audio.tts.chunk" ) - -type AudioNewParams struct { - // Input text to generate the audio for - Input string `json:"input,required"` - // The name of the model to query. - // - // [See all of Together AI's chat models](https://docs.together.ai/docs/serverless-models#audio-models) - // The current supported tts models are: - cartesia/sonic - hexgrad/Kokoro-82M - - // canopylabs/orpheus-3b-0.1-ft - Model AudioNewParamsModel `json:"model,omitzero,required"` - // The voice to use for generating the audio. The voices supported are different - // for each model. For eg - for canopylabs/orpheus-3b-0.1-ft, one of the voices - // supported is tara, for hexgrad/Kokoro-82M, one of the voices supported is - // af_alloy and for cartesia/sonic, one of the voices supported is "friendly - // sidekick". - // - // You can view the voices supported for each model using the /v1/voices endpoint - // sending the model name as the query parameter. - // [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). - Voice string `json:"voice,required"` - // Sampling rate to use for the output audio. The default sampling rate for - // canopylabs/orpheus-3b-0.1-ft and hexgrad/Kokoro-82M is 24000 and for - // cartesia/sonic is 44100. - SampleRate param.Opt[float64] `json:"sample_rate,omitzero"` - // Language of input text. - // - // Any of "en", "de", "fr", "es", "hi", "it", "ja", "ko", "nl", "pl", "pt", "ru", - // "sv", "tr", "zh". - Language AudioNewParamsLanguage `json:"language,omitzero"` - // Audio encoding of response - // - // Any of "pcm_f32le", "pcm_s16le", "pcm_mulaw", "pcm_alaw". - ResponseEncoding AudioNewParamsResponseEncoding `json:"response_encoding,omitzero"` - // The format of audio output. Supported formats are mp3, wav, raw if streaming is - // false. If streaming is true, the only supported format is raw. - // - // Any of "mp3", "wav", "raw". - ResponseFormat AudioNewParamsResponseFormat `json:"response_format,omitzero"` - paramObj -} - -func (r AudioNewParams) MarshalJSON() (data []byte, err error) { - type shadow AudioNewParams - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *AudioNewParams) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// The name of the model to query. -// -// [See all of Together AI's chat models](https://docs.together.ai/docs/serverless-models#audio-models) -// The current supported tts models are: - cartesia/sonic - hexgrad/Kokoro-82M - -// canopylabs/orpheus-3b-0.1-ft -type AudioNewParamsModel string - -const ( - AudioNewParamsModelCartesiaSonic AudioNewParamsModel = "cartesia/sonic" - AudioNewParamsModelHexgradKokoro82M AudioNewParamsModel = "hexgrad/Kokoro-82M" - AudioNewParamsModelCanopylabsOrpheus3b0_1Ft AudioNewParamsModel = "canopylabs/orpheus-3b-0.1-ft" -) - -// Language of input text. -type AudioNewParamsLanguage string - -const ( - AudioNewParamsLanguageEn AudioNewParamsLanguage = "en" - AudioNewParamsLanguageDe AudioNewParamsLanguage = "de" - AudioNewParamsLanguageFr AudioNewParamsLanguage = "fr" - AudioNewParamsLanguageEs AudioNewParamsLanguage = "es" - AudioNewParamsLanguageHi AudioNewParamsLanguage = "hi" - AudioNewParamsLanguageIt AudioNewParamsLanguage = "it" - AudioNewParamsLanguageJa AudioNewParamsLanguage = "ja" - AudioNewParamsLanguageKo AudioNewParamsLanguage = "ko" - AudioNewParamsLanguageNl AudioNewParamsLanguage = "nl" - AudioNewParamsLanguagePl AudioNewParamsLanguage = "pl" - AudioNewParamsLanguagePt AudioNewParamsLanguage = "pt" - AudioNewParamsLanguageRu AudioNewParamsLanguage = "ru" - AudioNewParamsLanguageSv AudioNewParamsLanguage = "sv" - AudioNewParamsLanguageTr AudioNewParamsLanguage = "tr" - AudioNewParamsLanguageZh AudioNewParamsLanguage = "zh" -) - -// Audio encoding of response -type AudioNewParamsResponseEncoding string - -const ( - AudioNewParamsResponseEncodingPcmF32le AudioNewParamsResponseEncoding = "pcm_f32le" - AudioNewParamsResponseEncodingPcmS16le AudioNewParamsResponseEncoding = "pcm_s16le" - AudioNewParamsResponseEncodingPcmMulaw AudioNewParamsResponseEncoding = "pcm_mulaw" - AudioNewParamsResponseEncodingPcmAlaw AudioNewParamsResponseEncoding = "pcm_alaw" -) - -// The format of audio output. Supported formats are mp3, wav, raw if streaming is -// false. If streaming is true, the only supported format is raw. -type AudioNewParamsResponseFormat string - -const ( - AudioNewParamsResponseFormatMP3 AudioNewParamsResponseFormat = "mp3" - AudioNewParamsResponseFormatWav AudioNewParamsResponseFormat = "wav" - AudioNewParamsResponseFormatRaw AudioNewParamsResponseFormat = "raw" -) diff --git a/audiospeech.go b/audiospeech.go new file mode 100644 index 00000000..6f725284 --- /dev/null +++ b/audiospeech.go @@ -0,0 +1,158 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package together + +import ( + "context" + "net/http" + "slices" + + "github.com/togethercomputer/together-go/internal/apijson" + "github.com/togethercomputer/together-go/internal/requestconfig" + "github.com/togethercomputer/together-go/option" + "github.com/togethercomputer/together-go/packages/param" + "github.com/togethercomputer/together-go/packages/ssestream" +) + +// AudioSpeechService contains methods and other services that help with +// interacting with the together API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioSpeechService] method instead. +type AudioSpeechService struct { + Options []option.RequestOption +} + +// NewAudioSpeechService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewAudioSpeechService(opts ...option.RequestOption) (r AudioSpeechService) { + r = AudioSpeechService{} + r.Options = opts + return +} + +// Generate audio from input text +func (r *AudioSpeechService) New(ctx context.Context, body AudioSpeechNewParams, opts ...option.RequestOption) (res *http.Response, err error) { + opts = slices.Concat(r.Options, opts) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream")}, opts...) + path := "audio/speech" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Generate audio from input text +func (r *AudioSpeechService) NewStreaming(ctx context.Context, body AudioSpeechNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[AudioSpeechStreamChunk]) { + var ( + raw *http.Response + err error + ) + opts = slices.Concat(r.Options, opts) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream"), option.WithJSONSet("stream", true)}, opts...) + path := "audio/speech" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[AudioSpeechStreamChunk](ssestream.NewDecoder(raw), err) +} + +type AudioSpeechNewParams struct { + // Input text to generate the audio for + Input string `json:"input,required"` + // The name of the model to query. + // + // [See all of Together AI's chat models](https://docs.together.ai/docs/serverless-models#audio-models) + // The current supported tts models are: - cartesia/sonic - hexgrad/Kokoro-82M - + // canopylabs/orpheus-3b-0.1-ft + Model AudioSpeechNewParamsModel `json:"model,omitzero,required"` + // The voice to use for generating the audio. The voices supported are different + // for each model. For eg - for canopylabs/orpheus-3b-0.1-ft, one of the voices + // supported is tara, for hexgrad/Kokoro-82M, one of the voices supported is + // af_alloy and for cartesia/sonic, one of the voices supported is "friendly + // sidekick". + // + // You can view the voices supported for each model using the /v1/voices endpoint + // sending the model name as the query parameter. + // [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + Voice string `json:"voice,required"` + // Sampling rate to use for the output audio. The default sampling rate for + // canopylabs/orpheus-3b-0.1-ft and hexgrad/Kokoro-82M is 24000 and for + // cartesia/sonic is 44100. + SampleRate param.Opt[int64] `json:"sample_rate,omitzero"` + // Language of input text. + // + // Any of "en", "de", "fr", "es", "hi", "it", "ja", "ko", "nl", "pl", "pt", "ru", + // "sv", "tr", "zh". + Language AudioSpeechNewParamsLanguage `json:"language,omitzero"` + // Audio encoding of response + // + // Any of "pcm_f32le", "pcm_s16le", "pcm_mulaw", "pcm_alaw". + ResponseEncoding AudioSpeechNewParamsResponseEncoding `json:"response_encoding,omitzero"` + // The format of audio output. Supported formats are mp3, wav, raw if streaming is + // false. If streaming is true, the only supported format is raw. + // + // Any of "mp3", "wav", "raw". + ResponseFormat AudioSpeechNewParamsResponseFormat `json:"response_format,omitzero"` + paramObj +} + +func (r AudioSpeechNewParams) MarshalJSON() (data []byte, err error) { + type shadow AudioSpeechNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AudioSpeechNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The name of the model to query. +// +// [See all of Together AI's chat models](https://docs.together.ai/docs/serverless-models#audio-models) +// The current supported tts models are: - cartesia/sonic - hexgrad/Kokoro-82M - +// canopylabs/orpheus-3b-0.1-ft +type AudioSpeechNewParamsModel string + +const ( + AudioSpeechNewParamsModelCartesiaSonic AudioSpeechNewParamsModel = "cartesia/sonic" + AudioSpeechNewParamsModelHexgradKokoro82M AudioSpeechNewParamsModel = "hexgrad/Kokoro-82M" + AudioSpeechNewParamsModelCanopylabsOrpheus3b0_1Ft AudioSpeechNewParamsModel = "canopylabs/orpheus-3b-0.1-ft" +) + +// Language of input text. +type AudioSpeechNewParamsLanguage string + +const ( + AudioSpeechNewParamsLanguageEn AudioSpeechNewParamsLanguage = "en" + AudioSpeechNewParamsLanguageDe AudioSpeechNewParamsLanguage = "de" + AudioSpeechNewParamsLanguageFr AudioSpeechNewParamsLanguage = "fr" + AudioSpeechNewParamsLanguageEs AudioSpeechNewParamsLanguage = "es" + AudioSpeechNewParamsLanguageHi AudioSpeechNewParamsLanguage = "hi" + AudioSpeechNewParamsLanguageIt AudioSpeechNewParamsLanguage = "it" + AudioSpeechNewParamsLanguageJa AudioSpeechNewParamsLanguage = "ja" + AudioSpeechNewParamsLanguageKo AudioSpeechNewParamsLanguage = "ko" + AudioSpeechNewParamsLanguageNl AudioSpeechNewParamsLanguage = "nl" + AudioSpeechNewParamsLanguagePl AudioSpeechNewParamsLanguage = "pl" + AudioSpeechNewParamsLanguagePt AudioSpeechNewParamsLanguage = "pt" + AudioSpeechNewParamsLanguageRu AudioSpeechNewParamsLanguage = "ru" + AudioSpeechNewParamsLanguageSv AudioSpeechNewParamsLanguage = "sv" + AudioSpeechNewParamsLanguageTr AudioSpeechNewParamsLanguage = "tr" + AudioSpeechNewParamsLanguageZh AudioSpeechNewParamsLanguage = "zh" +) + +// Audio encoding of response +type AudioSpeechNewParamsResponseEncoding string + +const ( + AudioSpeechNewParamsResponseEncodingPcmF32le AudioSpeechNewParamsResponseEncoding = "pcm_f32le" + AudioSpeechNewParamsResponseEncodingPcmS16le AudioSpeechNewParamsResponseEncoding = "pcm_s16le" + AudioSpeechNewParamsResponseEncodingPcmMulaw AudioSpeechNewParamsResponseEncoding = "pcm_mulaw" + AudioSpeechNewParamsResponseEncodingPcmAlaw AudioSpeechNewParamsResponseEncoding = "pcm_alaw" +) + +// The format of audio output. Supported formats are mp3, wav, raw if streaming is +// false. If streaming is true, the only supported format is raw. +type AudioSpeechNewParamsResponseFormat string + +const ( + AudioSpeechNewParamsResponseFormatMP3 AudioSpeechNewParamsResponseFormat = "mp3" + AudioSpeechNewParamsResponseFormatWav AudioSpeechNewParamsResponseFormat = "wav" + AudioSpeechNewParamsResponseFormatRaw AudioSpeechNewParamsResponseFormat = "raw" +) diff --git a/audio_test.go b/audiospeech_test.go similarity index 71% rename from audio_test.go rename to audiospeech_test.go index cc6a0eae..2e00bc13 100644 --- a/audio_test.go +++ b/audiospeech_test.go @@ -15,7 +15,7 @@ import ( "github.com/togethercomputer/together-go/option" ) -func TestAudioNewWithOptionalParams(t *testing.T) { +func TestAudioSpeechNewWithOptionalParams(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte("abc")) @@ -26,14 +26,14 @@ func TestAudioNewWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - resp, err := client.Audio.New(context.TODO(), together.AudioNewParams{ + resp, err := client.Audio.Speech.New(context.TODO(), together.AudioSpeechNewParams{ Input: "input", - Model: together.AudioNewParamsModelCartesiaSonic, + Model: together.AudioSpeechNewParamsModelCartesiaSonic, Voice: "voice", - Language: together.AudioNewParamsLanguageEn, - ResponseEncoding: together.AudioNewParamsResponseEncodingPcmF32le, - ResponseFormat: together.AudioNewParamsResponseFormatMP3, - SampleRate: together.Float(0), + Language: together.AudioSpeechNewParamsLanguageEn, + ResponseEncoding: together.AudioSpeechNewParamsResponseEncodingPcmF32le, + ResponseFormat: together.AudioSpeechNewParamsResponseFormatMP3, + SampleRate: together.Int(0), }) if err != nil { var apierr *together.Error diff --git a/audiovoice.go b/audiovoice.go new file mode 100644 index 00000000..e6004b85 --- /dev/null +++ b/audiovoice.go @@ -0,0 +1,95 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package together + +import ( + "context" + "net/http" + "slices" + + "github.com/togethercomputer/together-go/internal/apijson" + "github.com/togethercomputer/together-go/internal/requestconfig" + "github.com/togethercomputer/together-go/option" + "github.com/togethercomputer/together-go/packages/respjson" +) + +// AudioVoiceService contains methods and other services that help with interacting +// with the together API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewAudioVoiceService] method instead. +type AudioVoiceService struct { + Options []option.RequestOption +} + +// NewAudioVoiceService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewAudioVoiceService(opts ...option.RequestOption) (r AudioVoiceService) { + r = AudioVoiceService{} + r.Options = opts + return +} + +// Fetch available voices for each model +func (r *AudioVoiceService) List(ctx context.Context, opts ...option.RequestOption) (res *AudioVoiceListResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "voices" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Response containing a list of models and their available voices. +type AudioVoiceListResponse struct { + Data []AudioVoiceListResponseData `json:"data,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AudioVoiceListResponse) RawJSON() string { return r.JSON.raw } +func (r *AudioVoiceListResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Represents a model with its available voices. +type AudioVoiceListResponseData struct { + Model string `json:"model,required"` + Voices []AudioVoiceListResponseDataVoice `json:"voices,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Model respjson.Field + Voices respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AudioVoiceListResponseData) RawJSON() string { return r.JSON.raw } +func (r *AudioVoiceListResponseData) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type AudioVoiceListResponseDataVoice struct { + ID string `json:"id,required"` + Name string `json:"name,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Name respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r AudioVoiceListResponseDataVoice) RawJSON() string { return r.JSON.raw } +func (r *AudioVoiceListResponseDataVoice) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} diff --git a/audiovoice_test.go b/audiovoice_test.go new file mode 100644 index 00000000..0dfff34f --- /dev/null +++ b/audiovoice_test.go @@ -0,0 +1,36 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package together_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/togethercomputer/together-go" + "github.com/togethercomputer/together-go/internal/testutil" + "github.com/togethercomputer/together-go/option" +) + +func TestAudioVoiceList(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := together.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + _, err := client.Audio.Voices.List(context.TODO()) + if err != nil { + var apierr *together.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/batch.go b/batch.go index 413e2487..4d4870e4 100644 --- a/batch.go +++ b/batch.go @@ -45,7 +45,7 @@ func (r *BatchService) New(ctx context.Context, body BatchNewParams, opts ...opt } // Get details of a batch job by ID -func (r *BatchService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *BatchGetResponse, err error) { +func (r *BatchService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *BatchJob, err error) { opts = slices.Concat(r.Options, opts) if id == "" { err = errors.New("missing required id parameter") @@ -57,81 +57,26 @@ func (r *BatchService) Get(ctx context.Context, id string, opts ...option.Reques } // List all batch jobs for the authenticated user -func (r *BatchService) List(ctx context.Context, opts ...option.RequestOption) (res *[]BatchListResponse, err error) { +func (r *BatchService) List(ctx context.Context, opts ...option.RequestOption) (res *[]BatchJob, err error) { opts = slices.Concat(r.Options, opts) path := "batches" err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) return } -type BatchNewResponse struct { - Job BatchNewResponseJob `json:"job"` - Warning string `json:"warning"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Job respjson.Field - Warning respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r BatchNewResponse) RawJSON() string { return r.JSON.raw } -func (r *BatchNewResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type BatchNewResponseJob struct { - ID string `json:"id" format:"uuid"` - CompletedAt time.Time `json:"completed_at" format:"date-time"` - CreatedAt time.Time `json:"created_at" format:"date-time"` - Endpoint string `json:"endpoint"` - Error string `json:"error"` - ErrorFileID string `json:"error_file_id"` - // Size of input file in bytes - FileSizeBytes int64 `json:"file_size_bytes"` - InputFileID string `json:"input_file_id"` - JobDeadline time.Time `json:"job_deadline" format:"date-time"` - // Model used for processing requests - ModelID string `json:"model_id"` - OutputFileID string `json:"output_file_id"` - // Completion progress (0.0 to 100) - Progress float64 `json:"progress"` - // Current status of the batch job - // - // Any of "VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", - // "CANCELLED". - Status string `json:"status"` - UserID string `json:"user_id"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - CompletedAt respjson.Field - CreatedAt respjson.Field - Endpoint respjson.Field - Error respjson.Field - ErrorFileID respjson.Field - FileSizeBytes respjson.Field - InputFileID respjson.Field - JobDeadline respjson.Field - ModelID respjson.Field - OutputFileID respjson.Field - Progress respjson.Field - Status respjson.Field - UserID respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r BatchNewResponseJob) RawJSON() string { return r.JSON.raw } -func (r *BatchNewResponseJob) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) +// Cancel a batch job by ID +func (r *BatchService) Cancel(ctx context.Context, id string, opts ...option.RequestOption) (res *BatchJob, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("batches/%s/cancel", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return } -type BatchGetResponse struct { +type BatchJob struct { ID string `json:"id" format:"uuid"` CompletedAt time.Time `json:"completed_at" format:"date-time"` CreatedAt time.Time `json:"created_at" format:"date-time"` @@ -151,8 +96,8 @@ type BatchGetResponse struct { // // Any of "VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", // "CANCELLED". - Status BatchGetResponseStatus `json:"status"` - UserID string `json:"user_id"` + Status BatchJobStatus `json:"status"` + UserID string `json:"user_id"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { ID respjson.Field @@ -175,84 +120,41 @@ type BatchGetResponse struct { } // Returns the unmodified JSON received from the API -func (r BatchGetResponse) RawJSON() string { return r.JSON.raw } -func (r *BatchGetResponse) UnmarshalJSON(data []byte) error { +func (r BatchJob) RawJSON() string { return r.JSON.raw } +func (r *BatchJob) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // Current status of the batch job -type BatchGetResponseStatus string +type BatchJobStatus string const ( - BatchGetResponseStatusValidating BatchGetResponseStatus = "VALIDATING" - BatchGetResponseStatusInProgress BatchGetResponseStatus = "IN_PROGRESS" - BatchGetResponseStatusCompleted BatchGetResponseStatus = "COMPLETED" - BatchGetResponseStatusFailed BatchGetResponseStatus = "FAILED" - BatchGetResponseStatusExpired BatchGetResponseStatus = "EXPIRED" - BatchGetResponseStatusCancelled BatchGetResponseStatus = "CANCELLED" + BatchJobStatusValidating BatchJobStatus = "VALIDATING" + BatchJobStatusInProgress BatchJobStatus = "IN_PROGRESS" + BatchJobStatusCompleted BatchJobStatus = "COMPLETED" + BatchJobStatusFailed BatchJobStatus = "FAILED" + BatchJobStatusExpired BatchJobStatus = "EXPIRED" + BatchJobStatusCancelled BatchJobStatus = "CANCELLED" ) -type BatchListResponse struct { - ID string `json:"id" format:"uuid"` - CompletedAt time.Time `json:"completed_at" format:"date-time"` - CreatedAt time.Time `json:"created_at" format:"date-time"` - Endpoint string `json:"endpoint"` - Error string `json:"error"` - ErrorFileID string `json:"error_file_id"` - // Size of input file in bytes - FileSizeBytes int64 `json:"file_size_bytes"` - InputFileID string `json:"input_file_id"` - JobDeadline time.Time `json:"job_deadline" format:"date-time"` - // Model used for processing requests - ModelID string `json:"model_id"` - OutputFileID string `json:"output_file_id"` - // Completion progress (0.0 to 100) - Progress float64 `json:"progress"` - // Current status of the batch job - // - // Any of "VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", - // "CANCELLED". - Status BatchListResponseStatus `json:"status"` - UserID string `json:"user_id"` +type BatchNewResponse struct { + Job BatchJob `json:"job"` + Warning string `json:"warning"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { - ID respjson.Field - CompletedAt respjson.Field - CreatedAt respjson.Field - Endpoint respjson.Field - Error respjson.Field - ErrorFileID respjson.Field - FileSizeBytes respjson.Field - InputFileID respjson.Field - JobDeadline respjson.Field - ModelID respjson.Field - OutputFileID respjson.Field - Progress respjson.Field - Status respjson.Field - UserID respjson.Field - ExtraFields map[string]respjson.Field - raw string + Job respjson.Field + Warning respjson.Field + ExtraFields map[string]respjson.Field + raw string } `json:"-"` } // Returns the unmodified JSON received from the API -func (r BatchListResponse) RawJSON() string { return r.JSON.raw } -func (r *BatchListResponse) UnmarshalJSON(data []byte) error { +func (r BatchNewResponse) RawJSON() string { return r.JSON.raw } +func (r *BatchNewResponse) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -// Current status of the batch job -type BatchListResponseStatus string - -const ( - BatchListResponseStatusValidating BatchListResponseStatus = "VALIDATING" - BatchListResponseStatusInProgress BatchListResponseStatus = "IN_PROGRESS" - BatchListResponseStatusCompleted BatchListResponseStatus = "COMPLETED" - BatchListResponseStatusFailed BatchListResponseStatus = "FAILED" - BatchListResponseStatusExpired BatchListResponseStatus = "EXPIRED" - BatchListResponseStatusCancelled BatchListResponseStatus = "CANCELLED" -) - type BatchNewParams struct { // The endpoint to use for batch processing Endpoint string `json:"endpoint,required"` diff --git a/batch_test.go b/batch_test.go index 5b8ce3b6..14357add 100644 --- a/batch_test.go +++ b/batch_test.go @@ -84,3 +84,25 @@ func TestBatchList(t *testing.T) { t.Fatalf("err should be nil: %s", err.Error()) } } + +func TestBatchCancel(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := together.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + _, err := client.Batches.Cancel(context.TODO(), "batch_job_abc123def456") + if err != nil { + var apierr *together.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/chatcompletion.go b/chatcompletion.go index 86eaa812..31e6257f 100644 --- a/chatcompletion.go +++ b/chatcompletion.go @@ -224,6 +224,7 @@ type ChatCompletionChunkChoiceDelta struct { Content string `json:"content,nullable"` // Deprecated: deprecated FunctionCall ChatCompletionChunkChoiceDeltaFunctionCall `json:"function_call,nullable"` + Reasoning string `json:"reasoning,nullable"` TokenID int64 `json:"token_id"` ToolCalls []ToolChoice `json:"tool_calls"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. @@ -231,6 +232,7 @@ type ChatCompletionChunkChoiceDelta struct { Role respjson.Field Content respjson.Field FunctionCall respjson.Field + Reasoning respjson.Field TokenID respjson.Field ToolCalls respjson.Field ExtraFields map[string]respjson.Field diff --git a/client.go b/client.go index 787d25ba..707bd0cf 100644 --- a/client.go +++ b/client.go @@ -21,7 +21,7 @@ type Client struct { Completions CompletionService Embeddings EmbeddingService Files FileService - FineTune FineTuneService + FineTuning FineTuningService CodeInterpreter CodeInterpreterService Images ImageService Videos VideoService @@ -30,6 +30,7 @@ type Client struct { Jobs JobService Endpoints EndpointService Hardware HardwareService + Rerank RerankService Batches BatchService Evals EvalService } @@ -60,7 +61,7 @@ func NewClient(opts ...option.RequestOption) (r Client) { r.Completions = NewCompletionService(opts...) r.Embeddings = NewEmbeddingService(opts...) r.Files = NewFileService(opts...) - r.FineTune = NewFineTuneService(opts...) + r.FineTuning = NewFineTuningService(opts...) r.CodeInterpreter = NewCodeInterpreterService(opts...) r.Images = NewImageService(opts...) r.Videos = NewVideoService(opts...) @@ -69,6 +70,7 @@ func NewClient(opts ...option.RequestOption) (r Client) { r.Jobs = NewJobService(opts...) r.Endpoints = NewEndpointService(opts...) r.Hardware = NewHardwareService(opts...) + r.Rerank = NewRerankService(opts...) r.Batches = NewBatchService(opts...) r.Evals = NewEvalService(opts...) @@ -143,11 +145,3 @@ func (r *Client) Patch(ctx context.Context, path string, params any, res any, op func (r *Client) Delete(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { return r.Execute(ctx, http.MethodDelete, path, params, res, opts...) } - -// Query a reranker model -func (r *Client) Rerank(ctx context.Context, body RerankParams, opts ...option.RequestOption) (res *RerankResponse, err error) { - opts = slices.Concat(r.Options, opts) - path := "rerank" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) - return -} diff --git a/client_test.go b/client_test.go index daf801f0..113c6e6e 100644 --- a/client_test.go +++ b/client_test.go @@ -89,11 +89,11 @@ func TestRetryAfter(t *testing.T) { } attempts := len(retryCountHeaders) - if attempts != 6 { - t.Errorf("Expected %d attempts, got %d", 6, attempts) + if attempts != 3 { + t.Errorf("Expected %d attempts, got %d", 3, attempts) } - expectedRetryCountHeaders := []string{"0", "1", "2", "3", "4", "5"} + expectedRetryCountHeaders := []string{"0", "1", "2"} if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) } @@ -133,7 +133,7 @@ func TestDeleteRetryCountHeader(t *testing.T) { t.Error("Expected there to be a cancel error") } - expectedRetryCountHeaders := []string{"", "", "", "", "", ""} + expectedRetryCountHeaders := []string{"", "", ""} if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) } @@ -173,7 +173,7 @@ func TestOverwriteRetryCountHeader(t *testing.T) { t.Error("Expected there to be a cancel error") } - expectedRetryCountHeaders := []string{"42", "42", "42", "42", "42", "42"} + expectedRetryCountHeaders := []string{"42", "42", "42"} if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) } @@ -211,7 +211,7 @@ func TestRetryAfterMs(t *testing.T) { if err == nil { t.Error("Expected there to be a cancel error") } - if want := 6; attempts != want { + if want := 3; attempts != want { t.Errorf("Expected %d attempts, got %d", want, attempts) } } diff --git a/completion.go b/completion.go index 98114588..0aaa211f 100644 --- a/completion.go +++ b/completion.go @@ -212,6 +212,7 @@ type CompletionChunkChoiceDelta struct { Content string `json:"content,nullable"` // Deprecated: deprecated FunctionCall CompletionChunkChoiceDeltaFunctionCall `json:"function_call,nullable"` + Reasoning string `json:"reasoning,nullable"` TokenID int64 `json:"token_id"` ToolCalls []ToolChoice `json:"tool_calls"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. @@ -219,6 +220,7 @@ type CompletionChunkChoiceDelta struct { Role respjson.Field Content respjson.Field FunctionCall respjson.Field + Reasoning respjson.Field TokenID respjson.Field ToolCalls respjson.Field ExtraFields map[string]respjson.Field diff --git a/endpoint.go b/endpoint.go index 203e3421..d37f22ea 100644 --- a/endpoint.go +++ b/endpoint.go @@ -42,7 +42,7 @@ func NewEndpointService(opts ...option.RequestOption) (r EndpointService) { // Creates a new dedicated endpoint for serving models. The endpoint will // automatically start after creation. You can deploy any supported model on // hardware configurations that meet the model's requirements. -func (r *EndpointService) New(ctx context.Context, body EndpointNewParams, opts ...option.RequestOption) (res *EndpointNewResponse, err error) { +func (r *EndpointService) New(ctx context.Context, body EndpointNewParams, opts ...option.RequestOption) (res *DedicatedEndpoint, err error) { opts = slices.Concat(r.Options, opts) path := "endpoints" err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) @@ -51,7 +51,7 @@ func (r *EndpointService) New(ctx context.Context, body EndpointNewParams, opts // Retrieves details about a specific endpoint, including its current state, // configuration, and scaling settings. -func (r *EndpointService) Get(ctx context.Context, endpointID string, opts ...option.RequestOption) (res *EndpointGetResponse, err error) { +func (r *EndpointService) Get(ctx context.Context, endpointID string, opts ...option.RequestOption) (res *DedicatedEndpoint, err error) { opts = slices.Concat(r.Options, opts) if endpointID == "" { err = errors.New("missing required endpointId parameter") @@ -64,7 +64,7 @@ func (r *EndpointService) Get(ctx context.Context, endpointID string, opts ...op // Updates an existing endpoint's configuration. You can modify the display name, // autoscaling settings, or change the endpoint's state (start/stop). -func (r *EndpointService) Update(ctx context.Context, endpointID string, body EndpointUpdateParams, opts ...option.RequestOption) (res *EndpointUpdateResponse, err error) { +func (r *EndpointService) Update(ctx context.Context, endpointID string, body EndpointUpdateParams, opts ...option.RequestOption) (res *DedicatedEndpoint, err error) { opts = slices.Concat(r.Options, opts) if endpointID == "" { err = errors.New("missing required endpointId parameter") @@ -87,7 +87,7 @@ func (r *EndpointService) List(ctx context.Context, query EndpointListParams, op // Permanently deletes an endpoint. This action cannot be undone. func (r *EndpointService) Delete(ctx context.Context, endpointID string, opts ...option.RequestOption) (err error) { opts = slices.Concat(r.Options, opts) - opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "*/*")}, opts...) if endpointID == "" { err = errors.New("missing required endpointId parameter") return @@ -97,6 +97,14 @@ func (r *EndpointService) Delete(ctx context.Context, endpointID string, opts .. return } +// List all available availability zones. +func (r *EndpointService) ListAvzones(ctx context.Context, opts ...option.RequestOption) (res *EndpointListAvzonesResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "clusters/availability-zones" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + // Configuration for automatic scaling of replicas based on demand. type Autoscaling struct { // The maximum number of replicas to scale up to under load @@ -147,87 +155,7 @@ func (r *AutoscalingParam) UnmarshalJSON(data []byte) error { } // Details about a dedicated endpoint deployment -type EndpointNewResponse struct { - // Unique identifier for the endpoint - ID string `json:"id,required"` - // Configuration for automatic scaling of the endpoint - Autoscaling Autoscaling `json:"autoscaling,required"` - // Timestamp when the endpoint was created - CreatedAt time.Time `json:"created_at,required" format:"date-time"` - // Human-readable name for the endpoint - DisplayName string `json:"display_name,required"` - // The hardware configuration used for this endpoint - Hardware string `json:"hardware,required"` - // The model deployed on this endpoint - Model string `json:"model,required"` - // System name for the endpoint - Name string `json:"name,required"` - // The type of object - // - // Any of "endpoint". - Object EndpointNewResponseObject `json:"object,required"` - // The owner of this endpoint - Owner string `json:"owner,required"` - // Current state of the endpoint - // - // Any of "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "ERROR". - State EndpointNewResponseState `json:"state,required"` - // The type of endpoint - // - // Any of "dedicated". - Type EndpointNewResponseType `json:"type,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - Autoscaling respjson.Field - CreatedAt respjson.Field - DisplayName respjson.Field - Hardware respjson.Field - Model respjson.Field - Name respjson.Field - Object respjson.Field - Owner respjson.Field - State respjson.Field - Type respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r EndpointNewResponse) RawJSON() string { return r.JSON.raw } -func (r *EndpointNewResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// The type of object -type EndpointNewResponseObject string - -const ( - EndpointNewResponseObjectEndpoint EndpointNewResponseObject = "endpoint" -) - -// Current state of the endpoint -type EndpointNewResponseState string - -const ( - EndpointNewResponseStatePending EndpointNewResponseState = "PENDING" - EndpointNewResponseStateStarting EndpointNewResponseState = "STARTING" - EndpointNewResponseStateStarted EndpointNewResponseState = "STARTED" - EndpointNewResponseStateStopping EndpointNewResponseState = "STOPPING" - EndpointNewResponseStateStopped EndpointNewResponseState = "STOPPED" - EndpointNewResponseStateError EndpointNewResponseState = "ERROR" -) - -// The type of endpoint -type EndpointNewResponseType string - -const ( - EndpointNewResponseTypeDedicated EndpointNewResponseType = "dedicated" -) - -// Details about a dedicated endpoint deployment -type EndpointGetResponse struct { +type DedicatedEndpoint struct { // Unique identifier for the endpoint ID string `json:"id,required"` // Configuration for automatic scaling of the endpoint @@ -245,17 +173,17 @@ type EndpointGetResponse struct { // The type of object // // Any of "endpoint". - Object EndpointGetResponseObject `json:"object,required"` + Object DedicatedEndpointObject `json:"object,required"` // The owner of this endpoint Owner string `json:"owner,required"` // Current state of the endpoint // // Any of "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "ERROR". - State EndpointGetResponseState `json:"state,required"` + State DedicatedEndpointState `json:"state,required"` // The type of endpoint // // Any of "dedicated". - Type EndpointGetResponseType `json:"type,required"` + Type DedicatedEndpointType `json:"type,required"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { ID respjson.Field @@ -275,115 +203,35 @@ type EndpointGetResponse struct { } // Returns the unmodified JSON received from the API -func (r EndpointGetResponse) RawJSON() string { return r.JSON.raw } -func (r *EndpointGetResponse) UnmarshalJSON(data []byte) error { +func (r DedicatedEndpoint) RawJSON() string { return r.JSON.raw } +func (r *DedicatedEndpoint) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // The type of object -type EndpointGetResponseObject string +type DedicatedEndpointObject string const ( - EndpointGetResponseObjectEndpoint EndpointGetResponseObject = "endpoint" + DedicatedEndpointObjectEndpoint DedicatedEndpointObject = "endpoint" ) // Current state of the endpoint -type EndpointGetResponseState string +type DedicatedEndpointState string const ( - EndpointGetResponseStatePending EndpointGetResponseState = "PENDING" - EndpointGetResponseStateStarting EndpointGetResponseState = "STARTING" - EndpointGetResponseStateStarted EndpointGetResponseState = "STARTED" - EndpointGetResponseStateStopping EndpointGetResponseState = "STOPPING" - EndpointGetResponseStateStopped EndpointGetResponseState = "STOPPED" - EndpointGetResponseStateError EndpointGetResponseState = "ERROR" + DedicatedEndpointStatePending DedicatedEndpointState = "PENDING" + DedicatedEndpointStateStarting DedicatedEndpointState = "STARTING" + DedicatedEndpointStateStarted DedicatedEndpointState = "STARTED" + DedicatedEndpointStateStopping DedicatedEndpointState = "STOPPING" + DedicatedEndpointStateStopped DedicatedEndpointState = "STOPPED" + DedicatedEndpointStateError DedicatedEndpointState = "ERROR" ) // The type of endpoint -type EndpointGetResponseType string +type DedicatedEndpointType string const ( - EndpointGetResponseTypeDedicated EndpointGetResponseType = "dedicated" -) - -// Details about a dedicated endpoint deployment -type EndpointUpdateResponse struct { - // Unique identifier for the endpoint - ID string `json:"id,required"` - // Configuration for automatic scaling of the endpoint - Autoscaling Autoscaling `json:"autoscaling,required"` - // Timestamp when the endpoint was created - CreatedAt time.Time `json:"created_at,required" format:"date-time"` - // Human-readable name for the endpoint - DisplayName string `json:"display_name,required"` - // The hardware configuration used for this endpoint - Hardware string `json:"hardware,required"` - // The model deployed on this endpoint - Model string `json:"model,required"` - // System name for the endpoint - Name string `json:"name,required"` - // The type of object - // - // Any of "endpoint". - Object EndpointUpdateResponseObject `json:"object,required"` - // The owner of this endpoint - Owner string `json:"owner,required"` - // Current state of the endpoint - // - // Any of "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "ERROR". - State EndpointUpdateResponseState `json:"state,required"` - // The type of endpoint - // - // Any of "dedicated". - Type EndpointUpdateResponseType `json:"type,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - Autoscaling respjson.Field - CreatedAt respjson.Field - DisplayName respjson.Field - Hardware respjson.Field - Model respjson.Field - Name respjson.Field - Object respjson.Field - Owner respjson.Field - State respjson.Field - Type respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r EndpointUpdateResponse) RawJSON() string { return r.JSON.raw } -func (r *EndpointUpdateResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// The type of object -type EndpointUpdateResponseObject string - -const ( - EndpointUpdateResponseObjectEndpoint EndpointUpdateResponseObject = "endpoint" -) - -// Current state of the endpoint -type EndpointUpdateResponseState string - -const ( - EndpointUpdateResponseStatePending EndpointUpdateResponseState = "PENDING" - EndpointUpdateResponseStateStarting EndpointUpdateResponseState = "STARTING" - EndpointUpdateResponseStateStarted EndpointUpdateResponseState = "STARTED" - EndpointUpdateResponseStateStopping EndpointUpdateResponseState = "STOPPING" - EndpointUpdateResponseStateStopped EndpointUpdateResponseState = "STOPPED" - EndpointUpdateResponseStateError EndpointUpdateResponseState = "ERROR" -) - -// The type of endpoint -type EndpointUpdateResponseType string - -const ( - EndpointUpdateResponseTypeDedicated EndpointUpdateResponseType = "dedicated" + DedicatedEndpointTypeDedicated DedicatedEndpointType = "dedicated" ) type EndpointListResponse struct { @@ -456,6 +304,23 @@ const ( EndpointListResponseObjectList EndpointListResponseObject = "list" ) +// List of unique availability zones +type EndpointListAvzonesResponse struct { + Avzones []string `json:"avzones,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Avzones respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r EndpointListAvzonesResponse) RawJSON() string { return r.JSON.raw } +func (r *EndpointListAvzonesResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + type EndpointNewParams struct { // Configuration for automatic scaling of the endpoint Autoscaling AutoscalingParam `json:"autoscaling,omitzero,required"` @@ -467,6 +332,8 @@ type EndpointNewParams struct { // automatically stopped. Set to null, omit or set to 0 to disable automatic // timeout. InactiveTimeout param.Opt[int64] `json:"inactive_timeout,omitzero"` + // Create the endpoint in a specified availability zone (e.g., us-central-4b) + AvailabilityZone param.Opt[string] `json:"availability_zone,omitzero"` // Whether to disable the prompt cache for this endpoint DisablePromptCache param.Opt[bool] `json:"disable_prompt_cache,omitzero"` // Whether to disable speculative decoding for this endpoint @@ -528,10 +395,16 @@ const ( ) type EndpointListParams struct { + // If true, return only endpoints owned by the caller + Mine param.Opt[bool] `query:"mine,omitzero" json:"-"` // Filter endpoints by type // // Any of "dedicated", "serverless". Type EndpointListParamsType `query:"type,omitzero" json:"-"` + // Filter endpoints by usage type + // + // Any of "on-demand", "reserved". + UsageType EndpointListParamsUsageType `query:"usage_type,omitzero" json:"-"` paramObj } @@ -550,3 +423,11 @@ const ( EndpointListParamsTypeDedicated EndpointListParamsType = "dedicated" EndpointListParamsTypeServerless EndpointListParamsType = "serverless" ) + +// Filter endpoints by usage type +type EndpointListParamsUsageType string + +const ( + EndpointListParamsUsageTypeOnDemand EndpointListParamsUsageType = "on-demand" + EndpointListParamsUsageTypeReserved EndpointListParamsUsageType = "reserved" +) diff --git a/endpoint_test.go b/endpoint_test.go index 2f8c76fe..a871f070 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -32,6 +32,7 @@ func TestEndpointNewWithOptionalParams(t *testing.T) { }, Hardware: "1x_nvidia_a100_80gb_sxm", Model: "meta-llama/Llama-3-8b-chat-hf", + AvailabilityZone: together.String("availability_zone"), DisablePromptCache: together.Bool(true), DisableSpeculativeDecoding: together.Bool(true), DisplayName: together.String("My Llama3 70b endpoint"), @@ -116,7 +117,9 @@ func TestEndpointListWithOptionalParams(t *testing.T) { option.WithAPIKey("My API Key"), ) _, err := client.Endpoints.List(context.TODO(), together.EndpointListParams{ - Type: together.EndpointListParamsTypeDedicated, + Mine: together.Bool(true), + Type: together.EndpointListParamsTypeDedicated, + UsageType: together.EndpointListParamsUsageTypeOnDemand, }) if err != nil { var apierr *together.Error @@ -148,3 +151,25 @@ func TestEndpointDelete(t *testing.T) { t.Fatalf("err should be nil: %s", err.Error()) } } + +func TestEndpointListAvzones(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := together.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + _, err := client.Endpoints.ListAvzones(context.TODO()) + if err != nil { + var apierr *together.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/eval.go b/eval.go index 30719cd6..3ec3e054 100644 --- a/eval.go +++ b/eval.go @@ -14,6 +14,7 @@ import ( "github.com/togethercomputer/together-go/internal/apijson" "github.com/togethercomputer/together-go/internal/apiquery" + "github.com/togethercomputer/together-go/internal/paramutil" "github.com/togethercomputer/together-go/internal/requestconfig" "github.com/togethercomputer/together-go/option" "github.com/togethercomputer/together-go/packages/param" @@ -39,8 +40,16 @@ func NewEvalService(opts ...option.RequestOption) (r EvalService) { return } -// Get details of a specific evaluation job -func (r *EvalService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *EvalGetResponse, err error) { +// Create an evaluation job +func (r *EvalService) New(ctx context.Context, body EvalNewParams, opts ...option.RequestOption) (res *EvalNewResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "evaluation" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Get evaluation job details +func (r *EvalService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *EvaluationJob, err error) { opts = slices.Concat(r.Options, opts) if id == "" { err = errors.New("missing required id parameter") @@ -51,24 +60,16 @@ func (r *EvalService) Get(ctx context.Context, id string, opts ...option.Request return } -// Get a list of evaluation jobs with optional filtering -func (r *EvalService) List(ctx context.Context, query EvalListParams, opts ...option.RequestOption) (res *[]EvalListResponse, err error) { +// Get all evaluation jobs +func (r *EvalService) List(ctx context.Context, query EvalListParams, opts ...option.RequestOption) (res *[]EvaluationJob, err error) { opts = slices.Concat(r.Options, opts) - path := "evaluations" + path := "evaluation" err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) return } -// Get the list of models that are allowed for evaluation -func (r *EvalService) GetAllowedModels(ctx context.Context, opts ...option.RequestOption) (res *EvalGetAllowedModelsResponse, err error) { - opts = slices.Concat(r.Options, opts) - path := "evaluations/model-list" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) - return -} - -// Get the status and results of a specific evaluation job -func (r *EvalService) GetStatus(ctx context.Context, id string, opts ...option.RequestOption) (res *EvalGetStatusResponse, err error) { +// Get evaluation job status and results +func (r *EvalService) Status(ctx context.Context, id string, opts ...option.RequestOption) (res *EvalStatusResponse, err error) { opts = slices.Concat(r.Options, opts) if id == "" { err = errors.New("missing required id parameter") @@ -79,7 +80,7 @@ func (r *EvalService) GetStatus(ctx context.Context, id string, opts ...option.R return } -type EvalGetResponse struct { +type EvaluationJob struct { // When the job was created CreatedAt time.Time `json:"created_at" format:"date-time"` // ID of the job owner (admin only) @@ -87,17 +88,17 @@ type EvalGetResponse struct { // The parameters used for this evaluation Parameters map[string]any `json:"parameters"` // Results of the evaluation (when completed) - Results EvalGetResponseResultsUnion `json:"results,nullable"` + Results EvaluationJobResultsUnion `json:"results,nullable"` // Current status of the job // // Any of "pending", "queued", "running", "completed", "error", "user_error". - Status EvalGetResponseStatus `json:"status"` + Status EvaluationJobStatus `json:"status"` // History of status updates (admin only) - StatusUpdates []EvalGetResponseStatusUpdate `json:"status_updates"` + StatusUpdates []EvaluationJobStatusUpdate `json:"status_updates"` // The type of evaluation // // Any of "classify", "score", "compare". - Type EvalGetResponseType `json:"type"` + Type EvaluationJobType `json:"type"` // When the job was last updated UpdatedAt time.Time `json:"updated_at" format:"date-time"` // The evaluation job ID @@ -119,42 +120,42 @@ type EvalGetResponse struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponse) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponse) UnmarshalJSON(data []byte) error { +func (r EvaluationJob) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJob) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -// EvalGetResponseResultsUnion contains all possible properties and values from -// [EvalGetResponseResultsEvaluationClassifyResults], -// [EvalGetResponseResultsEvaluationScoreResults], -// [EvalGetResponseResultsEvaluationCompareResults], [EvalGetResponseResultsError]. +// EvaluationJobResultsUnion contains all possible properties and values from +// [EvaluationJobResultsEvaluationClassifyResults], +// [EvaluationJobResultsEvaluationScoreResults], +// [EvaluationJobResultsEvaluationCompareResults], [EvaluationJobResultsError]. // // Use the methods beginning with 'As' to cast the union to one of its variants. -type EvalGetResponseResultsUnion struct { +type EvaluationJobResultsUnion struct { GenerationFailCount float64 `json:"generation_fail_count"` - // This field is from variant [EvalGetResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvaluationJobResultsEvaluationClassifyResults]. InvalidLabelCount float64 `json:"invalid_label_count"` JudgeFailCount float64 `json:"judge_fail_count"` - // This field is from variant [EvalGetResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvaluationJobResultsEvaluationClassifyResults]. LabelCounts string `json:"label_counts"` - // This field is from variant [EvalGetResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvaluationJobResultsEvaluationClassifyResults]. PassPercentage float64 `json:"pass_percentage"` ResultFileID string `json:"result_file_id"` - // This field is from variant [EvalGetResponseResultsEvaluationScoreResults]. - AggregatedScores EvalGetResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` - // This field is from variant [EvalGetResponseResultsEvaluationScoreResults]. + // This field is from variant [EvaluationJobResultsEvaluationScoreResults]. + AggregatedScores EvaluationJobResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` + // This field is from variant [EvaluationJobResultsEvaluationScoreResults]. FailedSamples float64 `json:"failed_samples"` - // This field is from variant [EvalGetResponseResultsEvaluationScoreResults]. + // This field is from variant [EvaluationJobResultsEvaluationScoreResults]. InvalidScoreCount float64 `json:"invalid_score_count"` - // This field is from variant [EvalGetResponseResultsEvaluationCompareResults]. + // This field is from variant [EvaluationJobResultsEvaluationCompareResults]. AWins int64 `json:"A_wins"` - // This field is from variant [EvalGetResponseResultsEvaluationCompareResults]. + // This field is from variant [EvaluationJobResultsEvaluationCompareResults]. BWins int64 `json:"B_wins"` - // This field is from variant [EvalGetResponseResultsEvaluationCompareResults]. + // This field is from variant [EvaluationJobResultsEvaluationCompareResults]. NumSamples int64 `json:"num_samples"` - // This field is from variant [EvalGetResponseResultsEvaluationCompareResults]. + // This field is from variant [EvaluationJobResultsEvaluationCompareResults]. Ties int64 `json:"Ties"` - // This field is from variant [EvalGetResponseResultsError]. + // This field is from variant [EvaluationJobResultsError]. Error string `json:"error"` JSON struct { GenerationFailCount respjson.Field @@ -175,34 +176,34 @@ type EvalGetResponseResultsUnion struct { } `json:"-"` } -func (u EvalGetResponseResultsUnion) AsEvalGetResponseResultsEvaluationClassifyResults() (v EvalGetResponseResultsEvaluationClassifyResults) { +func (u EvaluationJobResultsUnion) AsEvaluationJobResultsEvaluationClassifyResults() (v EvaluationJobResultsEvaluationClassifyResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } -func (u EvalGetResponseResultsUnion) AsEvalGetResponseResultsEvaluationScoreResults() (v EvalGetResponseResultsEvaluationScoreResults) { +func (u EvaluationJobResultsUnion) AsEvaluationJobResultsEvaluationScoreResults() (v EvaluationJobResultsEvaluationScoreResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } -func (u EvalGetResponseResultsUnion) AsEvalGetResponseResultsEvaluationCompareResults() (v EvalGetResponseResultsEvaluationCompareResults) { +func (u EvaluationJobResultsUnion) AsEvaluationJobResultsEvaluationCompareResults() (v EvaluationJobResultsEvaluationCompareResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } -func (u EvalGetResponseResultsUnion) AsEvalGetResponseResultsError() (v EvalGetResponseResultsError) { +func (u EvaluationJobResultsUnion) AsEvaluationJobResultsError() (v EvaluationJobResultsError) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } // Returns the unmodified JSON received from the API -func (u EvalGetResponseResultsUnion) RawJSON() string { return u.JSON.raw } +func (u EvaluationJobResultsUnion) RawJSON() string { return u.JSON.raw } -func (r *EvalGetResponseResultsUnion) UnmarshalJSON(data []byte) error { +func (r *EvaluationJobResultsUnion) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetResponseResultsEvaluationClassifyResults struct { +type EvaluationJobResultsEvaluationClassifyResults struct { // Number of failed generations. GenerationFailCount float64 `json:"generation_fail_count,nullable"` // Number of invalid labels @@ -229,13 +230,13 @@ type EvalGetResponseResultsEvaluationClassifyResults struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseResultsEvaluationClassifyResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseResultsEvaluationClassifyResults) UnmarshalJSON(data []byte) error { +func (r EvaluationJobResultsEvaluationClassifyResults) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJobResultsEvaluationClassifyResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetResponseResultsEvaluationScoreResults struct { - AggregatedScores EvalGetResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` +type EvaluationJobResultsEvaluationScoreResults struct { + AggregatedScores EvaluationJobResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` // number of failed samples generated from model FailedSamples float64 `json:"failed_samples"` // Number of failed generations. @@ -260,12 +261,12 @@ type EvalGetResponseResultsEvaluationScoreResults struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseResultsEvaluationScoreResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseResultsEvaluationScoreResults) UnmarshalJSON(data []byte) error { +func (r EvaluationJobResultsEvaluationScoreResults) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJobResultsEvaluationScoreResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetResponseResultsEvaluationScoreResultsAggregatedScores struct { +type EvaluationJobResultsEvaluationScoreResultsAggregatedScores struct { MeanScore float64 `json:"mean_score"` PassPercentage float64 `json:"pass_percentage"` StdScore float64 `json:"std_score"` @@ -280,14 +281,14 @@ type EvalGetResponseResultsEvaluationScoreResultsAggregatedScores struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseResultsEvaluationScoreResultsAggregatedScores) RawJSON() string { +func (r EvaluationJobResultsEvaluationScoreResultsAggregatedScores) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseResultsEvaluationScoreResultsAggregatedScores) UnmarshalJSON(data []byte) error { +func (r *EvaluationJobResultsEvaluationScoreResultsAggregatedScores) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetResponseResultsEvaluationCompareResults struct { +type EvaluationJobResultsEvaluationCompareResults struct { // Number of times model A won AWins int64 `json:"A_wins"` // Number of times model B won @@ -317,12 +318,12 @@ type EvalGetResponseResultsEvaluationCompareResults struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseResultsEvaluationCompareResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseResultsEvaluationCompareResults) UnmarshalJSON(data []byte) error { +func (r EvaluationJobResultsEvaluationCompareResults) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJobResultsEvaluationCompareResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetResponseResultsError struct { +type EvaluationJobResultsError struct { Error string `json:"error"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { @@ -333,24 +334,24 @@ type EvalGetResponseResultsError struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseResultsError) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseResultsError) UnmarshalJSON(data []byte) error { +func (r EvaluationJobResultsError) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJobResultsError) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // Current status of the job -type EvalGetResponseStatus string +type EvaluationJobStatus string const ( - EvalGetResponseStatusPending EvalGetResponseStatus = "pending" - EvalGetResponseStatusQueued EvalGetResponseStatus = "queued" - EvalGetResponseStatusRunning EvalGetResponseStatus = "running" - EvalGetResponseStatusCompleted EvalGetResponseStatus = "completed" - EvalGetResponseStatusError EvalGetResponseStatus = "error" - EvalGetResponseStatusUserError EvalGetResponseStatus = "user_error" + EvaluationJobStatusPending EvaluationJobStatus = "pending" + EvaluationJobStatusQueued EvaluationJobStatus = "queued" + EvaluationJobStatusRunning EvaluationJobStatus = "running" + EvaluationJobStatusCompleted EvaluationJobStatus = "completed" + EvaluationJobStatusError EvaluationJobStatus = "error" + EvaluationJobStatusUserError EvaluationJobStatus = "user_error" ) -type EvalGetResponseStatusUpdate struct { +type EvaluationJobStatusUpdate struct { // Additional message for this update Message string `json:"message"` // The status at this update @@ -368,99 +369,102 @@ type EvalGetResponseStatusUpdate struct { } // Returns the unmodified JSON received from the API -func (r EvalGetResponseStatusUpdate) RawJSON() string { return r.JSON.raw } -func (r *EvalGetResponseStatusUpdate) UnmarshalJSON(data []byte) error { +func (r EvaluationJobStatusUpdate) RawJSON() string { return r.JSON.raw } +func (r *EvaluationJobStatusUpdate) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // The type of evaluation -type EvalGetResponseType string +type EvaluationJobType string const ( - EvalGetResponseTypeClassify EvalGetResponseType = "classify" - EvalGetResponseTypeScore EvalGetResponseType = "score" - EvalGetResponseTypeCompare EvalGetResponseType = "compare" + EvaluationJobTypeClassify EvaluationJobType = "classify" + EvaluationJobTypeScore EvaluationJobType = "score" + EvaluationJobTypeCompare EvaluationJobType = "compare" ) -type EvalListResponse struct { - // When the job was created - CreatedAt time.Time `json:"created_at" format:"date-time"` - // ID of the job owner (admin only) - OwnerID string `json:"owner_id"` - // The parameters used for this evaluation - Parameters map[string]any `json:"parameters"` - // Results of the evaluation (when completed) - Results EvalListResponseResultsUnion `json:"results,nullable"` - // Current status of the job - // - // Any of "pending", "queued", "running", "completed", "error", "user_error". - Status EvalListResponseStatus `json:"status"` - // History of status updates (admin only) - StatusUpdates []EvalListResponseStatusUpdate `json:"status_updates"` - // The type of evaluation +type EvalNewResponse struct { + // Initial status of the job // - // Any of "classify", "score", "compare". - Type EvalListResponseType `json:"type"` - // When the job was last updated - UpdatedAt time.Time `json:"updated_at" format:"date-time"` - // The evaluation job ID + // Any of "pending". + Status EvalNewResponseStatus `json:"status"` + // The ID of the created evaluation job WorkflowID string `json:"workflow_id"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { - CreatedAt respjson.Field - OwnerID respjson.Field - Parameters respjson.Field - Results respjson.Field - Status respjson.Field - StatusUpdates respjson.Field - Type respjson.Field - UpdatedAt respjson.Field - WorkflowID respjson.Field - ExtraFields map[string]respjson.Field - raw string + Status respjson.Field + WorkflowID respjson.Field + ExtraFields map[string]respjson.Field + raw string } `json:"-"` } // Returns the unmodified JSON received from the API -func (r EvalListResponse) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponse) UnmarshalJSON(data []byte) error { +func (r EvalNewResponse) RawJSON() string { return r.JSON.raw } +func (r *EvalNewResponse) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -// EvalListResponseResultsUnion contains all possible properties and values from -// [EvalListResponseResultsEvaluationClassifyResults], -// [EvalListResponseResultsEvaluationScoreResults], -// [EvalListResponseResultsEvaluationCompareResults], -// [EvalListResponseResultsError]. +// Initial status of the job +type EvalNewResponseStatus string + +const ( + EvalNewResponseStatusPending EvalNewResponseStatus = "pending" +) + +type EvalStatusResponse struct { + // The results of the evaluation job + Results EvalStatusResponseResultsUnion `json:"results"` + // The status of the evaluation job + // + // Any of "completed", "error", "user_error", "running", "queued", "pending". + Status EvalStatusResponseStatus `json:"status"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Results respjson.Field + Status respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r EvalStatusResponse) RawJSON() string { return r.JSON.raw } +func (r *EvalStatusResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// EvalStatusResponseResultsUnion contains all possible properties and values from +// [EvalStatusResponseResultsEvaluationClassifyResults], +// [EvalStatusResponseResultsEvaluationScoreResults], +// [EvalStatusResponseResultsEvaluationCompareResults]. // // Use the methods beginning with 'As' to cast the union to one of its variants. -type EvalListResponseResultsUnion struct { +type EvalStatusResponseResultsUnion struct { GenerationFailCount float64 `json:"generation_fail_count"` - // This field is from variant [EvalListResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationClassifyResults]. InvalidLabelCount float64 `json:"invalid_label_count"` JudgeFailCount float64 `json:"judge_fail_count"` - // This field is from variant [EvalListResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationClassifyResults]. LabelCounts string `json:"label_counts"` - // This field is from variant [EvalListResponseResultsEvaluationClassifyResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationClassifyResults]. PassPercentage float64 `json:"pass_percentage"` ResultFileID string `json:"result_file_id"` - // This field is from variant [EvalListResponseResultsEvaluationScoreResults]. - AggregatedScores EvalListResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` - // This field is from variant [EvalListResponseResultsEvaluationScoreResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationScoreResults]. + AggregatedScores EvalStatusResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` + // This field is from variant [EvalStatusResponseResultsEvaluationScoreResults]. FailedSamples float64 `json:"failed_samples"` - // This field is from variant [EvalListResponseResultsEvaluationScoreResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationScoreResults]. InvalidScoreCount float64 `json:"invalid_score_count"` - // This field is from variant [EvalListResponseResultsEvaluationCompareResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationCompareResults]. AWins int64 `json:"A_wins"` - // This field is from variant [EvalListResponseResultsEvaluationCompareResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationCompareResults]. BWins int64 `json:"B_wins"` - // This field is from variant [EvalListResponseResultsEvaluationCompareResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationCompareResults]. NumSamples int64 `json:"num_samples"` - // This field is from variant [EvalListResponseResultsEvaluationCompareResults]. + // This field is from variant [EvalStatusResponseResultsEvaluationCompareResults]. Ties int64 `json:"Ties"` - // This field is from variant [EvalListResponseResultsError]. - Error string `json:"error"` - JSON struct { + JSON struct { GenerationFailCount respjson.Field InvalidLabelCount respjson.Field JudgeFailCount respjson.Field @@ -474,39 +478,33 @@ type EvalListResponseResultsUnion struct { BWins respjson.Field NumSamples respjson.Field Ties respjson.Field - Error respjson.Field raw string } `json:"-"` } -func (u EvalListResponseResultsUnion) AsEvalListResponseResultsEvaluationClassifyResults() (v EvalListResponseResultsEvaluationClassifyResults) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u EvalListResponseResultsUnion) AsEvalListResponseResultsEvaluationScoreResults() (v EvalListResponseResultsEvaluationScoreResults) { +func (u EvalStatusResponseResultsUnion) AsEvalStatusResponseResultsEvaluationClassifyResults() (v EvalStatusResponseResultsEvaluationClassifyResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } -func (u EvalListResponseResultsUnion) AsEvalListResponseResultsEvaluationCompareResults() (v EvalListResponseResultsEvaluationCompareResults) { +func (u EvalStatusResponseResultsUnion) AsEvalStatusResponseResultsEvaluationScoreResults() (v EvalStatusResponseResultsEvaluationScoreResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } -func (u EvalListResponseResultsUnion) AsEvalListResponseResultsError() (v EvalListResponseResultsError) { +func (u EvalStatusResponseResultsUnion) AsEvalStatusResponseResultsEvaluationCompareResults() (v EvalStatusResponseResultsEvaluationCompareResults) { apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) return } // Returns the unmodified JSON received from the API -func (u EvalListResponseResultsUnion) RawJSON() string { return u.JSON.raw } +func (u EvalStatusResponseResultsUnion) RawJSON() string { return u.JSON.raw } -func (r *EvalListResponseResultsUnion) UnmarshalJSON(data []byte) error { +func (r *EvalStatusResponseResultsUnion) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalListResponseResultsEvaluationClassifyResults struct { +type EvalStatusResponseResultsEvaluationClassifyResults struct { // Number of failed generations. GenerationFailCount float64 `json:"generation_fail_count,nullable"` // Number of invalid labels @@ -533,13 +531,13 @@ type EvalListResponseResultsEvaluationClassifyResults struct { } // Returns the unmodified JSON received from the API -func (r EvalListResponseResultsEvaluationClassifyResults) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseResultsEvaluationClassifyResults) UnmarshalJSON(data []byte) error { +func (r EvalStatusResponseResultsEvaluationClassifyResults) RawJSON() string { return r.JSON.raw } +func (r *EvalStatusResponseResultsEvaluationClassifyResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalListResponseResultsEvaluationScoreResults struct { - AggregatedScores EvalListResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` +type EvalStatusResponseResultsEvaluationScoreResults struct { + AggregatedScores EvalStatusResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` // number of failed samples generated from model FailedSamples float64 `json:"failed_samples"` // Number of failed generations. @@ -564,12 +562,12 @@ type EvalListResponseResultsEvaluationScoreResults struct { } // Returns the unmodified JSON received from the API -func (r EvalListResponseResultsEvaluationScoreResults) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseResultsEvaluationScoreResults) UnmarshalJSON(data []byte) error { +func (r EvalStatusResponseResultsEvaluationScoreResults) RawJSON() string { return r.JSON.raw } +func (r *EvalStatusResponseResultsEvaluationScoreResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalListResponseResultsEvaluationScoreResultsAggregatedScores struct { +type EvalStatusResponseResultsEvaluationScoreResultsAggregatedScores struct { MeanScore float64 `json:"mean_score"` PassPercentage float64 `json:"pass_percentage"` StdScore float64 `json:"std_score"` @@ -584,14 +582,14 @@ type EvalListResponseResultsEvaluationScoreResultsAggregatedScores struct { } // Returns the unmodified JSON received from the API -func (r EvalListResponseResultsEvaluationScoreResultsAggregatedScores) RawJSON() string { +func (r EvalStatusResponseResultsEvaluationScoreResultsAggregatedScores) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseResultsEvaluationScoreResultsAggregatedScores) UnmarshalJSON(data []byte) error { +func (r *EvalStatusResponseResultsEvaluationScoreResultsAggregatedScores) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalListResponseResultsEvaluationCompareResults struct { +type EvalStatusResponseResultsEvaluationCompareResults struct { // Number of times model A won AWins int64 `json:"A_wins"` // Number of times model B won @@ -621,346 +619,804 @@ type EvalListResponseResultsEvaluationCompareResults struct { } // Returns the unmodified JSON received from the API -func (r EvalListResponseResultsEvaluationCompareResults) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseResultsEvaluationCompareResults) UnmarshalJSON(data []byte) error { +func (r EvalStatusResponseResultsEvaluationCompareResults) RawJSON() string { return r.JSON.raw } +func (r *EvalStatusResponseResultsEvaluationCompareResults) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalListResponseResultsError struct { - Error string `json:"error"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Error respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +// The status of the evaluation job +type EvalStatusResponseStatus string + +const ( + EvalStatusResponseStatusCompleted EvalStatusResponseStatus = "completed" + EvalStatusResponseStatusError EvalStatusResponseStatus = "error" + EvalStatusResponseStatusUserError EvalStatusResponseStatus = "user_error" + EvalStatusResponseStatusRunning EvalStatusResponseStatus = "running" + EvalStatusResponseStatusQueued EvalStatusResponseStatus = "queued" + EvalStatusResponseStatusPending EvalStatusResponseStatus = "pending" +) + +type EvalNewParams struct { + // Type-specific parameters for the evaluation + Parameters EvalNewParamsParametersUnion `json:"parameters,omitzero,required"` + // The type of evaluation to perform + // + // Any of "classify", "score", "compare". + Type EvalNewParamsType `json:"type,omitzero,required"` + paramObj } -// Returns the unmodified JSON received from the API -func (r EvalListResponseResultsError) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseResultsError) UnmarshalJSON(data []byte) error { +func (r EvalNewParams) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParams) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -// Current status of the job -type EvalListResponseStatus string +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type EvalNewParamsParametersUnion struct { + OfEvalNewsParametersEvaluationClassifyParameters *EvalNewParamsParametersEvaluationClassifyParameters `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationScoreParameters *EvalNewParamsParametersEvaluationScoreParameters `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationCompareParameters *EvalNewParamsParametersEvaluationCompareParameters `json:",omitzero,inline"` + paramUnion +} -const ( - EvalListResponseStatusPending EvalListResponseStatus = "pending" - EvalListResponseStatusQueued EvalListResponseStatus = "queued" - EvalListResponseStatusRunning EvalListResponseStatus = "running" - EvalListResponseStatusCompleted EvalListResponseStatus = "completed" - EvalListResponseStatusError EvalListResponseStatus = "error" - EvalListResponseStatusUserError EvalListResponseStatus = "user_error" -) +func (u EvalNewParamsParametersUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfEvalNewsParametersEvaluationClassifyParameters, u.OfEvalNewsParametersEvaluationScoreParameters, u.OfEvalNewsParametersEvaluationCompareParameters) +} +func (u *EvalNewParamsParametersUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} -type EvalListResponseStatusUpdate struct { - // Additional message for this update - Message string `json:"message"` - // The status at this update - Status string `json:"status"` - // When this update occurred - Timestamp time.Time `json:"timestamp" format:"date-time"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Message respjson.Field - Status respjson.Field - Timestamp respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +func (u *EvalNewParamsParametersUnion) asAny() any { + if !param.IsOmitted(u.OfEvalNewsParametersEvaluationClassifyParameters) { + return u.OfEvalNewsParametersEvaluationClassifyParameters + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationScoreParameters) { + return u.OfEvalNewsParametersEvaluationScoreParameters + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationCompareParameters) { + return u.OfEvalNewsParametersEvaluationCompareParameters + } + return nil } -// Returns the unmodified JSON received from the API -func (r EvalListResponseStatusUpdate) RawJSON() string { return r.JSON.raw } -func (r *EvalListResponseStatusUpdate) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetLabels() []string { + if vt := u.OfEvalNewsParametersEvaluationClassifyParameters; vt != nil { + return vt.Labels + } + return nil } -// The type of evaluation -type EvalListResponseType string +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetPassLabels() []string { + if vt := u.OfEvalNewsParametersEvaluationClassifyParameters; vt != nil { + return vt.PassLabels + } + return nil +} -const ( - EvalListResponseTypeClassify EvalListResponseType = "classify" - EvalListResponseTypeScore EvalListResponseType = "score" - EvalListResponseTypeCompare EvalListResponseType = "compare" -) +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetMaxScore() *float64 { + if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + return &vt.MaxScore + } + return nil +} -type EvalGetAllowedModelsResponse struct { - ModelList []string `json:"model_list"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ModelList respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetMinScore() *float64 { + if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + return &vt.MinScore + } + return nil } -// Returns the unmodified JSON received from the API -func (r EvalGetAllowedModelsResponse) RawJSON() string { return r.JSON.raw } -func (r *EvalGetAllowedModelsResponse) UnmarshalJSON(data []byte) error { +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetPassThreshold() *float64 { + if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + return &vt.PassThreshold + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetModelA() *EvalNewParamsParametersEvaluationCompareParametersModelAUnion { + if vt := u.OfEvalNewsParametersEvaluationCompareParameters; vt != nil { + return &vt.ModelA + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetModelB() *EvalNewParamsParametersEvaluationCompareParametersModelBUnion { + if vt := u.OfEvalNewsParametersEvaluationCompareParameters; vt != nil { + return &vt.ModelB + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u EvalNewParamsParametersUnion) GetInputDataFilePath() *string { + if vt := u.OfEvalNewsParametersEvaluationClassifyParameters; vt != nil { + return (*string)(&vt.InputDataFilePath) + } else if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + return (*string)(&vt.InputDataFilePath) + } else if vt := u.OfEvalNewsParametersEvaluationCompareParameters; vt != nil { + return (*string)(&vt.InputDataFilePath) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u EvalNewParamsParametersUnion) GetJudge() (res evalNewParamsParametersUnionJudge) { + if vt := u.OfEvalNewsParametersEvaluationClassifyParameters; vt != nil { + res.any = &vt.Judge + } else if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + res.any = &vt.Judge + } else if vt := u.OfEvalNewsParametersEvaluationCompareParameters; vt != nil { + res.any = &vt.Judge + } + return +} + +// Can have the runtime types +// [*EvalNewParamsParametersEvaluationClassifyParametersJudge], +// [*EvalNewParamsParametersEvaluationScoreParametersJudge], +// [*EvalNewParamsParametersEvaluationCompareParametersJudge] +type evalNewParamsParametersUnionJudge struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *together.EvalNewParamsParametersEvaluationClassifyParametersJudge: +// case *together.EvalNewParamsParametersEvaluationScoreParametersJudge: +// case *together.EvalNewParamsParametersEvaluationCompareParametersJudge: +// default: +// fmt.Errorf("not present") +// } +func (u evalNewParamsParametersUnionJudge) AsAny() any { return u.any } + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionJudge) GetModel() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersJudge: + return (*string)(&vt.Model) + case *EvalNewParamsParametersEvaluationScoreParametersJudge: + return (*string)(&vt.Model) + case *EvalNewParamsParametersEvaluationCompareParametersJudge: + return (*string)(&vt.Model) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionJudge) GetModelSource() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersJudge: + return (*string)(&vt.ModelSource) + case *EvalNewParamsParametersEvaluationScoreParametersJudge: + return (*string)(&vt.ModelSource) + case *EvalNewParamsParametersEvaluationCompareParametersJudge: + return (*string)(&vt.ModelSource) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionJudge) GetSystemTemplate() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersJudge: + return (*string)(&vt.SystemTemplate) + case *EvalNewParamsParametersEvaluationScoreParametersJudge: + return (*string)(&vt.SystemTemplate) + case *EvalNewParamsParametersEvaluationCompareParametersJudge: + return (*string)(&vt.SystemTemplate) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionJudge) GetExternalAPIToken() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalAPIToken) + case *EvalNewParamsParametersEvaluationScoreParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalAPIToken) + case *EvalNewParamsParametersEvaluationCompareParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalAPIToken) + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionJudge) GetExternalBaseURL() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalBaseURL) + case *EvalNewParamsParametersEvaluationScoreParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalBaseURL) + case *EvalNewParamsParametersEvaluationCompareParametersJudge: + return paramutil.AddrIfPresent(vt.ExternalBaseURL) + } + return nil +} + +// Returns a subunion which exports methods to access subproperties +// +// Or use AsAny() to get the underlying value +func (u EvalNewParamsParametersUnion) GetModelToEvaluate() (res evalNewParamsParametersUnionModelToEvaluate) { + if vt := u.OfEvalNewsParametersEvaluationClassifyParameters; vt != nil { + res.any = vt.ModelToEvaluate.asAny() + } else if vt := u.OfEvalNewsParametersEvaluationScoreParameters; vt != nil { + res.any = vt.ModelToEvaluate.asAny() + } + return +} + +// Can have the runtime types [*string] +type evalNewParamsParametersUnionModelToEvaluate struct{ any } + +// Use the following switch statement to get the type of the union: +// +// switch u.AsAny().(type) { +// case *string: +// default: +// fmt.Errorf("not present") +// } +func (u evalNewParamsParametersUnionModelToEvaluate) AsAny() any { return u.any } + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetInputTemplate() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.InputTemplate) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.InputTemplate) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetMaxTokens() *int64 { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*int64)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.MaxTokens) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*int64)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.MaxTokens) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetModel() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.Model) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.Model) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetModelSource() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.ModelSource) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.ModelSource) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetSystemTemplate() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.SystemTemplate) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*string)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.SystemTemplate) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetTemperature() *float64 { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return (*float64)(&vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.Temperature) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return (*float64)(&vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.Temperature) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetExternalAPIToken() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return paramutil.AddrIfPresent(vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.ExternalAPIToken) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return paramutil.AddrIfPresent(vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.ExternalAPIToken) + } + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u evalNewParamsParametersUnionModelToEvaluate) GetExternalBaseURL() *string { + switch vt := u.any.(type) { + case *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest != nil { + return paramutil.AddrIfPresent(vt.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest.ExternalBaseURL) + } + case *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion: + if vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest != nil { + return paramutil.AddrIfPresent(vt.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest.ExternalBaseURL) + } + } + return nil +} + +// The properties InputDataFilePath, Judge, Labels, PassLabels are required. +type EvalNewParamsParametersEvaluationClassifyParameters struct { + // Data file ID + InputDataFilePath string `json:"input_data_file_path,required"` + Judge EvalNewParamsParametersEvaluationClassifyParametersJudge `json:"judge,omitzero,required"` + // List of possible classification labels + Labels []string `json:"labels,omitzero,required"` + // List of labels that are considered passing + PassLabels []string `json:"pass_labels,omitzero,required"` + // Field name in the input data + ModelToEvaluate EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion `json:"model_to_evaluate,omitzero"` + paramObj +} + +func (r EvalNewParamsParametersEvaluationClassifyParameters) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationClassifyParameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationClassifyParameters) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponse struct { - Results EvalGetStatusResponseResultsUnion `json:"results,nullable"` - // Any of "pending", "queued", "running", "completed", "error", "user_error". - Status EvalGetStatusResponseStatus `json:"status"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Results respjson.Field - Status respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +// The properties Model, ModelSource, SystemTemplate are required. +type EvalNewParamsParametersEvaluationClassifyParametersJudge struct { + // Name of the judge model + Model string `json:"model,required"` + // Source of the judge model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template for the judge + SystemTemplate string `json:"system_template,required"` + // Bearer/API token for external judge models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external judge models. Must be OpenAI-compatible base URL. + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponse) RawJSON() string { return r.JSON.raw } -func (r *EvalGetStatusResponse) UnmarshalJSON(data []byte) error { +func (r EvalNewParamsParametersEvaluationClassifyParametersJudge) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationClassifyParametersJudge + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationClassifyParametersJudge) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -// EvalGetStatusResponseResultsUnion contains all possible properties and values -// from [EvalGetStatusResponseResultsEvaluationClassifyResults], -// [EvalGetStatusResponseResultsEvaluationScoreResults], -// [EvalGetStatusResponseResultsEvaluationCompareResults], -// [EvalGetStatusResponseResultsError]. +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationClassifyParametersJudge]( + "model_source", "serverless", "dedicated", "external", + ) +} + +// Only one field can be non-zero. // -// Use the methods beginning with 'As' to cast the union to one of its variants. -type EvalGetStatusResponseResultsUnion struct { - GenerationFailCount float64 `json:"generation_fail_count"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationClassifyResults]. - InvalidLabelCount float64 `json:"invalid_label_count"` - JudgeFailCount float64 `json:"judge_fail_count"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationClassifyResults]. - LabelCounts string `json:"label_counts"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationClassifyResults]. - PassPercentage float64 `json:"pass_percentage"` - ResultFileID string `json:"result_file_id"` - // This field is from variant [EvalGetStatusResponseResultsEvaluationScoreResults]. - AggregatedScores EvalGetStatusResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` - // This field is from variant [EvalGetStatusResponseResultsEvaluationScoreResults]. - FailedSamples float64 `json:"failed_samples"` - // This field is from variant [EvalGetStatusResponseResultsEvaluationScoreResults]. - InvalidScoreCount float64 `json:"invalid_score_count"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationCompareResults]. - AWins int64 `json:"A_wins"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationCompareResults]. - BWins int64 `json:"B_wins"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationCompareResults]. - NumSamples int64 `json:"num_samples"` - // This field is from variant - // [EvalGetStatusResponseResultsEvaluationCompareResults]. - Ties int64 `json:"Ties"` - // This field is from variant [EvalGetStatusResponseResultsError]. - Error string `json:"error"` - JSON struct { - GenerationFailCount respjson.Field - InvalidLabelCount respjson.Field - JudgeFailCount respjson.Field - LabelCounts respjson.Field - PassPercentage respjson.Field - ResultFileID respjson.Field - AggregatedScores respjson.Field - FailedSamples respjson.Field - InvalidScoreCount respjson.Field - AWins respjson.Field - BWins respjson.Field - NumSamples respjson.Field - Ties respjson.Field - Error respjson.Field - raw string - } `json:"-"` +// Use [param.IsOmitted] to confirm if a field is set. +type EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest `json:",omitzero,inline"` + paramUnion } -func (u EvalGetStatusResponseResultsUnion) AsEvalGetStatusResponseResultsEvaluationClassifyResults() (v EvalGetStatusResponseResultsEvaluationClassifyResults) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return +func (u EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest) +} +func (u *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) } -func (u EvalGetStatusResponseResultsUnion) AsEvalGetStatusResponseResultsEvaluationScoreResults() (v EvalGetStatusResponseResultsEvaluationScoreResults) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return +func (u *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest) { + return u.OfEvalNewsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest + } + return nil +} + +// The properties InputTemplate, MaxTokens, Model, ModelSource, SystemTemplate, +// Temperature are required. +type EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest struct { + // Input prompt template + InputTemplate string `json:"input_template,required"` + // Maximum number of tokens to generate + MaxTokens int64 `json:"max_tokens,required"` + // Name of the model to evaluate + Model string `json:"model,required"` + // Source of the model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template + SystemTemplate string `json:"system_template,required"` + // Sampling temperature + Temperature float64 `json:"temperature,required"` + // Bearer/API token for external models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external models. Must be OpenAI-compatible base URL + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj } -func (u EvalGetStatusResponseResultsUnion) AsEvalGetStatusResponseResultsEvaluationCompareResults() (v EvalGetStatusResponseResultsEvaluationCompareResults) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return +func (r EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) } -func (u EvalGetStatusResponseResultsUnion) AsEvalGetStatusResponseResultsError() (v EvalGetStatusResponseResultsError) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest]( + "model_source", "serverless", "dedicated", "external", + ) +} + +// The properties InputDataFilePath, Judge, MaxScore, MinScore, PassThreshold are +// required. +type EvalNewParamsParametersEvaluationScoreParameters struct { + // Data file ID + InputDataFilePath string `json:"input_data_file_path,required"` + Judge EvalNewParamsParametersEvaluationScoreParametersJudge `json:"judge,omitzero,required"` + // Maximum possible score + MaxScore float64 `json:"max_score,required"` + // Minimum possible score + MinScore float64 `json:"min_score,required"` + // Score threshold for passing + PassThreshold float64 `json:"pass_threshold,required"` + // Field name in the input data + ModelToEvaluate EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion `json:"model_to_evaluate,omitzero"` + paramObj } -// Returns the unmodified JSON received from the API -func (u EvalGetStatusResponseResultsUnion) RawJSON() string { return u.JSON.raw } +func (r EvalNewParamsParametersEvaluationScoreParameters) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationScoreParameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationScoreParameters) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties Model, ModelSource, SystemTemplate are required. +type EvalNewParamsParametersEvaluationScoreParametersJudge struct { + // Name of the judge model + Model string `json:"model,required"` + // Source of the judge model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template for the judge + SystemTemplate string `json:"system_template,required"` + // Bearer/API token for external judge models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external judge models. Must be OpenAI-compatible base URL. + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj +} -func (r *EvalGetStatusResponseResultsUnion) UnmarshalJSON(data []byte) error { +func (r EvalNewParamsParametersEvaluationScoreParametersJudge) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationScoreParametersJudge + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationScoreParametersJudge) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseResultsEvaluationClassifyResults struct { - // Number of failed generations. - GenerationFailCount float64 `json:"generation_fail_count,nullable"` - // Number of invalid labels - InvalidLabelCount float64 `json:"invalid_label_count,nullable"` - // Number of failed judge generations - JudgeFailCount float64 `json:"judge_fail_count,nullable"` - // JSON string representing label counts - LabelCounts string `json:"label_counts"` - // Pecentage of pass labels. - PassPercentage float64 `json:"pass_percentage,nullable"` - // Data File ID - ResultFileID string `json:"result_file_id"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - GenerationFailCount respjson.Field - InvalidLabelCount respjson.Field - JudgeFailCount respjson.Field - LabelCounts respjson.Field - PassPercentage respjson.Field - ResultFileID respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationScoreParametersJudge]( + "model_source", "serverless", "dedicated", "external", + ) } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponseResultsEvaluationClassifyResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetStatusResponseResultsEvaluationClassifyResults) UnmarshalJSON(data []byte) error { +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest `json:",omitzero,inline"` + paramUnion +} + +func (u EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest) +} +func (u *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest) { + return u.OfEvalNewsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest + } + return nil +} + +// The properties InputTemplate, MaxTokens, Model, ModelSource, SystemTemplate, +// Temperature are required. +type EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest struct { + // Input prompt template + InputTemplate string `json:"input_template,required"` + // Maximum number of tokens to generate + MaxTokens int64 `json:"max_tokens,required"` + // Name of the model to evaluate + Model string `json:"model,required"` + // Source of the model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template + SystemTemplate string `json:"system_template,required"` + // Sampling temperature + Temperature float64 `json:"temperature,required"` + // Bearer/API token for external models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external models. Must be OpenAI-compatible base URL + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj +} + +func (r EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseResultsEvaluationScoreResults struct { - AggregatedScores EvalGetStatusResponseResultsEvaluationScoreResultsAggregatedScores `json:"aggregated_scores"` - // number of failed samples generated from model - FailedSamples float64 `json:"failed_samples"` - // Number of failed generations. - GenerationFailCount float64 `json:"generation_fail_count,nullable"` - // number of invalid scores generated from model - InvalidScoreCount float64 `json:"invalid_score_count"` - // Number of failed judge generations - JudgeFailCount float64 `json:"judge_fail_count,nullable"` - // Data File ID - ResultFileID string `json:"result_file_id"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - AggregatedScores respjson.Field - FailedSamples respjson.Field - GenerationFailCount respjson.Field - InvalidScoreCount respjson.Field - JudgeFailCount respjson.Field - ResultFileID respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest]( + "model_source", "serverless", "dedicated", "external", + ) } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponseResultsEvaluationScoreResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetStatusResponseResultsEvaluationScoreResults) UnmarshalJSON(data []byte) error { +// The properties InputDataFilePath, Judge are required. +type EvalNewParamsParametersEvaluationCompareParameters struct { + // Data file name + InputDataFilePath string `json:"input_data_file_path,required"` + Judge EvalNewParamsParametersEvaluationCompareParametersJudge `json:"judge,omitzero,required"` + // Field name in the input data + ModelA EvalNewParamsParametersEvaluationCompareParametersModelAUnion `json:"model_a,omitzero"` + // Field name in the input data + ModelB EvalNewParamsParametersEvaluationCompareParametersModelBUnion `json:"model_b,omitzero"` + paramObj +} + +func (r EvalNewParamsParametersEvaluationCompareParameters) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationCompareParameters + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationCompareParameters) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseResultsEvaluationScoreResultsAggregatedScores struct { - MeanScore float64 `json:"mean_score"` - PassPercentage float64 `json:"pass_percentage"` - StdScore float64 `json:"std_score"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - MeanScore respjson.Field - PassPercentage respjson.Field - StdScore respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +// The properties Model, ModelSource, SystemTemplate are required. +type EvalNewParamsParametersEvaluationCompareParametersJudge struct { + // Name of the judge model + Model string `json:"model,required"` + // Source of the judge model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template for the judge + SystemTemplate string `json:"system_template,required"` + // Bearer/API token for external judge models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external judge models. Must be OpenAI-compatible base URL. + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponseResultsEvaluationScoreResultsAggregatedScores) RawJSON() string { - return r.JSON.raw +func (r EvalNewParamsParametersEvaluationCompareParametersJudge) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationCompareParametersJudge + return param.MarshalObject(r, (*shadow)(&r)) } -func (r *EvalGetStatusResponseResultsEvaluationScoreResultsAggregatedScores) UnmarshalJSON(data []byte) error { +func (r *EvalNewParamsParametersEvaluationCompareParametersJudge) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseResultsEvaluationCompareResults struct { - // Number of times model A won - AWins int64 `json:"A_wins"` - // Number of times model B won - BWins int64 `json:"B_wins"` - // Number of failed generations. - GenerationFailCount float64 `json:"generation_fail_count,nullable"` - // Number of failed judge generations - JudgeFailCount float64 `json:"judge_fail_count,nullable"` - // Total number of samples compared - NumSamples int64 `json:"num_samples"` - // Data File ID - ResultFileID string `json:"result_file_id"` - // Number of ties - Ties int64 `json:"Ties"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - AWins respjson.Field - BWins respjson.Field - GenerationFailCount respjson.Field - JudgeFailCount respjson.Field - NumSamples respjson.Field - ResultFileID respjson.Field - Ties respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationCompareParametersJudge]( + "model_source", "serverless", "dedicated", "external", + ) } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponseResultsEvaluationCompareResults) RawJSON() string { return r.JSON.raw } -func (r *EvalGetStatusResponseResultsEvaluationCompareResults) UnmarshalJSON(data []byte) error { +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type EvalNewParamsParametersEvaluationCompareParametersModelAUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationCompareParametersModelAEvaluationModelRequest *EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest `json:",omitzero,inline"` + paramUnion +} + +func (u EvalNewParamsParametersEvaluationCompareParametersModelAUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfEvalNewsParametersEvaluationCompareParametersModelAEvaluationModelRequest) +} +func (u *EvalNewParamsParametersEvaluationCompareParametersModelAUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *EvalNewParamsParametersEvaluationCompareParametersModelAUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationCompareParametersModelAEvaluationModelRequest) { + return u.OfEvalNewsParametersEvaluationCompareParametersModelAEvaluationModelRequest + } + return nil +} + +// The properties InputTemplate, MaxTokens, Model, ModelSource, SystemTemplate, +// Temperature are required. +type EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest struct { + // Input prompt template + InputTemplate string `json:"input_template,required"` + // Maximum number of tokens to generate + MaxTokens int64 `json:"max_tokens,required"` + // Name of the model to evaluate + Model string `json:"model,required"` + // Source of the model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template + SystemTemplate string `json:"system_template,required"` + // Sampling temperature + Temperature float64 `json:"temperature,required"` + // Bearer/API token for external models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external models. Must be OpenAI-compatible base URL + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj +} + +func (r EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseResultsError struct { - Error string `json:"error"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Error respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationCompareParametersModelAEvaluationModelRequest]( + "model_source", "serverless", "dedicated", "external", + ) } -// Returns the unmodified JSON received from the API -func (r EvalGetStatusResponseResultsError) RawJSON() string { return r.JSON.raw } -func (r *EvalGetStatusResponseResultsError) UnmarshalJSON(data []byte) error { +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type EvalNewParamsParametersEvaluationCompareParametersModelBUnion struct { + OfString param.Opt[string] `json:",omitzero,inline"` + OfEvalNewsParametersEvaluationCompareParametersModelBEvaluationModelRequest *EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest `json:",omitzero,inline"` + paramUnion +} + +func (u EvalNewParamsParametersEvaluationCompareParametersModelBUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfString, u.OfEvalNewsParametersEvaluationCompareParametersModelBEvaluationModelRequest) +} +func (u *EvalNewParamsParametersEvaluationCompareParametersModelBUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *EvalNewParamsParametersEvaluationCompareParametersModelBUnion) asAny() any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfEvalNewsParametersEvaluationCompareParametersModelBEvaluationModelRequest) { + return u.OfEvalNewsParametersEvaluationCompareParametersModelBEvaluationModelRequest + } + return nil +} + +// The properties InputTemplate, MaxTokens, Model, ModelSource, SystemTemplate, +// Temperature are required. +type EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest struct { + // Input prompt template + InputTemplate string `json:"input_template,required"` + // Maximum number of tokens to generate + MaxTokens int64 `json:"max_tokens,required"` + // Name of the model to evaluate + Model string `json:"model,required"` + // Source of the model. + // + // Any of "serverless", "dedicated", "external". + ModelSource string `json:"model_source,omitzero,required"` + // System prompt template + SystemTemplate string `json:"system_template,required"` + // Sampling temperature + Temperature float64 `json:"temperature,required"` + // Bearer/API token for external models. + ExternalAPIToken param.Opt[string] `json:"external_api_token,omitzero"` + // Base URL for external models. Must be OpenAI-compatible base URL + ExternalBaseURL param.Opt[string] `json:"external_base_url,omitzero"` + paramObj +} + +func (r EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest) MarshalJSON() (data []byte, err error) { + type shadow EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type EvalGetStatusResponseStatus string +func init() { + apijson.RegisterFieldValidator[EvalNewParamsParametersEvaluationCompareParametersModelBEvaluationModelRequest]( + "model_source", "serverless", "dedicated", "external", + ) +} + +// The type of evaluation to perform +type EvalNewParamsType string const ( - EvalGetStatusResponseStatusPending EvalGetStatusResponseStatus = "pending" - EvalGetStatusResponseStatusQueued EvalGetStatusResponseStatus = "queued" - EvalGetStatusResponseStatusRunning EvalGetStatusResponseStatus = "running" - EvalGetStatusResponseStatusCompleted EvalGetStatusResponseStatus = "completed" - EvalGetStatusResponseStatusError EvalGetStatusResponseStatus = "error" - EvalGetStatusResponseStatusUserError EvalGetStatusResponseStatus = "user_error" + EvalNewParamsTypeClassify EvalNewParamsType = "classify" + EvalNewParamsTypeScore EvalNewParamsType = "score" + EvalNewParamsTypeCompare EvalNewParamsType = "compare" ) type EvalListParams struct { - // Maximum number of results to return (max 100) - Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` - // Filter by job status - // - // Any of "pending", "queued", "running", "completed", "error", "user_error". - Status EvalListParamsStatus `query:"status,omitzero" json:"-"` + Limit param.Opt[int64] `query:"limit,omitzero" json:"-"` + Status param.Opt[string] `query:"status,omitzero" json:"-"` + // Admin users can specify a user ID to filter jobs. Pass empty string to get all + // jobs. + UserID param.Opt[string] `query:"userId,omitzero" json:"-"` paramObj } @@ -971,15 +1427,3 @@ func (r EvalListParams) URLQuery() (v url.Values, err error) { NestedFormat: apiquery.NestedQueryFormatBrackets, }) } - -// Filter by job status -type EvalListParamsStatus string - -const ( - EvalListParamsStatusPending EvalListParamsStatus = "pending" - EvalListParamsStatusQueued EvalListParamsStatus = "queued" - EvalListParamsStatusRunning EvalListParamsStatus = "running" - EvalListParamsStatusCompleted EvalListParamsStatus = "completed" - EvalListParamsStatusError EvalListParamsStatus = "error" - EvalListParamsStatusUserError EvalListParamsStatus = "user_error" -) diff --git a/eval_test.go b/eval_test.go index 89dda912..1cb946ee 100644 --- a/eval_test.go +++ b/eval_test.go @@ -13,7 +13,7 @@ import ( "github.com/togethercomputer/together-go/option" ) -func TestEvalGet(t *testing.T) { +func TestEvalNewWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -25,7 +25,26 @@ func TestEvalGet(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Evals.Get(context.TODO(), "id") + _, err := client.Evals.New(context.TODO(), together.EvalNewParams{ + Parameters: together.EvalNewParamsParametersUnion{ + OfEvalNewsParametersEvaluationClassifyParameters: &together.EvalNewParamsParametersEvaluationClassifyParameters{ + InputDataFilePath: "file-1234-aefd", + Judge: together.EvalNewParamsParametersEvaluationClassifyParametersJudge{ + Model: "meta-llama/Llama-3-70B-Instruct-Turbo", + ModelSource: "serverless", + SystemTemplate: "Imagine you are a helpful assistant", + ExternalAPIToken: together.String("external_api_token"), + ExternalBaseURL: together.String("external_base_url"), + }, + Labels: []string{"yes", "no"}, + PassLabels: []string{"yes"}, + ModelToEvaluate: together.EvalNewParamsParametersEvaluationClassifyParametersModelToEvaluateUnion{ + OfString: together.String("string"), + }, + }, + }, + Type: together.EvalNewParamsTypeClassify, + }) if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -35,7 +54,7 @@ func TestEvalGet(t *testing.T) { } } -func TestEvalListWithOptionalParams(t *testing.T) { +func TestEvalGet(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -47,10 +66,7 @@ func TestEvalListWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Evals.List(context.TODO(), together.EvalListParams{ - Limit: together.Int(1), - Status: together.EvalListParamsStatusPending, - }) + _, err := client.Evals.Get(context.TODO(), "id") if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -60,7 +76,7 @@ func TestEvalListWithOptionalParams(t *testing.T) { } } -func TestEvalGetAllowedModels(t *testing.T) { +func TestEvalListWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -72,7 +88,11 @@ func TestEvalGetAllowedModels(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Evals.GetAllowedModels(context.TODO()) + _, err := client.Evals.List(context.TODO(), together.EvalListParams{ + Limit: together.Int(0), + Status: together.String("status"), + UserID: together.String("userId"), + }) if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -82,7 +102,7 @@ func TestEvalGetAllowedModels(t *testing.T) { } } -func TestEvalGetStatus(t *testing.T) { +func TestEvalStatus(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -94,7 +114,7 @@ func TestEvalGetStatus(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Evals.GetStatus(context.TODO(), "id") + _, err := client.Evals.Status(context.TODO(), "id") if err != nil { var apierr *together.Error if errors.As(err, &apierr) { diff --git a/file.go b/file.go index aee6f4c0..fd198d91 100644 --- a/file.go +++ b/file.go @@ -39,7 +39,7 @@ func NewFileService(opts ...option.RequestOption) (r FileService) { } // List the metadata for a single uploaded data file. -func (r *FileService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *FileGetResponse, err error) { +func (r *FileService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *FileResponse, err error) { opts = slices.Concat(r.Options, opts) if id == "" { err = errors.New("missing required id parameter") @@ -51,7 +51,7 @@ func (r *FileService) Get(ctx context.Context, id string, opts ...option.Request } // List the metadata for all uploaded data files. -func (r *FileService) List(ctx context.Context, opts ...option.RequestOption) (res *FileListResponse, err error) { +func (r *FileService) List(ctx context.Context, opts ...option.RequestOption) (res *FileList, err error) { opts = slices.Concat(r.Options, opts) path := "files" err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) @@ -84,13 +84,29 @@ func (r *FileService) Content(ctx context.Context, id string, opts ...option.Req } // Upload a file with specified purpose, file name, and file type. -func (r *FileService) Upload(ctx context.Context, body FileUploadParams, opts ...option.RequestOption) (res *FileUploadResponse, err error) { +func (r *FileService) Upload(ctx context.Context, body FileUploadParams, opts ...option.RequestOption) (res *FileResponse, err error) { opts = slices.Concat(r.Options, opts) path := "files/upload" err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) return } +type FileList struct { + Data []FileResponse `json:"data,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FileList) RawJSON() string { return r.JSON.raw } +func (r *FileList) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + // The purpose of the file type FilePurpose string @@ -104,16 +120,7 @@ const ( FilePurposeBatchAPI FilePurpose = "batch-api" ) -// The type of the file -type FileType string - -const ( - FileTypeCsv FileType = "csv" - FileTypeJSONL FileType = "jsonl" - FileTypeParquet FileType = "parquet" -) - -type FileGetResponse struct { +type FileResponse struct { ID string `json:"id,required"` Bytes int64 `json:"bytes,required"` CreatedAt int64 `json:"created_at,required"` @@ -147,65 +154,19 @@ type FileGetResponse struct { } // Returns the unmodified JSON received from the API -func (r FileGetResponse) RawJSON() string { return r.JSON.raw } -func (r *FileGetResponse) UnmarshalJSON(data []byte) error { +func (r FileResponse) RawJSON() string { return r.JSON.raw } +func (r *FileResponse) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type FileListResponse struct { - Data []FileListResponseData `json:"data,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Data respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FileListResponse) RawJSON() string { return r.JSON.raw } -func (r *FileListResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FileListResponseData struct { - ID string `json:"id,required"` - Bytes int64 `json:"bytes,required"` - CreatedAt int64 `json:"created_at,required"` - Filename string `json:"filename,required"` - // The type of the file - // - // Any of "csv", "jsonl", "parquet". - FileType FileType `json:"FileType,required"` - LineCount int64 `json:"LineCount,required"` - Object string `json:"object,required"` - Processed bool `json:"Processed,required"` - // The purpose of the file - // - // Any of "fine-tune", "eval", "eval-sample", "eval-output", "eval-summary", - // "batch-generated", "batch-api". - Purpose FilePurpose `json:"purpose,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - Bytes respjson.Field - CreatedAt respjson.Field - Filename respjson.Field - FileType respjson.Field - LineCount respjson.Field - Object respjson.Field - Processed respjson.Field - Purpose respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} +// The type of the file +type FileType string -// Returns the unmodified JSON received from the API -func (r FileListResponseData) RawJSON() string { return r.JSON.raw } -func (r *FileListResponseData) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} +const ( + FileTypeCsv FileType = "csv" + FileTypeJSONL FileType = "jsonl" + FileTypeParquet FileType = "parquet" +) type FileDeleteResponse struct { ID string `json:"id"` @@ -225,45 +186,6 @@ func (r *FileDeleteResponse) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type FileUploadResponse struct { - ID string `json:"id,required"` - Bytes int64 `json:"bytes,required"` - CreatedAt int64 `json:"created_at,required"` - Filename string `json:"filename,required"` - // The type of the file - // - // Any of "csv", "jsonl", "parquet". - FileType FileType `json:"FileType,required"` - LineCount int64 `json:"LineCount,required"` - Object string `json:"object,required"` - Processed bool `json:"Processed,required"` - // The purpose of the file - // - // Any of "fine-tune", "eval", "eval-sample", "eval-output", "eval-summary", - // "batch-generated", "batch-api". - Purpose FilePurpose `json:"purpose,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - Bytes respjson.Field - CreatedAt respjson.Field - Filename respjson.Field - FileType respjson.Field - LineCount respjson.Field - Object respjson.Field - Processed respjson.Field - Purpose respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FileUploadResponse) RawJSON() string { return r.JSON.raw } -func (r *FileUploadResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - type FileUploadParams struct { // The content of the file being uploaded File io.Reader `json:"file,omitzero,required" format:"binary"` diff --git a/finetune.go b/finetune.go deleted file mode 100644 index 3d76fd4b..00000000 --- a/finetune.go +++ /dev/null @@ -1,2029 +0,0 @@ -// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -package together - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "slices" - "time" - - "github.com/togethercomputer/together-go/internal/apijson" - "github.com/togethercomputer/together-go/internal/apiquery" - "github.com/togethercomputer/together-go/internal/requestconfig" - "github.com/togethercomputer/together-go/option" - "github.com/togethercomputer/together-go/packages/param" - "github.com/togethercomputer/together-go/packages/respjson" -) - -// FineTuneService contains methods and other services that help with interacting -// with the together API. -// -// Note, unlike clients, this service does not read variables from the environment -// automatically. You should not instantiate this service directly, and instead use -// the [NewFineTuneService] method instead. -type FineTuneService struct { - Options []option.RequestOption -} - -// NewFineTuneService generates a new service that applies the given options to -// each request. These options are applied after the parent client's options (if -// there is one), and before any request-specific options. -func NewFineTuneService(opts ...option.RequestOption) (r FineTuneService) { - r = FineTuneService{} - r.Options = opts - return -} - -// Create a fine-tuning job with the provided model and training data. -func (r *FineTuneService) New(ctx context.Context, body FineTuneNewParams, opts ...option.RequestOption) (res *FineTuneNewResponse, err error) { - opts = slices.Concat(r.Options, opts) - path := "fine-tunes" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) - return -} - -// List the metadata for a single fine-tuning job. -func (r *FineTuneService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTune, err error) { - opts = slices.Concat(r.Options, opts) - if id == "" { - err = errors.New("missing required id parameter") - return - } - path := fmt.Sprintf("fine-tunes/%s", id) - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) - return -} - -// List the metadata for all fine-tuning jobs. Returns a list of -// FinetuneResponseTruncated objects. -func (r *FineTuneService) List(ctx context.Context, opts ...option.RequestOption) (res *FineTuneListResponse, err error) { - opts = slices.Concat(r.Options, opts) - path := "fine-tunes" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) - return -} - -// Cancel a currently running fine-tuning job. Returns a FinetuneResponseTruncated -// object. -func (r *FineTuneService) Cancel(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuneCancelResponse, err error) { - opts = slices.Concat(r.Options, opts) - if id == "" { - err = errors.New("missing required id parameter") - return - } - path := fmt.Sprintf("fine-tunes/%s/cancel", id) - err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) - return -} - -// Download a compressed fine-tuned model or checkpoint to local disk. -func (r *FineTuneService) Download(ctx context.Context, query FineTuneDownloadParams, opts ...option.RequestOption) (res *FineTuneDownloadResponse, err error) { - opts = slices.Concat(r.Options, opts) - path := "finetune/download" - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) - return -} - -// List the events for a single fine-tuning job. -func (r *FineTuneService) ListEvents(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuneListEventsResponse, err error) { - opts = slices.Concat(r.Options, opts) - if id == "" { - err = errors.New("missing required id parameter") - return - } - path := fmt.Sprintf("fine-tunes/%s/events", id) - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) - return -} - -// List the checkpoints for a single fine-tuning job. -func (r *FineTuneService) GetCheckpoints(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuneGetCheckpointsResponse, err error) { - opts = slices.Concat(r.Options, opts) - if id == "" { - err = errors.New("missing required id parameter") - return - } - path := fmt.Sprintf("fine-tunes/%s/checkpoints", id) - err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) - return -} - -type CosineLrSchedulerArgs struct { - // The ratio of the final learning rate to the peak learning rate - MinLrRatio float64 `json:"min_lr_ratio,required"` - // Number or fraction of cycles for the cosine learning rate scheduler - NumCycles float64 `json:"num_cycles,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - MinLrRatio respjson.Field - NumCycles respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r CosineLrSchedulerArgs) RawJSON() string { return r.JSON.raw } -func (r *CosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this CosineLrSchedulerArgs to a CosineLrSchedulerArgsParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// CosineLrSchedulerArgsParam.Overrides() -func (r CosineLrSchedulerArgs) ToParam() CosineLrSchedulerArgsParam { - return param.Override[CosineLrSchedulerArgsParam](json.RawMessage(r.RawJSON())) -} - -// The properties MinLrRatio, NumCycles are required. -type CosineLrSchedulerArgsParam struct { - // The ratio of the final learning rate to the peak learning rate - MinLrRatio float64 `json:"min_lr_ratio,required"` - // Number or fraction of cycles for the cosine learning rate scheduler - NumCycles float64 `json:"num_cycles,required"` - paramObj -} - -func (r CosineLrSchedulerArgsParam) MarshalJSON() (data []byte, err error) { - type shadow CosineLrSchedulerArgsParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *CosineLrSchedulerArgsParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTune struct { - ID string `json:"id,required" format:"uuid"` - // Any of "pending", "queued", "running", "compressing", "uploading", - // "cancel_requested", "cancelled", "error", "completed". - Status FineTuneStatus `json:"status,required"` - BatchSize FineTuneBatchSizeUnion `json:"batch_size"` - CreatedAt string `json:"created_at"` - EpochsCompleted int64 `json:"epochs_completed"` - EvalSteps int64 `json:"eval_steps"` - Events []FineTuneEvent `json:"events"` - FromCheckpoint string `json:"from_checkpoint"` - FromHfModel string `json:"from_hf_model"` - HfModelRevision string `json:"hf_model_revision"` - JobID string `json:"job_id"` - LearningRate float64 `json:"learning_rate"` - LrScheduler LrScheduler `json:"lr_scheduler"` - MaxGradNorm float64 `json:"max_grad_norm"` - Model string `json:"model"` - ModelOutputName string `json:"model_output_name"` - ModelOutputPath string `json:"model_output_path"` - NCheckpoints int64 `json:"n_checkpoints"` - NEpochs int64 `json:"n_epochs"` - NEvals int64 `json:"n_evals"` - ParamCount int64 `json:"param_count"` - QueueDepth int64 `json:"queue_depth"` - TokenCount int64 `json:"token_count"` - TotalPrice int64 `json:"total_price"` - TrainOnInputs FineTuneTrainOnInputsUnion `json:"train_on_inputs"` - TrainingFile string `json:"training_file"` - TrainingMethod FineTuneTrainingMethodUnion `json:"training_method"` - TrainingType FineTuneTrainingTypeUnion `json:"training_type"` - TrainingfileNumlines int64 `json:"trainingfile_numlines"` - TrainingfileSize int64 `json:"trainingfile_size"` - UpdatedAt string `json:"updated_at"` - ValidationFile string `json:"validation_file"` - WandbProjectName string `json:"wandb_project_name"` - WandbURL string `json:"wandb_url"` - WarmupRatio float64 `json:"warmup_ratio"` - WeightDecay float64 `json:"weight_decay"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - Status respjson.Field - BatchSize respjson.Field - CreatedAt respjson.Field - EpochsCompleted respjson.Field - EvalSteps respjson.Field - Events respjson.Field - FromCheckpoint respjson.Field - FromHfModel respjson.Field - HfModelRevision respjson.Field - JobID respjson.Field - LearningRate respjson.Field - LrScheduler respjson.Field - MaxGradNorm respjson.Field - Model respjson.Field - ModelOutputName respjson.Field - ModelOutputPath respjson.Field - NCheckpoints respjson.Field - NEpochs respjson.Field - NEvals respjson.Field - ParamCount respjson.Field - QueueDepth respjson.Field - TokenCount respjson.Field - TotalPrice respjson.Field - TrainOnInputs respjson.Field - TrainingFile respjson.Field - TrainingMethod respjson.Field - TrainingType respjson.Field - TrainingfileNumlines respjson.Field - TrainingfileSize respjson.Field - UpdatedAt respjson.Field - ValidationFile respjson.Field - WandbProjectName respjson.Field - WandbURL respjson.Field - WarmupRatio respjson.Field - WeightDecay respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTune) RawJSON() string { return r.JSON.raw } -func (r *FineTune) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneStatus string - -const ( - FineTuneStatusPending FineTuneStatus = "pending" - FineTuneStatusQueued FineTuneStatus = "queued" - FineTuneStatusRunning FineTuneStatus = "running" - FineTuneStatusCompressing FineTuneStatus = "compressing" - FineTuneStatusUploading FineTuneStatus = "uploading" - FineTuneStatusCancelRequested FineTuneStatus = "cancel_requested" - FineTuneStatusCancelled FineTuneStatus = "cancelled" - FineTuneStatusError FineTuneStatus = "error" - FineTuneStatusCompleted FineTuneStatus = "completed" -) - -// FineTuneBatchSizeUnion contains all possible properties and values from [int64], -// [string]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -// -// If the underlying value is not a json object, one of the following properties -// will be valid: OfInt OfFineTuneBatchSizeString] -type FineTuneBatchSizeUnion struct { - // This field will be present if the value is a [int64] instead of an object. - OfInt int64 `json:",inline"` - // This field will be present if the value is a [string] instead of an object. - OfFineTuneBatchSizeString string `json:",inline"` - JSON struct { - OfInt respjson.Field - OfFineTuneBatchSizeString respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneBatchSizeUnion) AsInt() (v int64) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneBatchSizeUnion) AsFineTuneBatchSizeString() (v string) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneBatchSizeUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneBatchSizeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneBatchSizeString string - -const ( - FineTuneBatchSizeStringMax FineTuneBatchSizeString = "max" -) - -// FineTuneTrainOnInputsUnion contains all possible properties and values from -// [bool], [string]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -// -// If the underlying value is not a json object, one of the following properties -// will be valid: OfBool OfFineTuneTrainOnInputsString] -type FineTuneTrainOnInputsUnion struct { - // This field will be present if the value is a [bool] instead of an object. - OfBool bool `json:",inline"` - // This field will be present if the value is a [string] instead of an object. - OfFineTuneTrainOnInputsString string `json:",inline"` - JSON struct { - OfBool respjson.Field - OfFineTuneTrainOnInputsString respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneTrainOnInputsUnion) AsBool() (v bool) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneTrainOnInputsUnion) AsFineTuneTrainOnInputsString() (v string) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneTrainOnInputsUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneTrainOnInputsUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneTrainOnInputsString string - -const ( - FineTuneTrainOnInputsStringAuto FineTuneTrainOnInputsString = "auto" -) - -// FineTuneTrainingMethodUnion contains all possible properties and values from -// [TrainingMethodSft], [TrainingMethodDpo]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneTrainingMethodUnion struct { - Method string `json:"method"` - // This field is from variant [TrainingMethodSft]. - TrainOnInputs TrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` - // This field is from variant [TrainingMethodDpo]. - DpoBeta float64 `json:"dpo_beta"` - // This field is from variant [TrainingMethodDpo]. - DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` - // This field is from variant [TrainingMethodDpo]. - DpoReferenceFree bool `json:"dpo_reference_free"` - // This field is from variant [TrainingMethodDpo]. - RpoAlpha float64 `json:"rpo_alpha"` - // This field is from variant [TrainingMethodDpo]. - SimpoGamma float64 `json:"simpo_gamma"` - JSON struct { - Method respjson.Field - TrainOnInputs respjson.Field - DpoBeta respjson.Field - DpoNormalizeLogratiosByLength respjson.Field - DpoReferenceFree respjson.Field - RpoAlpha respjson.Field - SimpoGamma respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneTrainingMethodUnion) AsTrainingMethodSft() (v TrainingMethodSft) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneTrainingMethodUnion) AsTrainingMethodDpo() (v TrainingMethodDpo) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneTrainingMethodUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneTrainingMethodUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// FineTuneTrainingTypeUnion contains all possible properties and values from -// [FullTrainingType], [LoRaTrainingType]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneTrainingTypeUnion struct { - Type string `json:"type"` - // This field is from variant [LoRaTrainingType]. - LoraAlpha int64 `json:"lora_alpha"` - // This field is from variant [LoRaTrainingType]. - LoraR int64 `json:"lora_r"` - // This field is from variant [LoRaTrainingType]. - LoraDropout float64 `json:"lora_dropout"` - // This field is from variant [LoRaTrainingType]. - LoraTrainableModules string `json:"lora_trainable_modules"` - JSON struct { - Type respjson.Field - LoraAlpha respjson.Field - LoraR respjson.Field - LoraDropout respjson.Field - LoraTrainableModules respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneTrainingTypeUnion) AsFullTrainingType() (v FullTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneTrainingTypeUnion) AsLoRaTrainingType() (v LoRaTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneTrainingTypeUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneTrainingTypeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneEvent struct { - CheckpointPath string `json:"checkpoint_path,required"` - CreatedAt string `json:"created_at,required"` - Hash string `json:"hash,required"` - Message string `json:"message,required"` - ModelPath string `json:"model_path,required"` - // Any of "fine-tune-event". - Object FineTuneEventObject `json:"object,required"` - ParamCount int64 `json:"param_count,required"` - Step int64 `json:"step,required"` - TokenCount int64 `json:"token_count,required"` - TotalSteps int64 `json:"total_steps,required"` - TrainingOffset int64 `json:"training_offset,required"` - // Any of "job_pending", "job_start", "job_stopped", "model_downloading", - // "model_download_complete", "training_data_downloading", - // "training_data_download_complete", "validation_data_downloading", - // "validation_data_download_complete", "wandb_init", "training_start", - // "checkpoint_save", "billing_limit", "epoch_complete", "training_complete", - // "model_compressing", "model_compression_complete", "model_uploading", - // "model_upload_complete", "job_complete", "job_error", "cancel_requested", - // "job_restarted", "refund", "warning". - Type FineTuneEventType `json:"type,required"` - WandbURL string `json:"wandb_url,required"` - // Any of "info", "warning", "error", "legacy_info", "legacy_iwarning", - // "legacy_ierror". - Level FineTuneEventLevel `json:"level,nullable"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - CheckpointPath respjson.Field - CreatedAt respjson.Field - Hash respjson.Field - Message respjson.Field - ModelPath respjson.Field - Object respjson.Field - ParamCount respjson.Field - Step respjson.Field - TokenCount respjson.Field - TotalSteps respjson.Field - TrainingOffset respjson.Field - Type respjson.Field - WandbURL respjson.Field - Level respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneEvent) RawJSON() string { return r.JSON.raw } -func (r *FineTuneEvent) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneEventObject string - -const ( - FineTuneEventObjectFineTuneEvent FineTuneEventObject = "fine-tune-event" -) - -type FineTuneEventType string - -const ( - FineTuneEventTypeJobPending FineTuneEventType = "job_pending" - FineTuneEventTypeJobStart FineTuneEventType = "job_start" - FineTuneEventTypeJobStopped FineTuneEventType = "job_stopped" - FineTuneEventTypeModelDownloading FineTuneEventType = "model_downloading" - FineTuneEventTypeModelDownloadComplete FineTuneEventType = "model_download_complete" - FineTuneEventTypeTrainingDataDownloading FineTuneEventType = "training_data_downloading" - FineTuneEventTypeTrainingDataDownloadComplete FineTuneEventType = "training_data_download_complete" - FineTuneEventTypeValidationDataDownloading FineTuneEventType = "validation_data_downloading" - FineTuneEventTypeValidationDataDownloadComplete FineTuneEventType = "validation_data_download_complete" - FineTuneEventTypeWandbInit FineTuneEventType = "wandb_init" - FineTuneEventTypeTrainingStart FineTuneEventType = "training_start" - FineTuneEventTypeCheckpointSave FineTuneEventType = "checkpoint_save" - FineTuneEventTypeBillingLimit FineTuneEventType = "billing_limit" - FineTuneEventTypeEpochComplete FineTuneEventType = "epoch_complete" - FineTuneEventTypeTrainingComplete FineTuneEventType = "training_complete" - FineTuneEventTypeModelCompressing FineTuneEventType = "model_compressing" - FineTuneEventTypeModelCompressionComplete FineTuneEventType = "model_compression_complete" - FineTuneEventTypeModelUploading FineTuneEventType = "model_uploading" - FineTuneEventTypeModelUploadComplete FineTuneEventType = "model_upload_complete" - FineTuneEventTypeJobComplete FineTuneEventType = "job_complete" - FineTuneEventTypeJobError FineTuneEventType = "job_error" - FineTuneEventTypeCancelRequested FineTuneEventType = "cancel_requested" - FineTuneEventTypeJobRestarted FineTuneEventType = "job_restarted" - FineTuneEventTypeRefund FineTuneEventType = "refund" - FineTuneEventTypeWarning FineTuneEventType = "warning" -) - -type FineTuneEventLevel string - -const ( - FineTuneEventLevelInfo FineTuneEventLevel = "info" - FineTuneEventLevelWarning FineTuneEventLevel = "warning" - FineTuneEventLevelError FineTuneEventLevel = "error" - FineTuneEventLevelLegacyInfo FineTuneEventLevel = "legacy_info" - FineTuneEventLevelLegacyIwarning FineTuneEventLevel = "legacy_iwarning" - FineTuneEventLevelLegacyIerror FineTuneEventLevel = "legacy_ierror" -) - -type FullTrainingType struct { - // Any of "Full". - Type FullTrainingTypeType `json:"type,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Type respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FullTrainingType) RawJSON() string { return r.JSON.raw } -func (r *FullTrainingType) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this FullTrainingType to a FullTrainingTypeParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// FullTrainingTypeParam.Overrides() -func (r FullTrainingType) ToParam() FullTrainingTypeParam { - return param.Override[FullTrainingTypeParam](json.RawMessage(r.RawJSON())) -} - -type FullTrainingTypeType string - -const ( - FullTrainingTypeTypeFull FullTrainingTypeType = "Full" -) - -// The property Type is required. -type FullTrainingTypeParam struct { - // Any of "Full". - Type FullTrainingTypeType `json:"type,omitzero,required"` - paramObj -} - -func (r FullTrainingTypeParam) MarshalJSON() (data []byte, err error) { - type shadow FullTrainingTypeParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *FullTrainingTypeParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type LinearLrSchedulerArgs struct { - // The ratio of the final learning rate to the peak learning rate - MinLrRatio float64 `json:"min_lr_ratio"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - MinLrRatio respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r LinearLrSchedulerArgs) RawJSON() string { return r.JSON.raw } -func (r *LinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this LinearLrSchedulerArgs to a LinearLrSchedulerArgsParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// LinearLrSchedulerArgsParam.Overrides() -func (r LinearLrSchedulerArgs) ToParam() LinearLrSchedulerArgsParam { - return param.Override[LinearLrSchedulerArgsParam](json.RawMessage(r.RawJSON())) -} - -type LinearLrSchedulerArgsParam struct { - // The ratio of the final learning rate to the peak learning rate - MinLrRatio param.Opt[float64] `json:"min_lr_ratio,omitzero"` - paramObj -} - -func (r LinearLrSchedulerArgsParam) MarshalJSON() (data []byte, err error) { - type shadow LinearLrSchedulerArgsParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *LinearLrSchedulerArgsParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type LoRaTrainingType struct { - LoraAlpha int64 `json:"lora_alpha,required"` - LoraR int64 `json:"lora_r,required"` - // Any of "Lora". - Type LoRaTrainingTypeType `json:"type,required"` - LoraDropout float64 `json:"lora_dropout"` - LoraTrainableModules string `json:"lora_trainable_modules"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - LoraAlpha respjson.Field - LoraR respjson.Field - Type respjson.Field - LoraDropout respjson.Field - LoraTrainableModules respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r LoRaTrainingType) RawJSON() string { return r.JSON.raw } -func (r *LoRaTrainingType) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this LoRaTrainingType to a LoRaTrainingTypeParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// LoRaTrainingTypeParam.Overrides() -func (r LoRaTrainingType) ToParam() LoRaTrainingTypeParam { - return param.Override[LoRaTrainingTypeParam](json.RawMessage(r.RawJSON())) -} - -type LoRaTrainingTypeType string - -const ( - LoRaTrainingTypeTypeLora LoRaTrainingTypeType = "Lora" -) - -// The properties LoraAlpha, LoraR, Type are required. -type LoRaTrainingTypeParam struct { - LoraAlpha int64 `json:"lora_alpha,required"` - LoraR int64 `json:"lora_r,required"` - // Any of "Lora". - Type LoRaTrainingTypeType `json:"type,omitzero,required"` - LoraDropout param.Opt[float64] `json:"lora_dropout,omitzero"` - LoraTrainableModules param.Opt[string] `json:"lora_trainable_modules,omitzero"` - paramObj -} - -func (r LoRaTrainingTypeParam) MarshalJSON() (data []byte, err error) { - type shadow LoRaTrainingTypeParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *LoRaTrainingTypeParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type LrScheduler struct { - // Any of "linear", "cosine". - LrSchedulerType LrSchedulerLrSchedulerType `json:"lr_scheduler_type,required"` - LrSchedulerArgs LrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - LrSchedulerType respjson.Field - LrSchedulerArgs respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r LrScheduler) RawJSON() string { return r.JSON.raw } -func (r *LrScheduler) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this LrScheduler to a LrSchedulerParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// LrSchedulerParam.Overrides() -func (r LrScheduler) ToParam() LrSchedulerParam { - return param.Override[LrSchedulerParam](json.RawMessage(r.RawJSON())) -} - -type LrSchedulerLrSchedulerType string - -const ( - LrSchedulerLrSchedulerTypeLinear LrSchedulerLrSchedulerType = "linear" - LrSchedulerLrSchedulerTypeCosine LrSchedulerLrSchedulerType = "cosine" -) - -// LrSchedulerLrSchedulerArgsUnion contains all possible properties and values from -// [LinearLrSchedulerArgs], [CosineLrSchedulerArgs]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type LrSchedulerLrSchedulerArgsUnion struct { - MinLrRatio float64 `json:"min_lr_ratio"` - // This field is from variant [CosineLrSchedulerArgs]. - NumCycles float64 `json:"num_cycles"` - JSON struct { - MinLrRatio respjson.Field - NumCycles respjson.Field - raw string - } `json:"-"` -} - -func (u LrSchedulerLrSchedulerArgsUnion) AsLinearLrSchedulerArgs() (v LinearLrSchedulerArgs) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u LrSchedulerLrSchedulerArgsUnion) AsCosineLrSchedulerArgs() (v CosineLrSchedulerArgs) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u LrSchedulerLrSchedulerArgsUnion) RawJSON() string { return u.JSON.raw } - -func (r *LrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// The property LrSchedulerType is required. -type LrSchedulerParam struct { - // Any of "linear", "cosine". - LrSchedulerType LrSchedulerLrSchedulerType `json:"lr_scheduler_type,omitzero,required"` - LrSchedulerArgs LrSchedulerLrSchedulerArgsUnionParam `json:"lr_scheduler_args,omitzero"` - paramObj -} - -func (r LrSchedulerParam) MarshalJSON() (data []byte, err error) { - type shadow LrSchedulerParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *LrSchedulerParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type LrSchedulerLrSchedulerArgsUnionParam struct { - OfLinearLrSchedulerArgs *LinearLrSchedulerArgsParam `json:",omitzero,inline"` - OfCosineLrSchedulerArgs *CosineLrSchedulerArgsParam `json:",omitzero,inline"` - paramUnion -} - -func (u LrSchedulerLrSchedulerArgsUnionParam) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfLinearLrSchedulerArgs, u.OfCosineLrSchedulerArgs) -} -func (u *LrSchedulerLrSchedulerArgsUnionParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *LrSchedulerLrSchedulerArgsUnionParam) asAny() any { - if !param.IsOmitted(u.OfLinearLrSchedulerArgs) { - return u.OfLinearLrSchedulerArgs - } else if !param.IsOmitted(u.OfCosineLrSchedulerArgs) { - return u.OfCosineLrSchedulerArgs - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u LrSchedulerLrSchedulerArgsUnionParam) GetNumCycles() *float64 { - if vt := u.OfCosineLrSchedulerArgs; vt != nil { - return &vt.NumCycles - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u LrSchedulerLrSchedulerArgsUnionParam) GetMinLrRatio() *float64 { - if vt := u.OfLinearLrSchedulerArgs; vt != nil && vt.MinLrRatio.Valid() { - return &vt.MinLrRatio.Value - } else if vt := u.OfCosineLrSchedulerArgs; vt != nil { - return (*float64)(&vt.MinLrRatio) - } - return nil -} - -type TrainingMethodDpo struct { - // Any of "dpo". - Method TrainingMethodDpoMethod `json:"method,required"` - DpoBeta float64 `json:"dpo_beta"` - DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` - DpoReferenceFree bool `json:"dpo_reference_free"` - RpoAlpha float64 `json:"rpo_alpha"` - SimpoGamma float64 `json:"simpo_gamma"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Method respjson.Field - DpoBeta respjson.Field - DpoNormalizeLogratiosByLength respjson.Field - DpoReferenceFree respjson.Field - RpoAlpha respjson.Field - SimpoGamma respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r TrainingMethodDpo) RawJSON() string { return r.JSON.raw } -func (r *TrainingMethodDpo) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this TrainingMethodDpo to a TrainingMethodDpoParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// TrainingMethodDpoParam.Overrides() -func (r TrainingMethodDpo) ToParam() TrainingMethodDpoParam { - return param.Override[TrainingMethodDpoParam](json.RawMessage(r.RawJSON())) -} - -type TrainingMethodDpoMethod string - -const ( - TrainingMethodDpoMethodDpo TrainingMethodDpoMethod = "dpo" -) - -// The property Method is required. -type TrainingMethodDpoParam struct { - // Any of "dpo". - Method TrainingMethodDpoMethod `json:"method,omitzero,required"` - DpoBeta param.Opt[float64] `json:"dpo_beta,omitzero"` - DpoNormalizeLogratiosByLength param.Opt[bool] `json:"dpo_normalize_logratios_by_length,omitzero"` - DpoReferenceFree param.Opt[bool] `json:"dpo_reference_free,omitzero"` - RpoAlpha param.Opt[float64] `json:"rpo_alpha,omitzero"` - SimpoGamma param.Opt[float64] `json:"simpo_gamma,omitzero"` - paramObj -} - -func (r TrainingMethodDpoParam) MarshalJSON() (data []byte, err error) { - type shadow TrainingMethodDpoParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *TrainingMethodDpoParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type TrainingMethodSft struct { - // Any of "sft". - Method TrainingMethodSftMethod `json:"method,required"` - // Whether to mask the user messages in conversational data or prompts in - // instruction data. - TrainOnInputs TrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Method respjson.Field - TrainOnInputs respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r TrainingMethodSft) RawJSON() string { return r.JSON.raw } -func (r *TrainingMethodSft) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// ToParam converts this TrainingMethodSft to a TrainingMethodSftParam. -// -// Warning: the fields of the param type will not be present. ToParam should only -// be used at the last possible moment before sending a request. Test for this with -// TrainingMethodSftParam.Overrides() -func (r TrainingMethodSft) ToParam() TrainingMethodSftParam { - return param.Override[TrainingMethodSftParam](json.RawMessage(r.RawJSON())) -} - -type TrainingMethodSftMethod string - -const ( - TrainingMethodSftMethodSft TrainingMethodSftMethod = "sft" -) - -// TrainingMethodSftTrainOnInputsUnion contains all possible properties and values -// from [bool], [string]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -// -// If the underlying value is not a json object, one of the following properties -// will be valid: OfBool OfTrainingMethodSftTrainOnInputsString] -type TrainingMethodSftTrainOnInputsUnion struct { - // This field will be present if the value is a [bool] instead of an object. - OfBool bool `json:",inline"` - // This field will be present if the value is a [string] instead of an object. - OfTrainingMethodSftTrainOnInputsString string `json:",inline"` - JSON struct { - OfBool respjson.Field - OfTrainingMethodSftTrainOnInputsString respjson.Field - raw string - } `json:"-"` -} - -func (u TrainingMethodSftTrainOnInputsUnion) AsBool() (v bool) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u TrainingMethodSftTrainOnInputsUnion) AsTrainingMethodSftTrainOnInputsString() (v string) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u TrainingMethodSftTrainOnInputsUnion) RawJSON() string { return u.JSON.raw } - -func (r *TrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type TrainingMethodSftTrainOnInputsString string - -const ( - TrainingMethodSftTrainOnInputsStringAuto TrainingMethodSftTrainOnInputsString = "auto" -) - -// The properties Method, TrainOnInputs are required. -type TrainingMethodSftParam struct { - // Any of "sft". - Method TrainingMethodSftMethod `json:"method,omitzero,required"` - // Whether to mask the user messages in conversational data or prompts in - // instruction data. - TrainOnInputs TrainingMethodSftTrainOnInputsUnionParam `json:"train_on_inputs,omitzero,required"` - paramObj -} - -func (r TrainingMethodSftParam) MarshalJSON() (data []byte, err error) { - type shadow TrainingMethodSftParam - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *TrainingMethodSftParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type TrainingMethodSftTrainOnInputsUnionParam struct { - OfBool param.Opt[bool] `json:",omitzero,inline"` - // Check if union is this variant with - // !param.IsOmitted(union.OfTrainingMethodSftTrainOnInputsString) - OfTrainingMethodSftTrainOnInputsString param.Opt[string] `json:",omitzero,inline"` - paramUnion -} - -func (u TrainingMethodSftTrainOnInputsUnionParam) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfBool, u.OfTrainingMethodSftTrainOnInputsString) -} -func (u *TrainingMethodSftTrainOnInputsUnionParam) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *TrainingMethodSftTrainOnInputsUnionParam) asAny() any { - if !param.IsOmitted(u.OfBool) { - return &u.OfBool.Value - } else if !param.IsOmitted(u.OfTrainingMethodSftTrainOnInputsString) { - return &u.OfTrainingMethodSftTrainOnInputsString - } - return nil -} - -// A truncated version of the fine-tune response, used for POST /fine-tunes, GET -// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints -type FineTuneNewResponse struct { - // Unique identifier for the fine-tune job - ID string `json:"id,required"` - // Creation timestamp of the fine-tune job - CreatedAt time.Time `json:"created_at,required" format:"date-time"` - // Any of "pending", "queued", "running", "compressing", "uploading", - // "cancel_requested", "cancelled", "error", "completed". - Status FineTuneNewResponseStatus `json:"status,required"` - // Last update timestamp of the fine-tune job - UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` - // Batch size used for training - BatchSize int64 `json:"batch_size"` - // Events related to this fine-tune job - Events []FineTuneEvent `json:"events"` - // Checkpoint used to continue training - FromCheckpoint string `json:"from_checkpoint"` - // Hugging Face Hub repo to start training from - FromHfModel string `json:"from_hf_model"` - // The revision of the Hugging Face Hub model to continue training from - HfModelRevision string `json:"hf_model_revision"` - // Learning rate used for training - LearningRate float64 `json:"learning_rate"` - // Learning rate scheduler configuration - LrScheduler LrScheduler `json:"lr_scheduler"` - // Maximum gradient norm for clipping - MaxGradNorm float64 `json:"max_grad_norm"` - // Base model used for fine-tuning - Model string `json:"model"` - ModelOutputName string `json:"model_output_name"` - // Number of checkpoints saved during training - NCheckpoints int64 `json:"n_checkpoints"` - // Number of training epochs - NEpochs int64 `json:"n_epochs"` - // Number of evaluations during training - NEvals int64 `json:"n_evals"` - // Owner address information - OwnerAddress string `json:"owner_address"` - // Suffix added to the fine-tuned model name - Suffix string `json:"suffix"` - // Count of tokens processed - TokenCount int64 `json:"token_count"` - // Total price for the fine-tuning job - TotalPrice int64 `json:"total_price"` - // File-ID of the training file - TrainingFile string `json:"training_file"` - // Method of training used - TrainingMethod FineTuneNewResponseTrainingMethodUnion `json:"training_method"` - // Type of training used (full or LoRA) - TrainingType FineTuneNewResponseTrainingTypeUnion `json:"training_type"` - // Identifier for the user who created the job - UserID string `json:"user_id"` - // File-ID of the validation file - ValidationFile string `json:"validation_file"` - // Weights & Biases run name - WandbName string `json:"wandb_name"` - // Weights & Biases project name - WandbProjectName string `json:"wandb_project_name"` - // Ratio of warmup steps - WarmupRatio float64 `json:"warmup_ratio"` - // Weight decay value used - WeightDecay float64 `json:"weight_decay"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - CreatedAt respjson.Field - Status respjson.Field - UpdatedAt respjson.Field - BatchSize respjson.Field - Events respjson.Field - FromCheckpoint respjson.Field - FromHfModel respjson.Field - HfModelRevision respjson.Field - LearningRate respjson.Field - LrScheduler respjson.Field - MaxGradNorm respjson.Field - Model respjson.Field - ModelOutputName respjson.Field - NCheckpoints respjson.Field - NEpochs respjson.Field - NEvals respjson.Field - OwnerAddress respjson.Field - Suffix respjson.Field - TokenCount respjson.Field - TotalPrice respjson.Field - TrainingFile respjson.Field - TrainingMethod respjson.Field - TrainingType respjson.Field - UserID respjson.Field - ValidationFile respjson.Field - WandbName respjson.Field - WandbProjectName respjson.Field - WarmupRatio respjson.Field - WeightDecay respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneNewResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneNewResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneNewResponseStatus string - -const ( - FineTuneNewResponseStatusPending FineTuneNewResponseStatus = "pending" - FineTuneNewResponseStatusQueued FineTuneNewResponseStatus = "queued" - FineTuneNewResponseStatusRunning FineTuneNewResponseStatus = "running" - FineTuneNewResponseStatusCompressing FineTuneNewResponseStatus = "compressing" - FineTuneNewResponseStatusUploading FineTuneNewResponseStatus = "uploading" - FineTuneNewResponseStatusCancelRequested FineTuneNewResponseStatus = "cancel_requested" - FineTuneNewResponseStatusCancelled FineTuneNewResponseStatus = "cancelled" - FineTuneNewResponseStatusError FineTuneNewResponseStatus = "error" - FineTuneNewResponseStatusCompleted FineTuneNewResponseStatus = "completed" -) - -// FineTuneNewResponseTrainingMethodUnion contains all possible properties and -// values from [TrainingMethodSft], [TrainingMethodDpo]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneNewResponseTrainingMethodUnion struct { - Method string `json:"method"` - // This field is from variant [TrainingMethodSft]. - TrainOnInputs TrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` - // This field is from variant [TrainingMethodDpo]. - DpoBeta float64 `json:"dpo_beta"` - // This field is from variant [TrainingMethodDpo]. - DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` - // This field is from variant [TrainingMethodDpo]. - DpoReferenceFree bool `json:"dpo_reference_free"` - // This field is from variant [TrainingMethodDpo]. - RpoAlpha float64 `json:"rpo_alpha"` - // This field is from variant [TrainingMethodDpo]. - SimpoGamma float64 `json:"simpo_gamma"` - JSON struct { - Method respjson.Field - TrainOnInputs respjson.Field - DpoBeta respjson.Field - DpoNormalizeLogratiosByLength respjson.Field - DpoReferenceFree respjson.Field - RpoAlpha respjson.Field - SimpoGamma respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneNewResponseTrainingMethodUnion) AsTrainingMethodSft() (v TrainingMethodSft) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneNewResponseTrainingMethodUnion) AsTrainingMethodDpo() (v TrainingMethodDpo) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneNewResponseTrainingMethodUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneNewResponseTrainingMethodUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// FineTuneNewResponseTrainingTypeUnion contains all possible properties and values -// from [FullTrainingType], [LoRaTrainingType]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneNewResponseTrainingTypeUnion struct { - Type string `json:"type"` - // This field is from variant [LoRaTrainingType]. - LoraAlpha int64 `json:"lora_alpha"` - // This field is from variant [LoRaTrainingType]. - LoraR int64 `json:"lora_r"` - // This field is from variant [LoRaTrainingType]. - LoraDropout float64 `json:"lora_dropout"` - // This field is from variant [LoRaTrainingType]. - LoraTrainableModules string `json:"lora_trainable_modules"` - JSON struct { - Type respjson.Field - LoraAlpha respjson.Field - LoraR respjson.Field - LoraDropout respjson.Field - LoraTrainableModules respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneNewResponseTrainingTypeUnion) AsFullTrainingType() (v FullTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneNewResponseTrainingTypeUnion) AsLoRaTrainingType() (v LoRaTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneNewResponseTrainingTypeUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneNewResponseTrainingTypeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneListResponse struct { - Data []FineTuneListResponseData `json:"data,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Data respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneListResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneListResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// A truncated version of the fine-tune response, used for POST /fine-tunes, GET -// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints -type FineTuneListResponseData struct { - // Unique identifier for the fine-tune job - ID string `json:"id,required"` - // Creation timestamp of the fine-tune job - CreatedAt time.Time `json:"created_at,required" format:"date-time"` - // Any of "pending", "queued", "running", "compressing", "uploading", - // "cancel_requested", "cancelled", "error", "completed". - Status string `json:"status,required"` - // Last update timestamp of the fine-tune job - UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` - // Batch size used for training - BatchSize int64 `json:"batch_size"` - // Events related to this fine-tune job - Events []FineTuneEvent `json:"events"` - // Checkpoint used to continue training - FromCheckpoint string `json:"from_checkpoint"` - // Hugging Face Hub repo to start training from - FromHfModel string `json:"from_hf_model"` - // The revision of the Hugging Face Hub model to continue training from - HfModelRevision string `json:"hf_model_revision"` - // Learning rate used for training - LearningRate float64 `json:"learning_rate"` - // Learning rate scheduler configuration - LrScheduler LrScheduler `json:"lr_scheduler"` - // Maximum gradient norm for clipping - MaxGradNorm float64 `json:"max_grad_norm"` - // Base model used for fine-tuning - Model string `json:"model"` - ModelOutputName string `json:"model_output_name"` - // Number of checkpoints saved during training - NCheckpoints int64 `json:"n_checkpoints"` - // Number of training epochs - NEpochs int64 `json:"n_epochs"` - // Number of evaluations during training - NEvals int64 `json:"n_evals"` - // Owner address information - OwnerAddress string `json:"owner_address"` - // Suffix added to the fine-tuned model name - Suffix string `json:"suffix"` - // Count of tokens processed - TokenCount int64 `json:"token_count"` - // Total price for the fine-tuning job - TotalPrice int64 `json:"total_price"` - // File-ID of the training file - TrainingFile string `json:"training_file"` - // Method of training used - TrainingMethod FineTuneListResponseDataTrainingMethodUnion `json:"training_method"` - // Type of training used (full or LoRA) - TrainingType FineTuneListResponseDataTrainingTypeUnion `json:"training_type"` - // Identifier for the user who created the job - UserID string `json:"user_id"` - // File-ID of the validation file - ValidationFile string `json:"validation_file"` - // Weights & Biases run name - WandbName string `json:"wandb_name"` - // Weights & Biases project name - WandbProjectName string `json:"wandb_project_name"` - // Ratio of warmup steps - WarmupRatio float64 `json:"warmup_ratio"` - // Weight decay value used - WeightDecay float64 `json:"weight_decay"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - CreatedAt respjson.Field - Status respjson.Field - UpdatedAt respjson.Field - BatchSize respjson.Field - Events respjson.Field - FromCheckpoint respjson.Field - FromHfModel respjson.Field - HfModelRevision respjson.Field - LearningRate respjson.Field - LrScheduler respjson.Field - MaxGradNorm respjson.Field - Model respjson.Field - ModelOutputName respjson.Field - NCheckpoints respjson.Field - NEpochs respjson.Field - NEvals respjson.Field - OwnerAddress respjson.Field - Suffix respjson.Field - TokenCount respjson.Field - TotalPrice respjson.Field - TrainingFile respjson.Field - TrainingMethod respjson.Field - TrainingType respjson.Field - UserID respjson.Field - ValidationFile respjson.Field - WandbName respjson.Field - WandbProjectName respjson.Field - WarmupRatio respjson.Field - WeightDecay respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneListResponseData) RawJSON() string { return r.JSON.raw } -func (r *FineTuneListResponseData) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// FineTuneListResponseDataTrainingMethodUnion contains all possible properties and -// values from [TrainingMethodSft], [TrainingMethodDpo]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneListResponseDataTrainingMethodUnion struct { - Method string `json:"method"` - // This field is from variant [TrainingMethodSft]. - TrainOnInputs TrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` - // This field is from variant [TrainingMethodDpo]. - DpoBeta float64 `json:"dpo_beta"` - // This field is from variant [TrainingMethodDpo]. - DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` - // This field is from variant [TrainingMethodDpo]. - DpoReferenceFree bool `json:"dpo_reference_free"` - // This field is from variant [TrainingMethodDpo]. - RpoAlpha float64 `json:"rpo_alpha"` - // This field is from variant [TrainingMethodDpo]. - SimpoGamma float64 `json:"simpo_gamma"` - JSON struct { - Method respjson.Field - TrainOnInputs respjson.Field - DpoBeta respjson.Field - DpoNormalizeLogratiosByLength respjson.Field - DpoReferenceFree respjson.Field - RpoAlpha respjson.Field - SimpoGamma respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneListResponseDataTrainingMethodUnion) AsTrainingMethodSft() (v TrainingMethodSft) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneListResponseDataTrainingMethodUnion) AsTrainingMethodDpo() (v TrainingMethodDpo) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneListResponseDataTrainingMethodUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneListResponseDataTrainingMethodUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// FineTuneListResponseDataTrainingTypeUnion contains all possible properties and -// values from [FullTrainingType], [LoRaTrainingType]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneListResponseDataTrainingTypeUnion struct { - Type string `json:"type"` - // This field is from variant [LoRaTrainingType]. - LoraAlpha int64 `json:"lora_alpha"` - // This field is from variant [LoRaTrainingType]. - LoraR int64 `json:"lora_r"` - // This field is from variant [LoRaTrainingType]. - LoraDropout float64 `json:"lora_dropout"` - // This field is from variant [LoRaTrainingType]. - LoraTrainableModules string `json:"lora_trainable_modules"` - JSON struct { - Type respjson.Field - LoraAlpha respjson.Field - LoraR respjson.Field - LoraDropout respjson.Field - LoraTrainableModules respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneListResponseDataTrainingTypeUnion) AsFullTrainingType() (v FullTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneListResponseDataTrainingTypeUnion) AsLoRaTrainingType() (v LoRaTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneListResponseDataTrainingTypeUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneListResponseDataTrainingTypeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// A truncated version of the fine-tune response, used for POST /fine-tunes, GET -// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints -type FineTuneCancelResponse struct { - // Unique identifier for the fine-tune job - ID string `json:"id,required"` - // Creation timestamp of the fine-tune job - CreatedAt time.Time `json:"created_at,required" format:"date-time"` - // Any of "pending", "queued", "running", "compressing", "uploading", - // "cancel_requested", "cancelled", "error", "completed". - Status FineTuneCancelResponseStatus `json:"status,required"` - // Last update timestamp of the fine-tune job - UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` - // Batch size used for training - BatchSize int64 `json:"batch_size"` - // Events related to this fine-tune job - Events []FineTuneEvent `json:"events"` - // Checkpoint used to continue training - FromCheckpoint string `json:"from_checkpoint"` - // Hugging Face Hub repo to start training from - FromHfModel string `json:"from_hf_model"` - // The revision of the Hugging Face Hub model to continue training from - HfModelRevision string `json:"hf_model_revision"` - // Learning rate used for training - LearningRate float64 `json:"learning_rate"` - // Learning rate scheduler configuration - LrScheduler LrScheduler `json:"lr_scheduler"` - // Maximum gradient norm for clipping - MaxGradNorm float64 `json:"max_grad_norm"` - // Base model used for fine-tuning - Model string `json:"model"` - ModelOutputName string `json:"model_output_name"` - // Number of checkpoints saved during training - NCheckpoints int64 `json:"n_checkpoints"` - // Number of training epochs - NEpochs int64 `json:"n_epochs"` - // Number of evaluations during training - NEvals int64 `json:"n_evals"` - // Owner address information - OwnerAddress string `json:"owner_address"` - // Suffix added to the fine-tuned model name - Suffix string `json:"suffix"` - // Count of tokens processed - TokenCount int64 `json:"token_count"` - // Total price for the fine-tuning job - TotalPrice int64 `json:"total_price"` - // File-ID of the training file - TrainingFile string `json:"training_file"` - // Method of training used - TrainingMethod FineTuneCancelResponseTrainingMethodUnion `json:"training_method"` - // Type of training used (full or LoRA) - TrainingType FineTuneCancelResponseTrainingTypeUnion `json:"training_type"` - // Identifier for the user who created the job - UserID string `json:"user_id"` - // File-ID of the validation file - ValidationFile string `json:"validation_file"` - // Weights & Biases run name - WandbName string `json:"wandb_name"` - // Weights & Biases project name - WandbProjectName string `json:"wandb_project_name"` - // Ratio of warmup steps - WarmupRatio float64 `json:"warmup_ratio"` - // Weight decay value used - WeightDecay float64 `json:"weight_decay"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - CreatedAt respjson.Field - Status respjson.Field - UpdatedAt respjson.Field - BatchSize respjson.Field - Events respjson.Field - FromCheckpoint respjson.Field - FromHfModel respjson.Field - HfModelRevision respjson.Field - LearningRate respjson.Field - LrScheduler respjson.Field - MaxGradNorm respjson.Field - Model respjson.Field - ModelOutputName respjson.Field - NCheckpoints respjson.Field - NEpochs respjson.Field - NEvals respjson.Field - OwnerAddress respjson.Field - Suffix respjson.Field - TokenCount respjson.Field - TotalPrice respjson.Field - TrainingFile respjson.Field - TrainingMethod respjson.Field - TrainingType respjson.Field - UserID respjson.Field - ValidationFile respjson.Field - WandbName respjson.Field - WandbProjectName respjson.Field - WarmupRatio respjson.Field - WeightDecay respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneCancelResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneCancelResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneCancelResponseStatus string - -const ( - FineTuneCancelResponseStatusPending FineTuneCancelResponseStatus = "pending" - FineTuneCancelResponseStatusQueued FineTuneCancelResponseStatus = "queued" - FineTuneCancelResponseStatusRunning FineTuneCancelResponseStatus = "running" - FineTuneCancelResponseStatusCompressing FineTuneCancelResponseStatus = "compressing" - FineTuneCancelResponseStatusUploading FineTuneCancelResponseStatus = "uploading" - FineTuneCancelResponseStatusCancelRequested FineTuneCancelResponseStatus = "cancel_requested" - FineTuneCancelResponseStatusCancelled FineTuneCancelResponseStatus = "cancelled" - FineTuneCancelResponseStatusError FineTuneCancelResponseStatus = "error" - FineTuneCancelResponseStatusCompleted FineTuneCancelResponseStatus = "completed" -) - -// FineTuneCancelResponseTrainingMethodUnion contains all possible properties and -// values from [TrainingMethodSft], [TrainingMethodDpo]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneCancelResponseTrainingMethodUnion struct { - Method string `json:"method"` - // This field is from variant [TrainingMethodSft]. - TrainOnInputs TrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` - // This field is from variant [TrainingMethodDpo]. - DpoBeta float64 `json:"dpo_beta"` - // This field is from variant [TrainingMethodDpo]. - DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` - // This field is from variant [TrainingMethodDpo]. - DpoReferenceFree bool `json:"dpo_reference_free"` - // This field is from variant [TrainingMethodDpo]. - RpoAlpha float64 `json:"rpo_alpha"` - // This field is from variant [TrainingMethodDpo]. - SimpoGamma float64 `json:"simpo_gamma"` - JSON struct { - Method respjson.Field - TrainOnInputs respjson.Field - DpoBeta respjson.Field - DpoNormalizeLogratiosByLength respjson.Field - DpoReferenceFree respjson.Field - RpoAlpha respjson.Field - SimpoGamma respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneCancelResponseTrainingMethodUnion) AsTrainingMethodSft() (v TrainingMethodSft) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneCancelResponseTrainingMethodUnion) AsTrainingMethodDpo() (v TrainingMethodDpo) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneCancelResponseTrainingMethodUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneCancelResponseTrainingMethodUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// FineTuneCancelResponseTrainingTypeUnion contains all possible properties and -// values from [FullTrainingType], [LoRaTrainingType]. -// -// Use the methods beginning with 'As' to cast the union to one of its variants. -type FineTuneCancelResponseTrainingTypeUnion struct { - Type string `json:"type"` - // This field is from variant [LoRaTrainingType]. - LoraAlpha int64 `json:"lora_alpha"` - // This field is from variant [LoRaTrainingType]. - LoraR int64 `json:"lora_r"` - // This field is from variant [LoRaTrainingType]. - LoraDropout float64 `json:"lora_dropout"` - // This field is from variant [LoRaTrainingType]. - LoraTrainableModules string `json:"lora_trainable_modules"` - JSON struct { - Type respjson.Field - LoraAlpha respjson.Field - LoraR respjson.Field - LoraDropout respjson.Field - LoraTrainableModules respjson.Field - raw string - } `json:"-"` -} - -func (u FineTuneCancelResponseTrainingTypeUnion) AsFullTrainingType() (v FullTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -func (u FineTuneCancelResponseTrainingTypeUnion) AsLoRaTrainingType() (v LoRaTrainingType) { - apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) - return -} - -// Returns the unmodified JSON received from the API -func (u FineTuneCancelResponseTrainingTypeUnion) RawJSON() string { return u.JSON.raw } - -func (r *FineTuneCancelResponseTrainingTypeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneDownloadResponse struct { - ID string `json:"id"` - CheckpointStep int64 `json:"checkpoint_step"` - Filename string `json:"filename"` - // Any of "local". - Object FineTuneDownloadResponseObject `json:"object,nullable"` - Size int64 `json:"size"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - CheckpointStep respjson.Field - Filename respjson.Field - Object respjson.Field - Size respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneDownloadResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneDownloadResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneDownloadResponseObject string - -const ( - FineTuneDownloadResponseObjectLocal FineTuneDownloadResponseObject = "local" -) - -type FineTuneListEventsResponse struct { - Data []FineTuneEvent `json:"data,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Data respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneListEventsResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneListEventsResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneGetCheckpointsResponse struct { - Data []FineTuneGetCheckpointsResponseData `json:"data,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - Data respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneGetCheckpointsResponse) RawJSON() string { return r.JSON.raw } -func (r *FineTuneGetCheckpointsResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneGetCheckpointsResponseData struct { - CheckpointType string `json:"checkpoint_type,required"` - CreatedAt string `json:"created_at,required"` - Path string `json:"path,required"` - Step int64 `json:"step,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - CheckpointType respjson.Field - CreatedAt respjson.Field - Path respjson.Field - Step respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r FineTuneGetCheckpointsResponseData) RawJSON() string { return r.JSON.raw } -func (r *FineTuneGetCheckpointsResponseData) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -type FineTuneNewParams struct { - // Name of the base model to run fine-tune job on - Model string `json:"model,required"` - // File-ID of a training file uploaded to the Together API - TrainingFile string `json:"training_file,required"` - // The checkpoint identifier to continue training from a previous fine-tuning job. - // Format is `{$JOB_ID}` or `{$OUTPUT_MODEL_NAME}` or `{$JOB_ID}:{$STEP}` or - // `{$OUTPUT_MODEL_NAME}:{$STEP}`. The step value is optional; without it, the - // final checkpoint will be used. - FromCheckpoint param.Opt[string] `json:"from_checkpoint,omitzero"` - // The Hugging Face Hub repo to start training from. Should be as close as possible - // to the base model (specified by the `model` argument) in terms of architecture - // and size. - FromHfModel param.Opt[string] `json:"from_hf_model,omitzero"` - // The API token for the Hugging Face Hub. - HfAPIToken param.Opt[string] `json:"hf_api_token,omitzero"` - // The revision of the Hugging Face Hub model to continue training from. E.g., - // hf_model_revision=main (default, used if the argument is not provided) or - // hf_model_revision='607a30d783dfa663caf39e06633721c8d4cfcd7e' (specific commit). - HfModelRevision param.Opt[string] `json:"hf_model_revision,omitzero"` - // The name of the Hugging Face repository to upload the fine-tuned model to. - HfOutputRepoName param.Opt[string] `json:"hf_output_repo_name,omitzero"` - // Controls how quickly the model adapts to new information (too high may cause - // instability, too low may slow convergence) - LearningRate param.Opt[float64] `json:"learning_rate,omitzero"` - // Max gradient norm to be used for gradient clipping. Set to 0 to disable. - MaxGradNorm param.Opt[float64] `json:"max_grad_norm,omitzero"` - // Number of intermediate model versions saved during training for evaluation - NCheckpoints param.Opt[int64] `json:"n_checkpoints,omitzero"` - // Number of complete passes through the training dataset (higher values may - // improve results but increase cost and risk of overfitting) - NEpochs param.Opt[int64] `json:"n_epochs,omitzero"` - // Number of evaluations to be run on a given validation set during training - NEvals param.Opt[int64] `json:"n_evals,omitzero"` - // Suffix that will be added to your fine-tuned model name - Suffix param.Opt[string] `json:"suffix,omitzero"` - // File-ID of a validation file uploaded to the Together API - ValidationFile param.Opt[string] `json:"validation_file,omitzero"` - // Integration key for tracking experiments and model metrics on W&B platform - WandbAPIKey param.Opt[string] `json:"wandb_api_key,omitzero"` - // The base URL of a dedicated Weights & Biases instance. - WandbBaseURL param.Opt[string] `json:"wandb_base_url,omitzero"` - // The Weights & Biases name for your run. - WandbName param.Opt[string] `json:"wandb_name,omitzero"` - // The Weights & Biases project for your run. If not specified, will use `together` - // as the project name. - WandbProjectName param.Opt[string] `json:"wandb_project_name,omitzero"` - // The percent of steps at the start of training to linearly increase the learning - // rate. - WarmupRatio param.Opt[float64] `json:"warmup_ratio,omitzero"` - // Weight decay. Regularization parameter for the optimizer. - WeightDecay param.Opt[float64] `json:"weight_decay,omitzero"` - // Number of training examples processed together (larger batches use more memory - // but may train faster). Defaults to "max". We use training optimizations like - // packing, so the effective batch size may be different than the value you set. - BatchSize FineTuneNewParamsBatchSizeUnion `json:"batch_size,omitzero"` - // The learning rate scheduler to use. It specifies how the learning rate is - // adjusted during training. - LrScheduler LrSchedulerParam `json:"lr_scheduler,omitzero"` - // Whether to mask the user messages in conversational data or prompts in - // instruction data. - TrainOnInputs FineTuneNewParamsTrainOnInputsUnion `json:"train_on_inputs,omitzero"` - // The training method to use. 'sft' for Supervised Fine-Tuning or 'dpo' for Direct - // Preference Optimization. - TrainingMethod FineTuneNewParamsTrainingMethodUnion `json:"training_method,omitzero"` - TrainingType FineTuneNewParamsTrainingTypeUnion `json:"training_type,omitzero"` - paramObj -} - -func (r FineTuneNewParams) MarshalJSON() (data []byte, err error) { - type shadow FineTuneNewParams - return param.MarshalObject(r, (*shadow)(&r)) -} -func (r *FineTuneNewParams) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type FineTuneNewParamsBatchSizeUnion struct { - OfInt param.Opt[int64] `json:",omitzero,inline"` - // Check if union is this variant with - // !param.IsOmitted(union.OfFineTuneNewsBatchSizeString) - OfFineTuneNewsBatchSizeString param.Opt[string] `json:",omitzero,inline"` - paramUnion -} - -func (u FineTuneNewParamsBatchSizeUnion) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfInt, u.OfFineTuneNewsBatchSizeString) -} -func (u *FineTuneNewParamsBatchSizeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *FineTuneNewParamsBatchSizeUnion) asAny() any { - if !param.IsOmitted(u.OfInt) { - return &u.OfInt.Value - } else if !param.IsOmitted(u.OfFineTuneNewsBatchSizeString) { - return &u.OfFineTuneNewsBatchSizeString - } - return nil -} - -type FineTuneNewParamsBatchSizeString string - -const ( - FineTuneNewParamsBatchSizeStringMax FineTuneNewParamsBatchSizeString = "max" -) - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type FineTuneNewParamsTrainOnInputsUnion struct { - OfBool param.Opt[bool] `json:",omitzero,inline"` - // Check if union is this variant with - // !param.IsOmitted(union.OfFineTuneNewsTrainOnInputsString) - OfFineTuneNewsTrainOnInputsString param.Opt[string] `json:",omitzero,inline"` - paramUnion -} - -func (u FineTuneNewParamsTrainOnInputsUnion) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfBool, u.OfFineTuneNewsTrainOnInputsString) -} -func (u *FineTuneNewParamsTrainOnInputsUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *FineTuneNewParamsTrainOnInputsUnion) asAny() any { - if !param.IsOmitted(u.OfBool) { - return &u.OfBool.Value - } else if !param.IsOmitted(u.OfFineTuneNewsTrainOnInputsString) { - return &u.OfFineTuneNewsTrainOnInputsString - } - return nil -} - -type FineTuneNewParamsTrainOnInputsString string - -const ( - FineTuneNewParamsTrainOnInputsStringAuto FineTuneNewParamsTrainOnInputsString = "auto" -) - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type FineTuneNewParamsTrainingMethodUnion struct { - OfTrainingMethodSft *TrainingMethodSftParam `json:",omitzero,inline"` - OfTrainingMethodDpo *TrainingMethodDpoParam `json:",omitzero,inline"` - paramUnion -} - -func (u FineTuneNewParamsTrainingMethodUnion) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfTrainingMethodSft, u.OfTrainingMethodDpo) -} -func (u *FineTuneNewParamsTrainingMethodUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *FineTuneNewParamsTrainingMethodUnion) asAny() any { - if !param.IsOmitted(u.OfTrainingMethodSft) { - return u.OfTrainingMethodSft - } else if !param.IsOmitted(u.OfTrainingMethodDpo) { - return u.OfTrainingMethodDpo - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetTrainOnInputs() *TrainingMethodSftTrainOnInputsUnionParam { - if vt := u.OfTrainingMethodSft; vt != nil { - return &vt.TrainOnInputs - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetDpoBeta() *float64 { - if vt := u.OfTrainingMethodDpo; vt != nil && vt.DpoBeta.Valid() { - return &vt.DpoBeta.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetDpoNormalizeLogratiosByLength() *bool { - if vt := u.OfTrainingMethodDpo; vt != nil && vt.DpoNormalizeLogratiosByLength.Valid() { - return &vt.DpoNormalizeLogratiosByLength.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetDpoReferenceFree() *bool { - if vt := u.OfTrainingMethodDpo; vt != nil && vt.DpoReferenceFree.Valid() { - return &vt.DpoReferenceFree.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetRpoAlpha() *float64 { - if vt := u.OfTrainingMethodDpo; vt != nil && vt.RpoAlpha.Valid() { - return &vt.RpoAlpha.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetSimpoGamma() *float64 { - if vt := u.OfTrainingMethodDpo; vt != nil && vt.SimpoGamma.Valid() { - return &vt.SimpoGamma.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingMethodUnion) GetMethod() *string { - if vt := u.OfTrainingMethodSft; vt != nil { - return (*string)(&vt.Method) - } else if vt := u.OfTrainingMethodDpo; vt != nil { - return (*string)(&vt.Method) - } - return nil -} - -// Only one field can be non-zero. -// -// Use [param.IsOmitted] to confirm if a field is set. -type FineTuneNewParamsTrainingTypeUnion struct { - OfFullTrainingType *FullTrainingTypeParam `json:",omitzero,inline"` - OfLoRaTrainingType *LoRaTrainingTypeParam `json:",omitzero,inline"` - paramUnion -} - -func (u FineTuneNewParamsTrainingTypeUnion) MarshalJSON() ([]byte, error) { - return param.MarshalUnion(u, u.OfFullTrainingType, u.OfLoRaTrainingType) -} -func (u *FineTuneNewParamsTrainingTypeUnion) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, u) -} - -func (u *FineTuneNewParamsTrainingTypeUnion) asAny() any { - if !param.IsOmitted(u.OfFullTrainingType) { - return u.OfFullTrainingType - } else if !param.IsOmitted(u.OfLoRaTrainingType) { - return u.OfLoRaTrainingType - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingTypeUnion) GetLoraAlpha() *int64 { - if vt := u.OfLoRaTrainingType; vt != nil { - return &vt.LoraAlpha - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingTypeUnion) GetLoraR() *int64 { - if vt := u.OfLoRaTrainingType; vt != nil { - return &vt.LoraR - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingTypeUnion) GetLoraDropout() *float64 { - if vt := u.OfLoRaTrainingType; vt != nil && vt.LoraDropout.Valid() { - return &vt.LoraDropout.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingTypeUnion) GetLoraTrainableModules() *string { - if vt := u.OfLoRaTrainingType; vt != nil && vt.LoraTrainableModules.Valid() { - return &vt.LoraTrainableModules.Value - } - return nil -} - -// Returns a pointer to the underlying variant's property, if present. -func (u FineTuneNewParamsTrainingTypeUnion) GetType() *string { - if vt := u.OfFullTrainingType; vt != nil { - return (*string)(&vt.Type) - } else if vt := u.OfLoRaTrainingType; vt != nil { - return (*string)(&vt.Type) - } - return nil -} - -type FineTuneDownloadParams struct { - // Fine-tune ID to download. A string that starts with `ft-`. - FtID string `query:"ft_id,required" json:"-"` - // Specifies step number for checkpoint to download. Ignores `checkpoint` value if - // set. - CheckpointStep param.Opt[int64] `query:"checkpoint_step,omitzero" json:"-"` - // Specifies output file name for downloaded model. Defaults to - // `$PWD/{model_name}.{extension}`. - Output param.Opt[string] `query:"output,omitzero" json:"-"` - // Specifies checkpoint type to download - `merged` vs `adapter`. This field is - // required if the checkpoint_step is not set. - // - // Any of "merged", "adapter". - Checkpoint FineTuneDownloadParamsCheckpoint `query:"checkpoint,omitzero" json:"-"` - paramObj -} - -// URLQuery serializes [FineTuneDownloadParams]'s query parameters as `url.Values`. -func (r FineTuneDownloadParams) URLQuery() (v url.Values, err error) { - return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ - ArrayFormat: apiquery.ArrayQueryFormatComma, - NestedFormat: apiquery.NestedQueryFormatBrackets, - }) -} - -// Specifies checkpoint type to download - `merged` vs `adapter`. This field is -// required if the checkpoint_step is not set. -type FineTuneDownloadParamsCheckpoint string - -const ( - FineTuneDownloadParamsCheckpointMerged FineTuneDownloadParamsCheckpoint = "merged" - FineTuneDownloadParamsCheckpointAdapter FineTuneDownloadParamsCheckpoint = "adapter" -) diff --git a/finetuning.go b/finetuning.go new file mode 100644 index 00000000..ba3cfaa7 --- /dev/null +++ b/finetuning.go @@ -0,0 +1,2728 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package together + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "time" + + "github.com/togethercomputer/together-go/internal/apijson" + "github.com/togethercomputer/together-go/internal/apiquery" + "github.com/togethercomputer/together-go/internal/requestconfig" + "github.com/togethercomputer/together-go/option" + "github.com/togethercomputer/together-go/packages/param" + "github.com/togethercomputer/together-go/packages/respjson" +) + +// FineTuningService contains methods and other services that help with interacting +// with the together API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFineTuningService] method instead. +type FineTuningService struct { + Options []option.RequestOption +} + +// NewFineTuningService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewFineTuningService(opts ...option.RequestOption) (r FineTuningService) { + r = FineTuningService{} + r.Options = opts + return +} + +// Create a fine-tuning job with the provided model and training data. +func (r *FineTuningService) New(ctx context.Context, body FineTuningNewParams, opts ...option.RequestOption) (res *FineTuningNewResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "fine-tunes" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// List the metadata for a single fine-tuning job. +func (r *FineTuningService) Get(ctx context.Context, id string, opts ...option.RequestOption) (res *FinetuneResponse, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("fine-tunes/%s", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List the metadata for all fine-tuning jobs. Returns a list of +// FinetuneResponseTruncated objects. +func (r *FineTuningService) List(ctx context.Context, opts ...option.RequestOption) (res *FineTuningListResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "fine-tunes" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// Delete a fine-tuning job. +func (r *FineTuningService) Delete(ctx context.Context, id string, body FineTuningDeleteParams, opts ...option.RequestOption) (res *FineTuningDeleteResponse, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("fine-tunes/%s", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, body, &res, opts...) + return +} + +// Cancel a currently running fine-tuning job. Returns a FinetuneResponseTruncated +// object. +func (r *FineTuningService) Cancel(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuningCancelResponse, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("fine-tunes/%s/cancel", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...) + return +} + +// Receive a compressed fine-tuned model or checkpoint. +func (r *FineTuningService) Content(ctx context.Context, query FineTuningContentParams, opts ...option.RequestOption) (res *http.Response, err error) { + opts = slices.Concat(r.Options, opts) + opts = append([]option.RequestOption{option.WithHeader("Accept", "application/octet-stream")}, opts...) + path := "finetune/download" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) + return +} + +// List the checkpoints for a single fine-tuning job. +func (r *FineTuningService) ListCheckpoints(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuningListCheckpointsResponse, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("fine-tunes/%s/checkpoints", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +// List the events for a single fine-tuning job. +func (r *FineTuningService) ListEvents(ctx context.Context, id string, opts ...option.RequestOption) (res *FineTuningListEventsResponse, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("fine-tunes/%s/events", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) + return +} + +type FinetuneEvent struct { + CheckpointPath string `json:"checkpoint_path,required"` + CreatedAt string `json:"created_at,required"` + Hash string `json:"hash,required"` + Message string `json:"message,required"` + ModelPath string `json:"model_path,required"` + // Any of "fine-tune-event". + Object FinetuneEventObject `json:"object,required"` + ParamCount int64 `json:"param_count,required"` + Step int64 `json:"step,required"` + TokenCount int64 `json:"token_count,required"` + TotalSteps int64 `json:"total_steps,required"` + TrainingOffset int64 `json:"training_offset,required"` + // Any of "job_pending", "job_start", "job_stopped", "model_downloading", + // "model_download_complete", "training_data_downloading", + // "training_data_download_complete", "validation_data_downloading", + // "validation_data_download_complete", "wandb_init", "training_start", + // "checkpoint_save", "billing_limit", "epoch_complete", "training_complete", + // "model_compressing", "model_compression_complete", "model_uploading", + // "model_upload_complete", "job_complete", "job_error", "cancel_requested", + // "job_restarted", "refund", "warning". + Type FinetuneEventType `json:"type,required"` + WandbURL string `json:"wandb_url,required"` + // Any of "info", "warning", "error", "legacy_info", "legacy_iwarning", + // "legacy_ierror". + Level FinetuneEventLevel `json:"level,nullable"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CheckpointPath respjson.Field + CreatedAt respjson.Field + Hash respjson.Field + Message respjson.Field + ModelPath respjson.Field + Object respjson.Field + ParamCount respjson.Field + Step respjson.Field + TokenCount respjson.Field + TotalSteps respjson.Field + TrainingOffset respjson.Field + Type respjson.Field + WandbURL respjson.Field + Level respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneEvent) RawJSON() string { return r.JSON.raw } +func (r *FinetuneEvent) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneEventObject string + +const ( + FinetuneEventObjectFineTuneEvent FinetuneEventObject = "fine-tune-event" +) + +type FinetuneEventLevel string + +const ( + FinetuneEventLevelInfo FinetuneEventLevel = "info" + FinetuneEventLevelWarning FinetuneEventLevel = "warning" + FinetuneEventLevelError FinetuneEventLevel = "error" + FinetuneEventLevelLegacyInfo FinetuneEventLevel = "legacy_info" + FinetuneEventLevelLegacyIwarning FinetuneEventLevel = "legacy_iwarning" + FinetuneEventLevelLegacyIerror FinetuneEventLevel = "legacy_ierror" +) + +type FinetuneEventType string + +const ( + FinetuneEventTypeJobPending FinetuneEventType = "job_pending" + FinetuneEventTypeJobStart FinetuneEventType = "job_start" + FinetuneEventTypeJobStopped FinetuneEventType = "job_stopped" + FinetuneEventTypeModelDownloading FinetuneEventType = "model_downloading" + FinetuneEventTypeModelDownloadComplete FinetuneEventType = "model_download_complete" + FinetuneEventTypeTrainingDataDownloading FinetuneEventType = "training_data_downloading" + FinetuneEventTypeTrainingDataDownloadComplete FinetuneEventType = "training_data_download_complete" + FinetuneEventTypeValidationDataDownloading FinetuneEventType = "validation_data_downloading" + FinetuneEventTypeValidationDataDownloadComplete FinetuneEventType = "validation_data_download_complete" + FinetuneEventTypeWandbInit FinetuneEventType = "wandb_init" + FinetuneEventTypeTrainingStart FinetuneEventType = "training_start" + FinetuneEventTypeCheckpointSave FinetuneEventType = "checkpoint_save" + FinetuneEventTypeBillingLimit FinetuneEventType = "billing_limit" + FinetuneEventTypeEpochComplete FinetuneEventType = "epoch_complete" + FinetuneEventTypeTrainingComplete FinetuneEventType = "training_complete" + FinetuneEventTypeModelCompressing FinetuneEventType = "model_compressing" + FinetuneEventTypeModelCompressionComplete FinetuneEventType = "model_compression_complete" + FinetuneEventTypeModelUploading FinetuneEventType = "model_uploading" + FinetuneEventTypeModelUploadComplete FinetuneEventType = "model_upload_complete" + FinetuneEventTypeJobComplete FinetuneEventType = "job_complete" + FinetuneEventTypeJobError FinetuneEventType = "job_error" + FinetuneEventTypeCancelRequested FinetuneEventType = "cancel_requested" + FinetuneEventTypeJobRestarted FinetuneEventType = "job_restarted" + FinetuneEventTypeRefund FinetuneEventType = "refund" + FinetuneEventTypeWarning FinetuneEventType = "warning" +) + +type FinetuneResponse struct { + ID string `json:"id,required" format:"uuid"` + // Any of "pending", "queued", "running", "compressing", "uploading", + // "cancel_requested", "cancelled", "error", "completed". + Status FinetuneResponseStatus `json:"status,required"` + BatchSize FinetuneResponseBatchSizeUnion `json:"batch_size"` + CreatedAt string `json:"created_at"` + EpochsCompleted int64 `json:"epochs_completed"` + EvalSteps int64 `json:"eval_steps"` + Events []FinetuneEvent `json:"events"` + FromCheckpoint string `json:"from_checkpoint"` + FromHfModel string `json:"from_hf_model"` + HfModelRevision string `json:"hf_model_revision"` + JobID string `json:"job_id"` + LearningRate float64 `json:"learning_rate"` + LrScheduler FinetuneResponseLrScheduler `json:"lr_scheduler"` + MaxGradNorm float64 `json:"max_grad_norm"` + Model string `json:"model"` + ModelOutputName string `json:"model_output_name"` + ModelOutputPath string `json:"model_output_path"` + NCheckpoints int64 `json:"n_checkpoints"` + NEpochs int64 `json:"n_epochs"` + NEvals int64 `json:"n_evals"` + ParamCount int64 `json:"param_count"` + QueueDepth int64 `json:"queue_depth"` + TokenCount int64 `json:"token_count"` + TotalPrice int64 `json:"total_price"` + TrainOnInputs FinetuneResponseTrainOnInputsUnion `json:"train_on_inputs"` + TrainingFile string `json:"training_file"` + TrainingMethod FinetuneResponseTrainingMethodUnion `json:"training_method"` + TrainingType FinetuneResponseTrainingTypeUnion `json:"training_type"` + TrainingfileNumlines int64 `json:"trainingfile_numlines"` + TrainingfileSize int64 `json:"trainingfile_size"` + UpdatedAt string `json:"updated_at"` + ValidationFile string `json:"validation_file"` + WandbProjectName string `json:"wandb_project_name"` + WandbURL string `json:"wandb_url"` + WarmupRatio float64 `json:"warmup_ratio"` + WeightDecay float64 `json:"weight_decay"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + Status respjson.Field + BatchSize respjson.Field + CreatedAt respjson.Field + EpochsCompleted respjson.Field + EvalSteps respjson.Field + Events respjson.Field + FromCheckpoint respjson.Field + FromHfModel respjson.Field + HfModelRevision respjson.Field + JobID respjson.Field + LearningRate respjson.Field + LrScheduler respjson.Field + MaxGradNorm respjson.Field + Model respjson.Field + ModelOutputName respjson.Field + ModelOutputPath respjson.Field + NCheckpoints respjson.Field + NEpochs respjson.Field + NEvals respjson.Field + ParamCount respjson.Field + QueueDepth respjson.Field + TokenCount respjson.Field + TotalPrice respjson.Field + TrainOnInputs respjson.Field + TrainingFile respjson.Field + TrainingMethod respjson.Field + TrainingType respjson.Field + TrainingfileNumlines respjson.Field + TrainingfileSize respjson.Field + UpdatedAt respjson.Field + ValidationFile respjson.Field + WandbProjectName respjson.Field + WandbURL respjson.Field + WarmupRatio respjson.Field + WeightDecay respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponse) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseStatus string + +const ( + FinetuneResponseStatusPending FinetuneResponseStatus = "pending" + FinetuneResponseStatusQueued FinetuneResponseStatus = "queued" + FinetuneResponseStatusRunning FinetuneResponseStatus = "running" + FinetuneResponseStatusCompressing FinetuneResponseStatus = "compressing" + FinetuneResponseStatusUploading FinetuneResponseStatus = "uploading" + FinetuneResponseStatusCancelRequested FinetuneResponseStatus = "cancel_requested" + FinetuneResponseStatusCancelled FinetuneResponseStatus = "cancelled" + FinetuneResponseStatusError FinetuneResponseStatus = "error" + FinetuneResponseStatusCompleted FinetuneResponseStatus = "completed" +) + +// FinetuneResponseBatchSizeUnion contains all possible properties and values from +// [int64], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfInt OfFinetuneResponseBatchSizeString] +type FinetuneResponseBatchSizeUnion struct { + // This field will be present if the value is a [int64] instead of an object. + OfInt int64 `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFinetuneResponseBatchSizeString string `json:",inline"` + JSON struct { + OfInt respjson.Field + OfFinetuneResponseBatchSizeString respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseBatchSizeUnion) AsInt() (v int64) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseBatchSizeUnion) AsFinetuneResponseBatchSizeString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseBatchSizeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FinetuneResponseBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseBatchSizeString string + +const ( + FinetuneResponseBatchSizeStringMax FinetuneResponseBatchSizeString = "max" +) + +type FinetuneResponseLrScheduler struct { + // Any of "linear", "cosine". + LrSchedulerType string `json:"lr_scheduler_type,required"` + LrSchedulerArgs FinetuneResponseLrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LrSchedulerType respjson.Field + LrSchedulerArgs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseLrScheduler) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponseLrScheduler) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FinetuneResponseLrSchedulerLrSchedulerArgsUnion contains all possible properties +// and values from +// [FinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs], +// [FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FinetuneResponseLrSchedulerLrSchedulerArgsUnion struct { + MinLrRatio float64 `json:"min_lr_ratio"` + // This field is from variant + // [FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. + NumCycles float64 `json:"num_cycles"` + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseLrSchedulerLrSchedulerArgsUnion) AsFinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs() (v FinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseLrSchedulerLrSchedulerArgsUnion) AsFinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs() (v FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseLrSchedulerLrSchedulerArgsUnion) RawJSON() string { return u.JSON.raw } + +func (r *FinetuneResponseLrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FinetuneResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio,required"` + // Number or fraction of cycles for the cosine learning rate scheduler + NumCycles float64 `json:"num_cycles,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FinetuneResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FinetuneResponseTrainOnInputsUnion contains all possible properties and values +// from [bool], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfBool OfFinetuneResponseTrainOnInputsString] +type FinetuneResponseTrainOnInputsUnion struct { + // This field will be present if the value is a [bool] instead of an object. + OfBool bool `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFinetuneResponseTrainOnInputsString string `json:",inline"` + JSON struct { + OfBool respjson.Field + OfFinetuneResponseTrainOnInputsString respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseTrainOnInputsUnion) AsBool() (v bool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseTrainOnInputsUnion) AsFinetuneResponseTrainOnInputsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseTrainOnInputsUnion) RawJSON() string { return u.JSON.raw } + +func (r *FinetuneResponseTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseTrainOnInputsString string + +const ( + FinetuneResponseTrainOnInputsStringAuto FinetuneResponseTrainOnInputsString = "auto" +) + +// FinetuneResponseTrainingMethodUnion contains all possible properties and values +// from [FinetuneResponseTrainingMethodTrainingMethodSft], +// [FinetuneResponseTrainingMethodTrainingMethodDpo]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FinetuneResponseTrainingMethodUnion struct { + Method string `json:"method"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodSft]. + TrainOnInputs FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodDpo]. + DpoBeta float64 `json:"dpo_beta"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodDpo]. + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodDpo]. + DpoReferenceFree bool `json:"dpo_reference_free"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodDpo]. + RpoAlpha float64 `json:"rpo_alpha"` + // This field is from variant [FinetuneResponseTrainingMethodTrainingMethodDpo]. + SimpoGamma float64 `json:"simpo_gamma"` + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseTrainingMethodUnion) AsFinetuneResponseTrainingMethodTrainingMethodSft() (v FinetuneResponseTrainingMethodTrainingMethodSft) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseTrainingMethodUnion) AsFinetuneResponseTrainingMethodTrainingMethodDpo() (v FinetuneResponseTrainingMethodTrainingMethodDpo) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseTrainingMethodUnion) RawJSON() string { return u.JSON.raw } + +func (r *FinetuneResponseTrainingMethodUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseTrainingMethodTrainingMethodSft struct { + // Any of "sft". + Method string `json:"method,required"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseTrainingMethodTrainingMethodSft) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponseTrainingMethodTrainingMethodSft) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion contains all +// possible properties and values from [bool], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfBool +// OfFinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString] +type FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion struct { + // This field will be present if the value is a [bool] instead of an object. + OfBool bool `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString string `json:",inline"` + JSON struct { + OfBool respjson.Field + OfFinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsBool() (v bool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsFinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) RawJSON() string { + return u.JSON.raw +} + +func (r *FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString string + +const ( + FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsStringAuto FinetuneResponseTrainingMethodTrainingMethodSftTrainOnInputsString = "auto" +) + +type FinetuneResponseTrainingMethodTrainingMethodDpo struct { + // Any of "dpo". + Method string `json:"method,required"` + DpoBeta float64 `json:"dpo_beta"` + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + DpoReferenceFree bool `json:"dpo_reference_free"` + RpoAlpha float64 `json:"rpo_alpha"` + SimpoGamma float64 `json:"simpo_gamma"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseTrainingMethodTrainingMethodDpo) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponseTrainingMethodTrainingMethodDpo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FinetuneResponseTrainingTypeUnion contains all possible properties and values +// from [FinetuneResponseTrainingTypeFullTrainingType], +// [FinetuneResponseTrainingTypeLoRaTrainingType]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FinetuneResponseTrainingTypeUnion struct { + Type string `json:"type"` + // This field is from variant [FinetuneResponseTrainingTypeLoRaTrainingType]. + LoraAlpha int64 `json:"lora_alpha"` + // This field is from variant [FinetuneResponseTrainingTypeLoRaTrainingType]. + LoraR int64 `json:"lora_r"` + // This field is from variant [FinetuneResponseTrainingTypeLoRaTrainingType]. + LoraDropout float64 `json:"lora_dropout"` + // This field is from variant [FinetuneResponseTrainingTypeLoRaTrainingType]. + LoraTrainableModules string `json:"lora_trainable_modules"` + JSON struct { + Type respjson.Field + LoraAlpha respjson.Field + LoraR respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + raw string + } `json:"-"` +} + +func (u FinetuneResponseTrainingTypeUnion) AsFinetuneResponseTrainingTypeFullTrainingType() (v FinetuneResponseTrainingTypeFullTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FinetuneResponseTrainingTypeUnion) AsFinetuneResponseTrainingTypeLoRaTrainingType() (v FinetuneResponseTrainingTypeLoRaTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FinetuneResponseTrainingTypeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FinetuneResponseTrainingTypeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseTrainingTypeFullTrainingType struct { + // Any of "Full". + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseTrainingTypeFullTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponseTrainingTypeFullTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FinetuneResponseTrainingTypeLoRaTrainingType struct { + LoraAlpha int64 `json:"lora_alpha,required"` + LoraR int64 `json:"lora_r,required"` + // Any of "Lora". + Type string `json:"type,required"` + LoraDropout float64 `json:"lora_dropout"` + LoraTrainableModules string `json:"lora_trainable_modules"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LoraAlpha respjson.Field + LoraR respjson.Field + Type respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FinetuneResponseTrainingTypeLoRaTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FinetuneResponseTrainingTypeLoRaTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A truncated version of the fine-tune response, used for POST /fine-tunes, GET +// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints +type FineTuningNewResponse struct { + // Unique identifier for the fine-tune job + ID string `json:"id,required"` + // Creation timestamp of the fine-tune job + CreatedAt time.Time `json:"created_at,required" format:"date-time"` + // Any of "pending", "queued", "running", "compressing", "uploading", + // "cancel_requested", "cancelled", "error", "completed". + Status FineTuningNewResponseStatus `json:"status,required"` + // Last update timestamp of the fine-tune job + UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` + // Batch size used for training + BatchSize int64 `json:"batch_size"` + // Events related to this fine-tune job + Events []FinetuneEvent `json:"events"` + // Checkpoint used to continue training + FromCheckpoint string `json:"from_checkpoint"` + // Hugging Face Hub repo to start training from + FromHfModel string `json:"from_hf_model"` + // The revision of the Hugging Face Hub model to continue training from + HfModelRevision string `json:"hf_model_revision"` + // Learning rate used for training + LearningRate float64 `json:"learning_rate"` + // Learning rate scheduler configuration + LrScheduler FineTuningNewResponseLrScheduler `json:"lr_scheduler"` + // Maximum gradient norm for clipping + MaxGradNorm float64 `json:"max_grad_norm"` + // Base model used for fine-tuning + Model string `json:"model"` + ModelOutputName string `json:"model_output_name"` + // Number of checkpoints saved during training + NCheckpoints int64 `json:"n_checkpoints"` + // Number of training epochs + NEpochs int64 `json:"n_epochs"` + // Number of evaluations during training + NEvals int64 `json:"n_evals"` + // Owner address information + OwnerAddress string `json:"owner_address"` + // Suffix added to the fine-tuned model name + Suffix string `json:"suffix"` + // Count of tokens processed + TokenCount int64 `json:"token_count"` + // Total price for the fine-tuning job + TotalPrice int64 `json:"total_price"` + // File-ID of the training file + TrainingFile string `json:"training_file"` + // Method of training used + TrainingMethod FineTuningNewResponseTrainingMethodUnion `json:"training_method"` + // Type of training used (full or LoRA) + TrainingType FineTuningNewResponseTrainingTypeUnion `json:"training_type"` + // Identifier for the user who created the job + UserID string `json:"user_id"` + // File-ID of the validation file + ValidationFile string `json:"validation_file"` + // Weights & Biases run name + WandbName string `json:"wandb_name"` + // Weights & Biases project name + WandbProjectName string `json:"wandb_project_name"` + // Ratio of warmup steps + WarmupRatio float64 `json:"warmup_ratio"` + // Weight decay value used + WeightDecay float64 `json:"weight_decay"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Status respjson.Field + UpdatedAt respjson.Field + BatchSize respjson.Field + Events respjson.Field + FromCheckpoint respjson.Field + FromHfModel respjson.Field + HfModelRevision respjson.Field + LearningRate respjson.Field + LrScheduler respjson.Field + MaxGradNorm respjson.Field + Model respjson.Field + ModelOutputName respjson.Field + NCheckpoints respjson.Field + NEpochs respjson.Field + NEvals respjson.Field + OwnerAddress respjson.Field + Suffix respjson.Field + TokenCount respjson.Field + TotalPrice respjson.Field + TrainingFile respjson.Field + TrainingMethod respjson.Field + TrainingType respjson.Field + UserID respjson.Field + ValidationFile respjson.Field + WandbName respjson.Field + WandbProjectName respjson.Field + WarmupRatio respjson.Field + WeightDecay respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseStatus string + +const ( + FineTuningNewResponseStatusPending FineTuningNewResponseStatus = "pending" + FineTuningNewResponseStatusQueued FineTuningNewResponseStatus = "queued" + FineTuningNewResponseStatusRunning FineTuningNewResponseStatus = "running" + FineTuningNewResponseStatusCompressing FineTuningNewResponseStatus = "compressing" + FineTuningNewResponseStatusUploading FineTuningNewResponseStatus = "uploading" + FineTuningNewResponseStatusCancelRequested FineTuningNewResponseStatus = "cancel_requested" + FineTuningNewResponseStatusCancelled FineTuningNewResponseStatus = "cancelled" + FineTuningNewResponseStatusError FineTuningNewResponseStatus = "error" + FineTuningNewResponseStatusCompleted FineTuningNewResponseStatus = "completed" +) + +// Learning rate scheduler configuration +type FineTuningNewResponseLrScheduler struct { + // Any of "linear", "cosine". + LrSchedulerType string `json:"lr_scheduler_type,required"` + LrSchedulerArgs FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LrSchedulerType respjson.Field + LrSchedulerArgs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseLrScheduler) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponseLrScheduler) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion contains all possible +// properties and values from +// [FineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs], +// [FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion struct { + MinLrRatio float64 `json:"min_lr_ratio"` + // This field is from variant + // [FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. + NumCycles float64 `json:"num_cycles"` + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion) AsFineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs() (v FineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion) AsFineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs() (v FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningNewResponseLrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningNewResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio,required"` + // Number or fraction of cycles for the cosine learning rate scheduler + NumCycles float64 `json:"num_cycles,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningNewResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningNewResponseTrainingMethodUnion contains all possible properties and +// values from [FineTuningNewResponseTrainingMethodTrainingMethodSft], +// [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningNewResponseTrainingMethodUnion struct { + Method string `json:"method"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodSft]. + TrainOnInputs FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. + DpoBeta float64 `json:"dpo_beta"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. + DpoReferenceFree bool `json:"dpo_reference_free"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. + RpoAlpha float64 `json:"rpo_alpha"` + // This field is from variant + // [FineTuningNewResponseTrainingMethodTrainingMethodDpo]. + SimpoGamma float64 `json:"simpo_gamma"` + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningNewResponseTrainingMethodUnion) AsFineTuningNewResponseTrainingMethodTrainingMethodSft() (v FineTuningNewResponseTrainingMethodTrainingMethodSft) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningNewResponseTrainingMethodUnion) AsFineTuningNewResponseTrainingMethodTrainingMethodDpo() (v FineTuningNewResponseTrainingMethodTrainingMethodDpo) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningNewResponseTrainingMethodUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningNewResponseTrainingMethodUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseTrainingMethodTrainingMethodSft struct { + // Any of "sft". + Method string `json:"method,required"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseTrainingMethodTrainingMethodSft) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponseTrainingMethodTrainingMethodSft) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion contains +// all possible properties and values from [bool], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfBool +// OfFineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString] +type FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion struct { + // This field will be present if the value is a [bool] instead of an object. + OfBool bool `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString string `json:",inline"` + JSON struct { + OfBool respjson.Field + OfFineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsBool() (v bool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsFineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) RawJSON() string { + return u.JSON.raw +} + +func (r *FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString string + +const ( + FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsStringAuto FineTuningNewResponseTrainingMethodTrainingMethodSftTrainOnInputsString = "auto" +) + +type FineTuningNewResponseTrainingMethodTrainingMethodDpo struct { + // Any of "dpo". + Method string `json:"method,required"` + DpoBeta float64 `json:"dpo_beta"` + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + DpoReferenceFree bool `json:"dpo_reference_free"` + RpoAlpha float64 `json:"rpo_alpha"` + SimpoGamma float64 `json:"simpo_gamma"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseTrainingMethodTrainingMethodDpo) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponseTrainingMethodTrainingMethodDpo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningNewResponseTrainingTypeUnion contains all possible properties and +// values from [FineTuningNewResponseTrainingTypeFullTrainingType], +// [FineTuningNewResponseTrainingTypeLoRaTrainingType]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningNewResponseTrainingTypeUnion struct { + Type string `json:"type"` + // This field is from variant [FineTuningNewResponseTrainingTypeLoRaTrainingType]. + LoraAlpha int64 `json:"lora_alpha"` + // This field is from variant [FineTuningNewResponseTrainingTypeLoRaTrainingType]. + LoraR int64 `json:"lora_r"` + // This field is from variant [FineTuningNewResponseTrainingTypeLoRaTrainingType]. + LoraDropout float64 `json:"lora_dropout"` + // This field is from variant [FineTuningNewResponseTrainingTypeLoRaTrainingType]. + LoraTrainableModules string `json:"lora_trainable_modules"` + JSON struct { + Type respjson.Field + LoraAlpha respjson.Field + LoraR respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningNewResponseTrainingTypeUnion) AsFineTuningNewResponseTrainingTypeFullTrainingType() (v FineTuningNewResponseTrainingTypeFullTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningNewResponseTrainingTypeUnion) AsFineTuningNewResponseTrainingTypeLoRaTrainingType() (v FineTuningNewResponseTrainingTypeLoRaTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningNewResponseTrainingTypeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningNewResponseTrainingTypeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseTrainingTypeFullTrainingType struct { + // Any of "Full". + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseTrainingTypeFullTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponseTrainingTypeFullTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewResponseTrainingTypeLoRaTrainingType struct { + LoraAlpha int64 `json:"lora_alpha,required"` + LoraR int64 `json:"lora_r,required"` + // Any of "Lora". + Type string `json:"type,required"` + LoraDropout float64 `json:"lora_dropout"` + LoraTrainableModules string `json:"lora_trainable_modules"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LoraAlpha respjson.Field + LoraR respjson.Field + Type respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningNewResponseTrainingTypeLoRaTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningNewResponseTrainingTypeLoRaTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponse struct { + Data []FineTuningListResponseData `json:"data,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A truncated version of the fine-tune response, used for POST /fine-tunes, GET +// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints +type FineTuningListResponseData struct { + // Unique identifier for the fine-tune job + ID string `json:"id,required"` + // Creation timestamp of the fine-tune job + CreatedAt time.Time `json:"created_at,required" format:"date-time"` + // Any of "pending", "queued", "running", "compressing", "uploading", + // "cancel_requested", "cancelled", "error", "completed". + Status string `json:"status,required"` + // Last update timestamp of the fine-tune job + UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` + // Batch size used for training + BatchSize int64 `json:"batch_size"` + // Events related to this fine-tune job + Events []FinetuneEvent `json:"events"` + // Checkpoint used to continue training + FromCheckpoint string `json:"from_checkpoint"` + // Hugging Face Hub repo to start training from + FromHfModel string `json:"from_hf_model"` + // The revision of the Hugging Face Hub model to continue training from + HfModelRevision string `json:"hf_model_revision"` + // Learning rate used for training + LearningRate float64 `json:"learning_rate"` + // Learning rate scheduler configuration + LrScheduler FineTuningListResponseDataLrScheduler `json:"lr_scheduler"` + // Maximum gradient norm for clipping + MaxGradNorm float64 `json:"max_grad_norm"` + // Base model used for fine-tuning + Model string `json:"model"` + ModelOutputName string `json:"model_output_name"` + // Number of checkpoints saved during training + NCheckpoints int64 `json:"n_checkpoints"` + // Number of training epochs + NEpochs int64 `json:"n_epochs"` + // Number of evaluations during training + NEvals int64 `json:"n_evals"` + // Owner address information + OwnerAddress string `json:"owner_address"` + // Suffix added to the fine-tuned model name + Suffix string `json:"suffix"` + // Count of tokens processed + TokenCount int64 `json:"token_count"` + // Total price for the fine-tuning job + TotalPrice int64 `json:"total_price"` + // File-ID of the training file + TrainingFile string `json:"training_file"` + // Method of training used + TrainingMethod FineTuningListResponseDataTrainingMethodUnion `json:"training_method"` + // Type of training used (full or LoRA) + TrainingType FineTuningListResponseDataTrainingTypeUnion `json:"training_type"` + // Identifier for the user who created the job + UserID string `json:"user_id"` + // File-ID of the validation file + ValidationFile string `json:"validation_file"` + // Weights & Biases run name + WandbName string `json:"wandb_name"` + // Weights & Biases project name + WandbProjectName string `json:"wandb_project_name"` + // Ratio of warmup steps + WarmupRatio float64 `json:"warmup_ratio"` + // Weight decay value used + WeightDecay float64 `json:"weight_decay"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Status respjson.Field + UpdatedAt respjson.Field + BatchSize respjson.Field + Events respjson.Field + FromCheckpoint respjson.Field + FromHfModel respjson.Field + HfModelRevision respjson.Field + LearningRate respjson.Field + LrScheduler respjson.Field + MaxGradNorm respjson.Field + Model respjson.Field + ModelOutputName respjson.Field + NCheckpoints respjson.Field + NEpochs respjson.Field + NEvals respjson.Field + OwnerAddress respjson.Field + Suffix respjson.Field + TokenCount respjson.Field + TotalPrice respjson.Field + TrainingFile respjson.Field + TrainingMethod respjson.Field + TrainingType respjson.Field + UserID respjson.Field + ValidationFile respjson.Field + WandbName respjson.Field + WandbProjectName respjson.Field + WarmupRatio respjson.Field + WeightDecay respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseData) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListResponseData) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Learning rate scheduler configuration +type FineTuningListResponseDataLrScheduler struct { + // Any of "linear", "cosine". + LrSchedulerType string `json:"lr_scheduler_type,required"` + LrSchedulerArgs FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LrSchedulerType respjson.Field + LrSchedulerArgs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataLrScheduler) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListResponseDataLrScheduler) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion contains all possible +// properties and values from +// [FineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs], +// [FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion struct { + MinLrRatio float64 `json:"min_lr_ratio"` + // This field is from variant + // [FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. + NumCycles float64 `json:"num_cycles"` + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion) AsFineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs() (v FineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion) AsFineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs() (v FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion) RawJSON() string { + return u.JSON.raw +} + +func (r *FineTuningListResponseDataLrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningListResponseDataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio,required"` + // Number or fraction of cycles for the cosine learning rate scheduler + NumCycles float64 `json:"num_cycles,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningListResponseDataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningListResponseDataTrainingMethodUnion contains all possible properties +// and values from [FineTuningListResponseDataTrainingMethodTrainingMethodSft], +// [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningListResponseDataTrainingMethodUnion struct { + Method string `json:"method"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodSft]. + TrainOnInputs FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. + DpoBeta float64 `json:"dpo_beta"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. + DpoReferenceFree bool `json:"dpo_reference_free"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. + RpoAlpha float64 `json:"rpo_alpha"` + // This field is from variant + // [FineTuningListResponseDataTrainingMethodTrainingMethodDpo]. + SimpoGamma float64 `json:"simpo_gamma"` + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningListResponseDataTrainingMethodUnion) AsFineTuningListResponseDataTrainingMethodTrainingMethodSft() (v FineTuningListResponseDataTrainingMethodTrainingMethodSft) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningListResponseDataTrainingMethodUnion) AsFineTuningListResponseDataTrainingMethodTrainingMethodDpo() (v FineTuningListResponseDataTrainingMethodTrainingMethodDpo) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningListResponseDataTrainingMethodUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningListResponseDataTrainingMethodUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataTrainingMethodTrainingMethodSft struct { + // Any of "sft". + Method string `json:"method,required"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataTrainingMethodTrainingMethodSft) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningListResponseDataTrainingMethodTrainingMethodSft) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion +// contains all possible properties and values from [bool], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfBool +// OfFineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString] +type FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion struct { + // This field will be present if the value is a [bool] instead of an object. + OfBool bool `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString string `json:",inline"` + JSON struct { + OfBool respjson.Field + OfFineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsBool() (v bool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsFineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion) RawJSON() string { + return u.JSON.raw +} + +func (r *FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString string + +const ( + FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsStringAuto FineTuningListResponseDataTrainingMethodTrainingMethodSftTrainOnInputsString = "auto" +) + +type FineTuningListResponseDataTrainingMethodTrainingMethodDpo struct { + // Any of "dpo". + Method string `json:"method,required"` + DpoBeta float64 `json:"dpo_beta"` + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + DpoReferenceFree bool `json:"dpo_reference_free"` + RpoAlpha float64 `json:"rpo_alpha"` + SimpoGamma float64 `json:"simpo_gamma"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataTrainingMethodTrainingMethodDpo) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningListResponseDataTrainingMethodTrainingMethodDpo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningListResponseDataTrainingTypeUnion contains all possible properties and +// values from [FineTuningListResponseDataTrainingTypeFullTrainingType], +// [FineTuningListResponseDataTrainingTypeLoRaTrainingType]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningListResponseDataTrainingTypeUnion struct { + Type string `json:"type"` + // This field is from variant + // [FineTuningListResponseDataTrainingTypeLoRaTrainingType]. + LoraAlpha int64 `json:"lora_alpha"` + // This field is from variant + // [FineTuningListResponseDataTrainingTypeLoRaTrainingType]. + LoraR int64 `json:"lora_r"` + // This field is from variant + // [FineTuningListResponseDataTrainingTypeLoRaTrainingType]. + LoraDropout float64 `json:"lora_dropout"` + // This field is from variant + // [FineTuningListResponseDataTrainingTypeLoRaTrainingType]. + LoraTrainableModules string `json:"lora_trainable_modules"` + JSON struct { + Type respjson.Field + LoraAlpha respjson.Field + LoraR respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningListResponseDataTrainingTypeUnion) AsFineTuningListResponseDataTrainingTypeFullTrainingType() (v FineTuningListResponseDataTrainingTypeFullTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningListResponseDataTrainingTypeUnion) AsFineTuningListResponseDataTrainingTypeLoRaTrainingType() (v FineTuningListResponseDataTrainingTypeLoRaTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningListResponseDataTrainingTypeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningListResponseDataTrainingTypeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataTrainingTypeFullTrainingType struct { + // Any of "Full". + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataTrainingTypeFullTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListResponseDataTrainingTypeFullTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListResponseDataTrainingTypeLoRaTrainingType struct { + LoraAlpha int64 `json:"lora_alpha,required"` + LoraR int64 `json:"lora_r,required"` + // Any of "Lora". + Type string `json:"type,required"` + LoraDropout float64 `json:"lora_dropout"` + LoraTrainableModules string `json:"lora_trainable_modules"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LoraAlpha respjson.Field + LoraR respjson.Field + Type respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListResponseDataTrainingTypeLoRaTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListResponseDataTrainingTypeLoRaTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningDeleteResponse struct { + // Message indicating the result of the deletion + Message string `json:"message"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Message respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningDeleteResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningDeleteResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// A truncated version of the fine-tune response, used for POST /fine-tunes, GET +// /fine-tunes and POST /fine-tunes/{id}/cancel endpoints +type FineTuningCancelResponse struct { + // Unique identifier for the fine-tune job + ID string `json:"id,required"` + // Creation timestamp of the fine-tune job + CreatedAt time.Time `json:"created_at,required" format:"date-time"` + // Any of "pending", "queued", "running", "compressing", "uploading", + // "cancel_requested", "cancelled", "error", "completed". + Status FineTuningCancelResponseStatus `json:"status,required"` + // Last update timestamp of the fine-tune job + UpdatedAt time.Time `json:"updated_at,required" format:"date-time"` + // Batch size used for training + BatchSize int64 `json:"batch_size"` + // Events related to this fine-tune job + Events []FinetuneEvent `json:"events"` + // Checkpoint used to continue training + FromCheckpoint string `json:"from_checkpoint"` + // Hugging Face Hub repo to start training from + FromHfModel string `json:"from_hf_model"` + // The revision of the Hugging Face Hub model to continue training from + HfModelRevision string `json:"hf_model_revision"` + // Learning rate used for training + LearningRate float64 `json:"learning_rate"` + // Learning rate scheduler configuration + LrScheduler FineTuningCancelResponseLrScheduler `json:"lr_scheduler"` + // Maximum gradient norm for clipping + MaxGradNorm float64 `json:"max_grad_norm"` + // Base model used for fine-tuning + Model string `json:"model"` + ModelOutputName string `json:"model_output_name"` + // Number of checkpoints saved during training + NCheckpoints int64 `json:"n_checkpoints"` + // Number of training epochs + NEpochs int64 `json:"n_epochs"` + // Number of evaluations during training + NEvals int64 `json:"n_evals"` + // Owner address information + OwnerAddress string `json:"owner_address"` + // Suffix added to the fine-tuned model name + Suffix string `json:"suffix"` + // Count of tokens processed + TokenCount int64 `json:"token_count"` + // Total price for the fine-tuning job + TotalPrice int64 `json:"total_price"` + // File-ID of the training file + TrainingFile string `json:"training_file"` + // Method of training used + TrainingMethod FineTuningCancelResponseTrainingMethodUnion `json:"training_method"` + // Type of training used (full or LoRA) + TrainingType FineTuningCancelResponseTrainingTypeUnion `json:"training_type"` + // Identifier for the user who created the job + UserID string `json:"user_id"` + // File-ID of the validation file + ValidationFile string `json:"validation_file"` + // Weights & Biases run name + WandbName string `json:"wandb_name"` + // Weights & Biases project name + WandbProjectName string `json:"wandb_project_name"` + // Ratio of warmup steps + WarmupRatio float64 `json:"warmup_ratio"` + // Weight decay value used + WeightDecay float64 `json:"weight_decay"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ID respjson.Field + CreatedAt respjson.Field + Status respjson.Field + UpdatedAt respjson.Field + BatchSize respjson.Field + Events respjson.Field + FromCheckpoint respjson.Field + FromHfModel respjson.Field + HfModelRevision respjson.Field + LearningRate respjson.Field + LrScheduler respjson.Field + MaxGradNorm respjson.Field + Model respjson.Field + ModelOutputName respjson.Field + NCheckpoints respjson.Field + NEpochs respjson.Field + NEvals respjson.Field + OwnerAddress respjson.Field + Suffix respjson.Field + TokenCount respjson.Field + TotalPrice respjson.Field + TrainingFile respjson.Field + TrainingMethod respjson.Field + TrainingType respjson.Field + UserID respjson.Field + ValidationFile respjson.Field + WandbName respjson.Field + WandbProjectName respjson.Field + WarmupRatio respjson.Field + WeightDecay respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseStatus string + +const ( + FineTuningCancelResponseStatusPending FineTuningCancelResponseStatus = "pending" + FineTuningCancelResponseStatusQueued FineTuningCancelResponseStatus = "queued" + FineTuningCancelResponseStatusRunning FineTuningCancelResponseStatus = "running" + FineTuningCancelResponseStatusCompressing FineTuningCancelResponseStatus = "compressing" + FineTuningCancelResponseStatusUploading FineTuningCancelResponseStatus = "uploading" + FineTuningCancelResponseStatusCancelRequested FineTuningCancelResponseStatus = "cancel_requested" + FineTuningCancelResponseStatusCancelled FineTuningCancelResponseStatus = "cancelled" + FineTuningCancelResponseStatusError FineTuningCancelResponseStatus = "error" + FineTuningCancelResponseStatusCompleted FineTuningCancelResponseStatus = "completed" +) + +// Learning rate scheduler configuration +type FineTuningCancelResponseLrScheduler struct { + // Any of "linear", "cosine". + LrSchedulerType string `json:"lr_scheduler_type,required"` + LrSchedulerArgs FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LrSchedulerType respjson.Field + LrSchedulerArgs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseLrScheduler) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponseLrScheduler) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion contains all possible +// properties and values from +// [FineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs], +// [FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion struct { + MinLrRatio float64 `json:"min_lr_ratio"` + // This field is from variant + // [FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs]. + NumCycles float64 `json:"num_cycles"` + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion) AsFineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs() (v FineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion) AsFineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs() (v FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningCancelResponseLrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningCancelResponseLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio,required"` + // Number or fraction of cycles for the cosine learning rate scheduler + NumCycles float64 `json:"num_cycles,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + MinLrRatio respjson.Field + NumCycles respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) RawJSON() string { + return r.JSON.raw +} +func (r *FineTuningCancelResponseLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningCancelResponseTrainingMethodUnion contains all possible properties and +// values from [FineTuningCancelResponseTrainingMethodTrainingMethodSft], +// [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningCancelResponseTrainingMethodUnion struct { + Method string `json:"method"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodSft]. + TrainOnInputs FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. + DpoBeta float64 `json:"dpo_beta"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. + DpoReferenceFree bool `json:"dpo_reference_free"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. + RpoAlpha float64 `json:"rpo_alpha"` + // This field is from variant + // [FineTuningCancelResponseTrainingMethodTrainingMethodDpo]. + SimpoGamma float64 `json:"simpo_gamma"` + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningCancelResponseTrainingMethodUnion) AsFineTuningCancelResponseTrainingMethodTrainingMethodSft() (v FineTuningCancelResponseTrainingMethodTrainingMethodSft) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningCancelResponseTrainingMethodUnion) AsFineTuningCancelResponseTrainingMethodTrainingMethodDpo() (v FineTuningCancelResponseTrainingMethodTrainingMethodDpo) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningCancelResponseTrainingMethodUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningCancelResponseTrainingMethodUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseTrainingMethodTrainingMethodSft struct { + // Any of "sft". + Method string `json:"method,required"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + TrainOnInputs respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseTrainingMethodTrainingMethodSft) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponseTrainingMethodTrainingMethodSft) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion +// contains all possible properties and values from [bool], [string]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +// +// If the underlying value is not a json object, one of the following properties +// will be valid: OfBool +// OfFineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString] +type FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion struct { + // This field will be present if the value is a [bool] instead of an object. + OfBool bool `json:",inline"` + // This field will be present if the value is a [string] instead of an object. + OfFineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString string `json:",inline"` + JSON struct { + OfBool respjson.Field + OfFineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsBool() (v bool) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) AsFineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString() (v string) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) RawJSON() string { + return u.JSON.raw +} + +func (r *FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString string + +const ( + FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsStringAuto FineTuningCancelResponseTrainingMethodTrainingMethodSftTrainOnInputsString = "auto" +) + +type FineTuningCancelResponseTrainingMethodTrainingMethodDpo struct { + // Any of "dpo". + Method string `json:"method,required"` + DpoBeta float64 `json:"dpo_beta"` + DpoNormalizeLogratiosByLength bool `json:"dpo_normalize_logratios_by_length"` + DpoReferenceFree bool `json:"dpo_reference_free"` + RpoAlpha float64 `json:"rpo_alpha"` + SimpoGamma float64 `json:"simpo_gamma"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Method respjson.Field + DpoBeta respjson.Field + DpoNormalizeLogratiosByLength respjson.Field + DpoReferenceFree respjson.Field + RpoAlpha respjson.Field + SimpoGamma respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseTrainingMethodTrainingMethodDpo) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponseTrainingMethodTrainingMethodDpo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// FineTuningCancelResponseTrainingTypeUnion contains all possible properties and +// values from [FineTuningCancelResponseTrainingTypeFullTrainingType], +// [FineTuningCancelResponseTrainingTypeLoRaTrainingType]. +// +// Use the methods beginning with 'As' to cast the union to one of its variants. +type FineTuningCancelResponseTrainingTypeUnion struct { + Type string `json:"type"` + // This field is from variant + // [FineTuningCancelResponseTrainingTypeLoRaTrainingType]. + LoraAlpha int64 `json:"lora_alpha"` + // This field is from variant + // [FineTuningCancelResponseTrainingTypeLoRaTrainingType]. + LoraR int64 `json:"lora_r"` + // This field is from variant + // [FineTuningCancelResponseTrainingTypeLoRaTrainingType]. + LoraDropout float64 `json:"lora_dropout"` + // This field is from variant + // [FineTuningCancelResponseTrainingTypeLoRaTrainingType]. + LoraTrainableModules string `json:"lora_trainable_modules"` + JSON struct { + Type respjson.Field + LoraAlpha respjson.Field + LoraR respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + raw string + } `json:"-"` +} + +func (u FineTuningCancelResponseTrainingTypeUnion) AsFineTuningCancelResponseTrainingTypeFullTrainingType() (v FineTuningCancelResponseTrainingTypeFullTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +func (u FineTuningCancelResponseTrainingTypeUnion) AsFineTuningCancelResponseTrainingTypeLoRaTrainingType() (v FineTuningCancelResponseTrainingTypeLoRaTrainingType) { + apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v) + return +} + +// Returns the unmodified JSON received from the API +func (u FineTuningCancelResponseTrainingTypeUnion) RawJSON() string { return u.JSON.raw } + +func (r *FineTuningCancelResponseTrainingTypeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseTrainingTypeFullTrainingType struct { + // Any of "Full". + Type string `json:"type,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Type respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseTrainingTypeFullTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponseTrainingTypeFullTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningCancelResponseTrainingTypeLoRaTrainingType struct { + LoraAlpha int64 `json:"lora_alpha,required"` + LoraR int64 `json:"lora_r,required"` + // Any of "Lora". + Type string `json:"type,required"` + LoraDropout float64 `json:"lora_dropout"` + LoraTrainableModules string `json:"lora_trainable_modules"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + LoraAlpha respjson.Field + LoraR respjson.Field + Type respjson.Field + LoraDropout respjson.Field + LoraTrainableModules respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningCancelResponseTrainingTypeLoRaTrainingType) RawJSON() string { return r.JSON.raw } +func (r *FineTuningCancelResponseTrainingTypeLoRaTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListCheckpointsResponse struct { + Data []FineTuningListCheckpointsResponseData `json:"data,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListCheckpointsResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListCheckpointsResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListCheckpointsResponseData struct { + CheckpointType string `json:"checkpoint_type,required"` + CreatedAt string `json:"created_at,required"` + Path string `json:"path,required"` + Step int64 `json:"step,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + CheckpointType respjson.Field + CreatedAt respjson.Field + Path respjson.Field + Step respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListCheckpointsResponseData) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListCheckpointsResponseData) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningListEventsResponse struct { + Data []FinetuneEvent `json:"data,required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Data respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r FineTuningListEventsResponse) RawJSON() string { return r.JSON.raw } +func (r *FineTuningListEventsResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +type FineTuningNewParams struct { + // Name of the base model to run fine-tune job on + Model string `json:"model,required"` + // File-ID of a training file uploaded to the Together API + TrainingFile string `json:"training_file,required"` + // The checkpoint identifier to continue training from a previous fine-tuning job. + // Format is `{$JOB_ID}` or `{$OUTPUT_MODEL_NAME}` or `{$JOB_ID}:{$STEP}` or + // `{$OUTPUT_MODEL_NAME}:{$STEP}`. The step value is optional; without it, the + // final checkpoint will be used. + FromCheckpoint param.Opt[string] `json:"from_checkpoint,omitzero"` + // The Hugging Face Hub repo to start training from. Should be as close as possible + // to the base model (specified by the `model` argument) in terms of architecture + // and size. + FromHfModel param.Opt[string] `json:"from_hf_model,omitzero"` + // The API token for the Hugging Face Hub. + HfAPIToken param.Opt[string] `json:"hf_api_token,omitzero"` + // The revision of the Hugging Face Hub model to continue training from. E.g., + // hf_model_revision=main (default, used if the argument is not provided) or + // hf_model_revision='607a30d783dfa663caf39e06633721c8d4cfcd7e' (specific commit). + HfModelRevision param.Opt[string] `json:"hf_model_revision,omitzero"` + // The name of the Hugging Face repository to upload the fine-tuned model to. + HfOutputRepoName param.Opt[string] `json:"hf_output_repo_name,omitzero"` + // Controls how quickly the model adapts to new information (too high may cause + // instability, too low may slow convergence) + LearningRate param.Opt[float64] `json:"learning_rate,omitzero"` + // Max gradient norm to be used for gradient clipping. Set to 0 to disable. + MaxGradNorm param.Opt[float64] `json:"max_grad_norm,omitzero"` + // Number of intermediate model versions saved during training for evaluation + NCheckpoints param.Opt[int64] `json:"n_checkpoints,omitzero"` + // Number of complete passes through the training dataset (higher values may + // improve results but increase cost and risk of overfitting) + NEpochs param.Opt[int64] `json:"n_epochs,omitzero"` + // Number of evaluations to be run on a given validation set during training + NEvals param.Opt[int64] `json:"n_evals,omitzero"` + // Suffix that will be added to your fine-tuned model name + Suffix param.Opt[string] `json:"suffix,omitzero"` + // File-ID of a validation file uploaded to the Together API + ValidationFile param.Opt[string] `json:"validation_file,omitzero"` + // Integration key for tracking experiments and model metrics on W&B platform + WandbAPIKey param.Opt[string] `json:"wandb_api_key,omitzero"` + // The base URL of a dedicated Weights & Biases instance. + WandbBaseURL param.Opt[string] `json:"wandb_base_url,omitzero"` + // The Weights & Biases name for your run. + WandbName param.Opt[string] `json:"wandb_name,omitzero"` + // The Weights & Biases project for your run. If not specified, will use `together` + // as the project name. + WandbProjectName param.Opt[string] `json:"wandb_project_name,omitzero"` + // The percent of steps at the start of training to linearly increase the learning + // rate. + WarmupRatio param.Opt[float64] `json:"warmup_ratio,omitzero"` + // Weight decay. Regularization parameter for the optimizer. + WeightDecay param.Opt[float64] `json:"weight_decay,omitzero"` + // Number of training examples processed together (larger batches use more memory + // but may train faster). Defaults to "max". We use training optimizations like + // packing, so the effective batch size may be different than the value you set. + BatchSize FineTuningNewParamsBatchSizeUnion `json:"batch_size,omitzero"` + // The learning rate scheduler to use. It specifies how the learning rate is + // adjusted during training. + LrScheduler FineTuningNewParamsLrScheduler `json:"lr_scheduler,omitzero"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FineTuningNewParamsTrainOnInputsUnion `json:"train_on_inputs,omitzero"` + // The training method to use. 'sft' for Supervised Fine-Tuning or 'dpo' for Direct + // Preference Optimization. + TrainingMethod FineTuningNewParamsTrainingMethodUnion `json:"training_method,omitzero"` + TrainingType FineTuningNewParamsTrainingTypeUnion `json:"training_type,omitzero"` + paramObj +} + +func (r FineTuningNewParams) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsBatchSizeUnion struct { + OfInt param.Opt[int64] `json:",omitzero,inline"` + // Check if union is this variant with + // !param.IsOmitted(union.OfFineTuningNewsBatchSizeString) + OfFineTuningNewsBatchSizeString param.Opt[string] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsBatchSizeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfInt, u.OfFineTuningNewsBatchSizeString) +} +func (u *FineTuningNewParamsBatchSizeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsBatchSizeUnion) asAny() any { + if !param.IsOmitted(u.OfInt) { + return &u.OfInt.Value + } else if !param.IsOmitted(u.OfFineTuningNewsBatchSizeString) { + return &u.OfFineTuningNewsBatchSizeString + } + return nil +} + +type FineTuningNewParamsBatchSizeString string + +const ( + FineTuningNewParamsBatchSizeStringMax FineTuningNewParamsBatchSizeString = "max" +) + +// The learning rate scheduler to use. It specifies how the learning rate is +// adjusted during training. +// +// The property LrSchedulerType is required. +type FineTuningNewParamsLrScheduler struct { + // Any of "linear", "cosine". + LrSchedulerType string `json:"lr_scheduler_type,omitzero,required"` + LrSchedulerArgs FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion `json:"lr_scheduler_args,omitzero"` + paramObj +} + +func (r FineTuningNewParamsLrScheduler) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsLrScheduler + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsLrScheduler) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningNewParamsLrScheduler]( + "lr_scheduler_type", "linear", "cosine", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion struct { + OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs *FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs `json:",omitzero,inline"` + OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs *FineTuningNewParamsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs, u.OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) +} +func (u *FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion) asAny() any { + if !param.IsOmitted(u.OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) { + return u.OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs + } else if !param.IsOmitted(u.OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) { + return u.OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion) GetNumCycles() *float64 { + if vt := u.OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs; vt != nil { + return &vt.NumCycles + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion) GetMinLrRatio() *float64 { + if vt := u.OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs; vt != nil && vt.MinLrRatio.Valid() { + return &vt.MinLrRatio.Value + } else if vt := u.OfFineTuningNewsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs; vt != nil { + return (*float64)(&vt.MinLrRatio) + } + return nil +} + +type FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio param.Opt[float64] `json:"min_lr_ratio,omitzero"` + paramObj +} + +func (r FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// The properties MinLrRatio, NumCycles are required. +type FineTuningNewParamsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs struct { + // The ratio of the final learning rate to the peak learning rate + MinLrRatio float64 `json:"min_lr_ratio,required"` + // Number or fraction of cycles for the cosine learning rate scheduler + NumCycles float64 `json:"num_cycles,required"` + paramObj +} + +func (r FineTuningNewParamsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsTrainOnInputsUnion struct { + OfBool param.Opt[bool] `json:",omitzero,inline"` + // Check if union is this variant with + // !param.IsOmitted(union.OfFineTuningNewsTrainOnInputsString) + OfFineTuningNewsTrainOnInputsString param.Opt[string] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsTrainOnInputsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfBool, u.OfFineTuningNewsTrainOnInputsString) +} +func (u *FineTuningNewParamsTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsTrainOnInputsUnion) asAny() any { + if !param.IsOmitted(u.OfBool) { + return &u.OfBool.Value + } else if !param.IsOmitted(u.OfFineTuningNewsTrainOnInputsString) { + return &u.OfFineTuningNewsTrainOnInputsString + } + return nil +} + +type FineTuningNewParamsTrainOnInputsString string + +const ( + FineTuningNewParamsTrainOnInputsStringAuto FineTuningNewParamsTrainOnInputsString = "auto" +) + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsTrainingMethodUnion struct { + OfFineTuningNewsTrainingMethodTrainingMethodSft *FineTuningNewParamsTrainingMethodTrainingMethodSft `json:",omitzero,inline"` + OfFineTuningNewsTrainingMethodTrainingMethodDpo *FineTuningNewParamsTrainingMethodTrainingMethodDpo `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsTrainingMethodUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfFineTuningNewsTrainingMethodTrainingMethodSft, u.OfFineTuningNewsTrainingMethodTrainingMethodDpo) +} +func (u *FineTuningNewParamsTrainingMethodUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsTrainingMethodUnion) asAny() any { + if !param.IsOmitted(u.OfFineTuningNewsTrainingMethodTrainingMethodSft) { + return u.OfFineTuningNewsTrainingMethodTrainingMethodSft + } else if !param.IsOmitted(u.OfFineTuningNewsTrainingMethodTrainingMethodDpo) { + return u.OfFineTuningNewsTrainingMethodTrainingMethodDpo + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetTrainOnInputs() *FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodSft; vt != nil { + return &vt.TrainOnInputs + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetDpoBeta() *float64 { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil && vt.DpoBeta.Valid() { + return &vt.DpoBeta.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetDpoNormalizeLogratiosByLength() *bool { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil && vt.DpoNormalizeLogratiosByLength.Valid() { + return &vt.DpoNormalizeLogratiosByLength.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetDpoReferenceFree() *bool { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil && vt.DpoReferenceFree.Valid() { + return &vt.DpoReferenceFree.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetRpoAlpha() *float64 { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil && vt.RpoAlpha.Valid() { + return &vt.RpoAlpha.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetSimpoGamma() *float64 { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil && vt.SimpoGamma.Valid() { + return &vt.SimpoGamma.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingMethodUnion) GetMethod() *string { + if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodSft; vt != nil { + return (*string)(&vt.Method) + } else if vt := u.OfFineTuningNewsTrainingMethodTrainingMethodDpo; vt != nil { + return (*string)(&vt.Method) + } + return nil +} + +// The properties Method, TrainOnInputs are required. +type FineTuningNewParamsTrainingMethodTrainingMethodSft struct { + // Any of "sft". + Method string `json:"method,omitzero,required"` + // Whether to mask the user messages in conversational data or prompts in + // instruction data. + TrainOnInputs FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion `json:"train_on_inputs,omitzero,required"` + paramObj +} + +func (r FineTuningNewParamsTrainingMethodTrainingMethodSft) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsTrainingMethodTrainingMethodSft + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsTrainingMethodTrainingMethodSft) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningNewParamsTrainingMethodTrainingMethodSft]( + "method", "sft", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion struct { + OfBool param.Opt[bool] `json:",omitzero,inline"` + // Check if union is this variant with + // !param.IsOmitted(union.OfFineTuningNewsTrainingMethodTrainingMethodSftTrainOnInputsString) + OfFineTuningNewsTrainingMethodTrainingMethodSftTrainOnInputsString param.Opt[string] `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfBool, u.OfFineTuningNewsTrainingMethodTrainingMethodSftTrainOnInputsString) +} +func (u *FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion) asAny() any { + if !param.IsOmitted(u.OfBool) { + return &u.OfBool.Value + } else if !param.IsOmitted(u.OfFineTuningNewsTrainingMethodTrainingMethodSftTrainOnInputsString) { + return &u.OfFineTuningNewsTrainingMethodTrainingMethodSftTrainOnInputsString + } + return nil +} + +type FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsString string + +const ( + FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsStringAuto FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsString = "auto" +) + +// The property Method is required. +type FineTuningNewParamsTrainingMethodTrainingMethodDpo struct { + // Any of "dpo". + Method string `json:"method,omitzero,required"` + DpoBeta param.Opt[float64] `json:"dpo_beta,omitzero"` + DpoNormalizeLogratiosByLength param.Opt[bool] `json:"dpo_normalize_logratios_by_length,omitzero"` + DpoReferenceFree param.Opt[bool] `json:"dpo_reference_free,omitzero"` + RpoAlpha param.Opt[float64] `json:"rpo_alpha,omitzero"` + SimpoGamma param.Opt[float64] `json:"simpo_gamma,omitzero"` + paramObj +} + +func (r FineTuningNewParamsTrainingMethodTrainingMethodDpo) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsTrainingMethodTrainingMethodDpo + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsTrainingMethodTrainingMethodDpo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningNewParamsTrainingMethodTrainingMethodDpo]( + "method", "dpo", + ) +} + +// Only one field can be non-zero. +// +// Use [param.IsOmitted] to confirm if a field is set. +type FineTuningNewParamsTrainingTypeUnion struct { + OfFineTuningNewsTrainingTypeFullTrainingType *FineTuningNewParamsTrainingTypeFullTrainingType `json:",omitzero,inline"` + OfFineTuningNewsTrainingTypeLoRaTrainingType *FineTuningNewParamsTrainingTypeLoRaTrainingType `json:",omitzero,inline"` + paramUnion +} + +func (u FineTuningNewParamsTrainingTypeUnion) MarshalJSON() ([]byte, error) { + return param.MarshalUnion(u, u.OfFineTuningNewsTrainingTypeFullTrainingType, u.OfFineTuningNewsTrainingTypeLoRaTrainingType) +} +func (u *FineTuningNewParamsTrainingTypeUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, u) +} + +func (u *FineTuningNewParamsTrainingTypeUnion) asAny() any { + if !param.IsOmitted(u.OfFineTuningNewsTrainingTypeFullTrainingType) { + return u.OfFineTuningNewsTrainingTypeFullTrainingType + } else if !param.IsOmitted(u.OfFineTuningNewsTrainingTypeLoRaTrainingType) { + return u.OfFineTuningNewsTrainingTypeLoRaTrainingType + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingTypeUnion) GetLoraAlpha() *int64 { + if vt := u.OfFineTuningNewsTrainingTypeLoRaTrainingType; vt != nil { + return &vt.LoraAlpha + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingTypeUnion) GetLoraR() *int64 { + if vt := u.OfFineTuningNewsTrainingTypeLoRaTrainingType; vt != nil { + return &vt.LoraR + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingTypeUnion) GetLoraDropout() *float64 { + if vt := u.OfFineTuningNewsTrainingTypeLoRaTrainingType; vt != nil && vt.LoraDropout.Valid() { + return &vt.LoraDropout.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingTypeUnion) GetLoraTrainableModules() *string { + if vt := u.OfFineTuningNewsTrainingTypeLoRaTrainingType; vt != nil && vt.LoraTrainableModules.Valid() { + return &vt.LoraTrainableModules.Value + } + return nil +} + +// Returns a pointer to the underlying variant's property, if present. +func (u FineTuningNewParamsTrainingTypeUnion) GetType() *string { + if vt := u.OfFineTuningNewsTrainingTypeFullTrainingType; vt != nil { + return (*string)(&vt.Type) + } else if vt := u.OfFineTuningNewsTrainingTypeLoRaTrainingType; vt != nil { + return (*string)(&vt.Type) + } + return nil +} + +// The property Type is required. +type FineTuningNewParamsTrainingTypeFullTrainingType struct { + // Any of "Full". + Type string `json:"type,omitzero,required"` + paramObj +} + +func (r FineTuningNewParamsTrainingTypeFullTrainingType) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsTrainingTypeFullTrainingType + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsTrainingTypeFullTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningNewParamsTrainingTypeFullTrainingType]( + "type", "Full", + ) +} + +// The properties LoraAlpha, LoraR, Type are required. +type FineTuningNewParamsTrainingTypeLoRaTrainingType struct { + LoraAlpha int64 `json:"lora_alpha,required"` + LoraR int64 `json:"lora_r,required"` + // Any of "Lora". + Type string `json:"type,omitzero,required"` + LoraDropout param.Opt[float64] `json:"lora_dropout,omitzero"` + LoraTrainableModules param.Opt[string] `json:"lora_trainable_modules,omitzero"` + paramObj +} + +func (r FineTuningNewParamsTrainingTypeLoRaTrainingType) MarshalJSON() (data []byte, err error) { + type shadow FineTuningNewParamsTrainingTypeLoRaTrainingType + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *FineTuningNewParamsTrainingTypeLoRaTrainingType) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func init() { + apijson.RegisterFieldValidator[FineTuningNewParamsTrainingTypeLoRaTrainingType]( + "type", "Lora", + ) +} + +type FineTuningDeleteParams struct { + Force param.Opt[bool] `query:"force,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningDeleteParams]'s query parameters as `url.Values`. +func (r FineTuningDeleteParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatComma, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +type FineTuningContentParams struct { + // Fine-tune ID to download. A string that starts with `ft-`. + FtID string `query:"ft_id,required" json:"-"` + // Specifies step number for checkpoint to download. Ignores `checkpoint` value if + // set. + CheckpointStep param.Opt[int64] `query:"checkpoint_step,omitzero" json:"-"` + // Specifies checkpoint type to download - `merged` vs `adapter`. This field is + // required if the checkpoint_step is not set. + // + // Any of "merged", "adapter", "model_output_path". + Checkpoint FineTuningContentParamsCheckpoint `query:"checkpoint,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [FineTuningContentParams]'s query parameters as +// `url.Values`. +func (r FineTuningContentParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatComma, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} + +// Specifies checkpoint type to download - `merged` vs `adapter`. This field is +// required if the checkpoint_step is not set. +type FineTuningContentParamsCheckpoint string + +const ( + FineTuningContentParamsCheckpointMerged FineTuningContentParamsCheckpoint = "merged" + FineTuningContentParamsCheckpointAdapter FineTuningContentParamsCheckpoint = "adapter" + FineTuningContentParamsCheckpointModelOutputPath FineTuningContentParamsCheckpoint = "model_output_path" +) diff --git a/finetune_test.go b/finetuning_test.go similarity index 62% rename from finetune_test.go rename to finetuning_test.go index 2421507e..ab5ef8a4 100644 --- a/finetune_test.go +++ b/finetuning_test.go @@ -3,8 +3,12 @@ package together_test import ( + "bytes" "context" "errors" + "io" + "net/http" + "net/http/httptest" "os" "testing" @@ -13,7 +17,7 @@ import ( "github.com/togethercomputer/together-go/option" ) -func TestFineTuneNewWithOptionalParams(t *testing.T) { +func TestFineTuningNewWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -25,10 +29,10 @@ func TestFineTuneNewWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.New(context.TODO(), together.FineTuneNewParams{ + _, err := client.FineTuning.New(context.TODO(), together.FineTuningNewParams{ Model: "model", TrainingFile: "training_file", - BatchSize: together.FineTuneNewParamsBatchSizeUnion{ + BatchSize: together.FineTuningNewParamsBatchSizeUnion{ OfInt: together.Int(0), }, FromCheckpoint: together.String("from_checkpoint"), @@ -37,10 +41,10 @@ func TestFineTuneNewWithOptionalParams(t *testing.T) { HfModelRevision: together.String("hf_model_revision"), HfOutputRepoName: together.String("hf_output_repo_name"), LearningRate: together.Float(0), - LrScheduler: together.LrSchedulerParam{ - LrSchedulerType: together.LrSchedulerLrSchedulerTypeLinear, - LrSchedulerArgs: together.LrSchedulerLrSchedulerArgsUnionParam{ - OfLinearLrSchedulerArgs: &together.LinearLrSchedulerArgsParam{ + LrScheduler: together.FineTuningNewParamsLrScheduler{ + LrSchedulerType: "linear", + LrSchedulerArgs: together.FineTuningNewParamsLrSchedulerLrSchedulerArgsUnion{ + OfFineTuningNewsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs: &together.FineTuningNewParamsLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs{ MinLrRatio: together.Float(0), }, }, @@ -50,20 +54,20 @@ func TestFineTuneNewWithOptionalParams(t *testing.T) { NEpochs: together.Int(0), NEvals: together.Int(0), Suffix: together.String("suffix"), - TrainOnInputs: together.FineTuneNewParamsTrainOnInputsUnion{ + TrainOnInputs: together.FineTuningNewParamsTrainOnInputsUnion{ OfBool: together.Bool(true), }, - TrainingMethod: together.FineTuneNewParamsTrainingMethodUnion{ - OfTrainingMethodSft: &together.TrainingMethodSftParam{ - Method: together.TrainingMethodSftMethodSft, - TrainOnInputs: together.TrainingMethodSftTrainOnInputsUnionParam{ + TrainingMethod: together.FineTuningNewParamsTrainingMethodUnion{ + OfFineTuningNewsTrainingMethodTrainingMethodSft: &together.FineTuningNewParamsTrainingMethodTrainingMethodSft{ + Method: "sft", + TrainOnInputs: together.FineTuningNewParamsTrainingMethodTrainingMethodSftTrainOnInputsUnion{ OfBool: together.Bool(true), }, }, }, - TrainingType: together.FineTuneNewParamsTrainingTypeUnion{ - OfFullTrainingType: &together.FullTrainingTypeParam{ - Type: together.FullTrainingTypeTypeFull, + TrainingType: together.FineTuningNewParamsTrainingTypeUnion{ + OfFineTuningNewsTrainingTypeFullTrainingType: &together.FineTuningNewParamsTrainingTypeFullTrainingType{ + Type: "Full", }, }, ValidationFile: together.String("validation_file"), @@ -83,7 +87,7 @@ func TestFineTuneNewWithOptionalParams(t *testing.T) { } } -func TestFineTuneGet(t *testing.T) { +func TestFineTuningGet(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -95,7 +99,7 @@ func TestFineTuneGet(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.Get(context.TODO(), "id") + _, err := client.FineTuning.Get(context.TODO(), "id") if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -105,7 +109,7 @@ func TestFineTuneGet(t *testing.T) { } } -func TestFineTuneList(t *testing.T) { +func TestFineTuningList(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -117,7 +121,7 @@ func TestFineTuneList(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.List(context.TODO()) + _, err := client.FineTuning.List(context.TODO()) if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -127,7 +131,7 @@ func TestFineTuneList(t *testing.T) { } } -func TestFineTuneCancel(t *testing.T) { +func TestFineTuningDeleteWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -139,7 +143,13 @@ func TestFineTuneCancel(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.Cancel(context.TODO(), "id") + _, err := client.FineTuning.Delete( + context.TODO(), + "id", + together.FineTuningDeleteParams{ + Force: together.Bool(true), + }, + ) if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -149,7 +159,7 @@ func TestFineTuneCancel(t *testing.T) { } } -func TestFineTuneDownloadWithOptionalParams(t *testing.T) { +func TestFineTuningCancel(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -161,11 +171,31 @@ func TestFineTuneDownloadWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.Download(context.TODO(), together.FineTuneDownloadParams{ + _, err := client.FineTuning.Cancel(context.TODO(), "id") + if err != nil { + var apierr *together.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} + +func TestFineTuningContentWithOptionalParams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("abc")) + })) + defer server.Close() + baseURL := server.URL + client := together.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + resp, err := client.FineTuning.Content(context.TODO(), together.FineTuningContentParams{ FtID: "ft_id", - Checkpoint: together.FineTuneDownloadParamsCheckpointMerged, + Checkpoint: together.FineTuningContentParamsCheckpointMerged, CheckpointStep: together.Int(0), - Output: together.String("output"), }) if err != nil { var apierr *together.Error @@ -174,9 +204,22 @@ func TestFineTuneDownloadWithOptionalParams(t *testing.T) { } t.Fatalf("err should be nil: %s", err.Error()) } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + var apierr *together.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } + if !bytes.Equal(b, []byte("abc")) { + t.Fatalf("return value not %s: %s", "abc", b) + } } -func TestFineTuneListEvents(t *testing.T) { +func TestFineTuningListCheckpoints(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -188,7 +231,7 @@ func TestFineTuneListEvents(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.ListEvents(context.TODO(), "id") + _, err := client.FineTuning.ListCheckpoints(context.TODO(), "id") if err != nil { var apierr *together.Error if errors.As(err, &apierr) { @@ -198,7 +241,7 @@ func TestFineTuneListEvents(t *testing.T) { } } -func TestFineTuneGetCheckpoints(t *testing.T) { +func TestFineTuningListEvents(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -210,7 +253,7 @@ func TestFineTuneGetCheckpoints(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.FineTune.GetCheckpoints(context.TODO(), "id") + _, err := client.FineTuning.ListEvents(context.TODO(), "id") if err != nil { var apierr *together.Error if errors.As(err, &apierr) { diff --git a/go.mod b/go.mod index ceb67072..1781c21e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/togethercomputer/together-go go 1.22 require ( - github.com/tidwall/gjson v1.14.4 + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go.sum b/go.sum index a70a5e0a..32ba293d 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= -github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= diff --git a/image.go b/image.go index 5e8aff9f..8dd0cdc4 100644 --- a/image.go +++ b/image.go @@ -35,7 +35,7 @@ func NewImageService(opts ...option.RequestOption) (r ImageService) { } // Use an image model to generate an image for a given prompt. -func (r *ImageService) New(ctx context.Context, body ImageNewParams, opts ...option.RequestOption) (res *ImageFile, err error) { +func (r *ImageService) Generate(ctx context.Context, body ImageGenerateParams, opts ...option.RequestOption) (res *ImageFile, err error) { opts = slices.Concat(r.Options, opts) path := "images/generations" err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) @@ -192,11 +192,11 @@ const ( ImageFileObjectList ImageFileObject = "list" ) -type ImageNewParams struct { +type ImageGenerateParams struct { // The model to use for image generation. // // [See all of Together AI's image models](https://docs.together.ai/docs/serverless-models#image-models) - Model ImageNewParamsModel `json:"model,omitzero,required"` + Model ImageGenerateParamsModel `json:"model,omitzero,required"` // A description of the desired images. Maximum length varies by model. Prompt string `json:"prompt,required"` // If true, disables the safety checker for image generation. @@ -221,40 +221,40 @@ type ImageNewParams struct { Width param.Opt[int64] `json:"width,omitzero"` // An array of objects that define LoRAs (Low-Rank Adaptations) to influence the // generated image. - ImageLoras []ImageNewParamsImageLora `json:"image_loras,omitzero"` + ImageLoras []ImageGenerateParamsImageLora `json:"image_loras,omitzero"` // The format of the image response. Can be either be `jpeg` or `png`. Defaults to // `jpeg`. // // Any of "jpeg", "png". - OutputFormat ImageNewParamsOutputFormat `json:"output_format,omitzero"` + OutputFormat ImageGenerateParamsOutputFormat `json:"output_format,omitzero"` // Format of the image response. Can be either a base64 string or a URL. // // Any of "base64", "url". - ResponseFormat ImageNewParamsResponseFormat `json:"response_format,omitzero"` + ResponseFormat ImageGenerateParamsResponseFormat `json:"response_format,omitzero"` paramObj } -func (r ImageNewParams) MarshalJSON() (data []byte, err error) { - type shadow ImageNewParams +func (r ImageGenerateParams) MarshalJSON() (data []byte, err error) { + type shadow ImageGenerateParams return param.MarshalObject(r, (*shadow)(&r)) } -func (r *ImageNewParams) UnmarshalJSON(data []byte) error { +func (r *ImageGenerateParams) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // The model to use for image generation. // // [See all of Together AI's image models](https://docs.together.ai/docs/serverless-models#image-models) -type ImageNewParamsModel string +type ImageGenerateParamsModel string const ( - ImageNewParamsModelBlackForestLabsFlux1SchnellFree ImageNewParamsModel = "black-forest-labs/FLUX.1-schnell-Free" - ImageNewParamsModelBlackForestLabsFlux1Schnell ImageNewParamsModel = "black-forest-labs/FLUX.1-schnell" - ImageNewParamsModelBlackForestLabsFlux1_1Pro ImageNewParamsModel = "black-forest-labs/FLUX.1.1-pro" + ImageGenerateParamsModelBlackForestLabsFlux1SchnellFree ImageGenerateParamsModel = "black-forest-labs/FLUX.1-schnell-Free" + ImageGenerateParamsModelBlackForestLabsFlux1Schnell ImageGenerateParamsModel = "black-forest-labs/FLUX.1-schnell" + ImageGenerateParamsModelBlackForestLabsFlux1_1Pro ImageGenerateParamsModel = "black-forest-labs/FLUX.1.1-pro" ) // The properties Path, Scale are required. -type ImageNewParamsImageLora struct { +type ImageGenerateParamsImageLora struct { // The URL of the LoRA to apply (e.g. // https://huggingface.co/strangerzonehf/Flux-Midjourney-Mix2-LoRA). Path string `json:"path,required"` @@ -263,27 +263,27 @@ type ImageNewParamsImageLora struct { paramObj } -func (r ImageNewParamsImageLora) MarshalJSON() (data []byte, err error) { - type shadow ImageNewParamsImageLora +func (r ImageGenerateParamsImageLora) MarshalJSON() (data []byte, err error) { + type shadow ImageGenerateParamsImageLora return param.MarshalObject(r, (*shadow)(&r)) } -func (r *ImageNewParamsImageLora) UnmarshalJSON(data []byte) error { +func (r *ImageGenerateParamsImageLora) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // The format of the image response. Can be either be `jpeg` or `png`. Defaults to // `jpeg`. -type ImageNewParamsOutputFormat string +type ImageGenerateParamsOutputFormat string const ( - ImageNewParamsOutputFormatJpeg ImageNewParamsOutputFormat = "jpeg" - ImageNewParamsOutputFormatPng ImageNewParamsOutputFormat = "png" + ImageGenerateParamsOutputFormatJpeg ImageGenerateParamsOutputFormat = "jpeg" + ImageGenerateParamsOutputFormatPng ImageGenerateParamsOutputFormat = "png" ) // Format of the image response. Can be either a base64 string or a URL. -type ImageNewParamsResponseFormat string +type ImageGenerateParamsResponseFormat string const ( - ImageNewParamsResponseFormatBase64 ImageNewParamsResponseFormat = "base64" - ImageNewParamsResponseFormatURL ImageNewParamsResponseFormat = "url" + ImageGenerateParamsResponseFormatBase64 ImageGenerateParamsResponseFormat = "base64" + ImageGenerateParamsResponseFormatURL ImageGenerateParamsResponseFormat = "url" ) diff --git a/image_test.go b/image_test.go index a7a1a273..7bba4713 100644 --- a/image_test.go +++ b/image_test.go @@ -13,7 +13,7 @@ import ( "github.com/togethercomputer/together-go/option" ) -func TestImageNewWithOptionalParams(t *testing.T) { +func TestImageGenerateWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -25,21 +25,21 @@ func TestImageNewWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Images.New(context.TODO(), together.ImageNewParams{ - Model: together.ImageNewParamsModelBlackForestLabsFlux1SchnellFree, + _, err := client.Images.Generate(context.TODO(), together.ImageGenerateParams{ + Model: together.ImageGenerateParamsModelBlackForestLabsFlux1SchnellFree, Prompt: "cat floating in space, cinematic", DisableSafetyChecker: together.Bool(true), GuidanceScale: together.Float(0), Height: together.Int(0), - ImageLoras: []together.ImageNewParamsImageLora{{ + ImageLoras: []together.ImageGenerateParamsImageLora{{ Path: "path", Scale: 0, }}, ImageURL: together.String("image_url"), N: together.Int(0), NegativePrompt: together.String("negative_prompt"), - OutputFormat: together.ImageNewParamsOutputFormatJpeg, - ResponseFormat: together.ImageNewParamsResponseFormatBase64, + OutputFormat: together.ImageGenerateParamsOutputFormatJpeg, + ResponseFormat: together.ImageGenerateParamsResponseFormatBase64, Seed: together.Int(0), Steps: together.Int(0), Width: together.Int(0), diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index 3ab6a5c3..2d676174 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -164,7 +164,7 @@ func NewRequestConfig(ctx context.Context, method string, u string, body any, ds req.Header.Add(k, v) } cfg := RequestConfig{ - MaxRetries: 5, + MaxRetries: 2, Context: ctx, Request: req, HTTPClient: http.DefaultClient, diff --git a/internal/version.go b/internal/version.go index d6f40b32..2d1d85e1 100644 --- a/internal/version.go +++ b/internal/version.go @@ -2,4 +2,4 @@ package internal -const PackageVersion = "0.1.0-alpha.2" // x-release-please-version +const PackageVersion = "0.1.0-alpha.3" // x-release-please-version diff --git a/model.go b/model.go index 4a910111..2d042834 100644 --- a/model.go +++ b/model.go @@ -34,7 +34,7 @@ func NewModelService(opts ...option.RequestOption) (r ModelService) { } // Lists all of Together's open-source models -func (r *ModelService) List(ctx context.Context, opts ...option.RequestOption) (res *[]ModelListResponse, err error) { +func (r *ModelService) List(ctx context.Context, opts ...option.RequestOption) (res *[]ModelObject, err error) { opts = slices.Concat(r.Options, opts) path := "models" err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...) @@ -49,18 +49,18 @@ func (r *ModelService) Upload(ctx context.Context, body ModelUploadParams, opts return } -type ModelListResponse struct { +type ModelObject struct { ID string `json:"id,required"` Created int64 `json:"created,required"` Object string `json:"object,required"` // Any of "chat", "language", "code", "image", "embedding", "moderation", "rerank". - Type ModelListResponseType `json:"type,required"` - ContextLength int64 `json:"context_length"` - DisplayName string `json:"display_name"` - License string `json:"license"` - Link string `json:"link"` - Organization string `json:"organization"` - Pricing ModelListResponsePricing `json:"pricing"` + Type ModelObjectType `json:"type,required"` + ContextLength int64 `json:"context_length"` + DisplayName string `json:"display_name"` + License string `json:"license"` + Link string `json:"link"` + Organization string `json:"organization"` + Pricing ModelObjectPricing `json:"pricing"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { ID respjson.Field @@ -79,24 +79,24 @@ type ModelListResponse struct { } // Returns the unmodified JSON received from the API -func (r ModelListResponse) RawJSON() string { return r.JSON.raw } -func (r *ModelListResponse) UnmarshalJSON(data []byte) error { +func (r ModelObject) RawJSON() string { return r.JSON.raw } +func (r *ModelObject) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type ModelListResponseType string +type ModelObjectType string const ( - ModelListResponseTypeChat ModelListResponseType = "chat" - ModelListResponseTypeLanguage ModelListResponseType = "language" - ModelListResponseTypeCode ModelListResponseType = "code" - ModelListResponseTypeImage ModelListResponseType = "image" - ModelListResponseTypeEmbedding ModelListResponseType = "embedding" - ModelListResponseTypeModeration ModelListResponseType = "moderation" - ModelListResponseTypeRerank ModelListResponseType = "rerank" + ModelObjectTypeChat ModelObjectType = "chat" + ModelObjectTypeLanguage ModelObjectType = "language" + ModelObjectTypeCode ModelObjectType = "code" + ModelObjectTypeImage ModelObjectType = "image" + ModelObjectTypeEmbedding ModelObjectType = "embedding" + ModelObjectTypeModeration ModelObjectType = "moderation" + ModelObjectTypeRerank ModelObjectType = "rerank" ) -type ModelListResponsePricing struct { +type ModelObjectPricing struct { Base float64 `json:"base,required"` Finetune float64 `json:"finetune,required"` Hourly float64 `json:"hourly,required"` @@ -115,8 +115,8 @@ type ModelListResponsePricing struct { } // Returns the unmodified JSON received from the API -func (r ModelListResponsePricing) RawJSON() string { return r.JSON.raw } -func (r *ModelListResponsePricing) UnmarshalJSON(data []byte) error { +func (r ModelObjectPricing) RawJSON() string { return r.JSON.raw } +func (r *ModelObjectPricing) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } diff --git a/packages/respjson/respjson.go b/packages/respjson/respjson.go index cc0088ca..9e61c5ca 100644 --- a/packages/respjson/respjson.go +++ b/packages/respjson/respjson.go @@ -5,7 +5,7 @@ package respjson // Use [Field.Valid] to check if an optional value was null or omitted. // // A Field will always occur in the following structure, where it -// mirrors the original field in it's parent struct: +// mirrors the original field in its parent struct: // // type ExampleObject struct { // Foo bool `json:"foo"` diff --git a/together.go b/rerank.go similarity index 52% rename from together.go rename to rerank.go index db9a8f1d..02f204a5 100644 --- a/together.go +++ b/rerank.go @@ -3,19 +3,52 @@ package together import ( + "context" + "net/http" + "slices" + "github.com/togethercomputer/together-go/internal/apijson" + "github.com/togethercomputer/together-go/internal/requestconfig" + "github.com/togethercomputer/together-go/option" "github.com/togethercomputer/together-go/packages/param" "github.com/togethercomputer/together-go/packages/respjson" ) -type RerankResponse struct { +// RerankService contains methods and other services that help with interacting +// with the together API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewRerankService] method instead. +type RerankService struct { + Options []option.RequestOption +} + +// NewRerankService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewRerankService(opts ...option.RequestOption) (r RerankService) { + r = RerankService{} + r.Options = opts + return +} + +// Query a reranker model +func (r *RerankService) New(ctx context.Context, body RerankNewParams, opts ...option.RequestOption) (res *RerankNewResponse, err error) { + opts = slices.Concat(r.Options, opts) + path := "rerank" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +type RerankNewResponse struct { // The model to be used for the rerank request. Model string `json:"model,required"` // Object type // // Any of "rerank". - Object RerankResponseObject `json:"object,required"` - Results []RerankResponseResult `json:"results,required"` + Object RerankNewResponseObject `json:"object,required"` + Results []RerankNewResponseResult `json:"results,required"` // Request ID ID string `json:"id"` Usage ChatCompletionUsage `json:"usage,nullable"` @@ -32,22 +65,22 @@ type RerankResponse struct { } // Returns the unmodified JSON received from the API -func (r RerankResponse) RawJSON() string { return r.JSON.raw } -func (r *RerankResponse) UnmarshalJSON(data []byte) error { +func (r RerankNewResponse) RawJSON() string { return r.JSON.raw } +func (r *RerankNewResponse) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // Object type -type RerankResponseObject string +type RerankNewResponseObject string const ( - RerankResponseObjectRerank RerankResponseObject = "rerank" + RerankNewResponseObjectRerank RerankNewResponseObject = "rerank" ) -type RerankResponseResult struct { - Document RerankResponseResultDocument `json:"document,required"` - Index int64 `json:"index,required"` - RelevanceScore float64 `json:"relevance_score,required"` +type RerankNewResponseResult struct { + Document RerankNewResponseResultDocument `json:"document,required"` + Index int64 `json:"index,required"` + RelevanceScore float64 `json:"relevance_score,required"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { Document respjson.Field @@ -59,12 +92,12 @@ type RerankResponseResult struct { } // Returns the unmodified JSON received from the API -func (r RerankResponseResult) RawJSON() string { return r.JSON.raw } -func (r *RerankResponseResult) UnmarshalJSON(data []byte) error { +func (r RerankNewResponseResult) RawJSON() string { return r.JSON.raw } +func (r *RerankNewResponseResult) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type RerankResponseResultDocument struct { +type RerankNewResponseResultDocument struct { Text string `json:"text,nullable"` // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. JSON struct { @@ -75,18 +108,18 @@ type RerankResponseResultDocument struct { } // Returns the unmodified JSON received from the API -func (r RerankResponseResultDocument) RawJSON() string { return r.JSON.raw } -func (r *RerankResponseResultDocument) UnmarshalJSON(data []byte) error { +func (r RerankNewResponseResultDocument) RawJSON() string { return r.JSON.raw } +func (r *RerankNewResponseResultDocument) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type RerankParams struct { +type RerankNewParams struct { // List of documents, which can be either strings or objects. - Documents RerankParamsDocumentsUnion `json:"documents,omitzero,required"` + Documents RerankNewParamsDocumentsUnion `json:"documents,omitzero,required"` // The model to be used for the rerank request. // // [See all of Together AI's rerank models](https://docs.together.ai/docs/serverless-models#rerank-models) - Model RerankParamsModel `json:"model,omitzero,required"` + Model RerankNewParamsModel `json:"model,omitzero,required"` // The search query to be used for ranking. Query string `json:"query,required"` // Whether to return supplied documents with the response. @@ -99,31 +132,31 @@ type RerankParams struct { paramObj } -func (r RerankParams) MarshalJSON() (data []byte, err error) { - type shadow RerankParams +func (r RerankNewParams) MarshalJSON() (data []byte, err error) { + type shadow RerankNewParams return param.MarshalObject(r, (*shadow)(&r)) } -func (r *RerankParams) UnmarshalJSON(data []byte) error { +func (r *RerankNewParams) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } // Only one field can be non-zero. // // Use [param.IsOmitted] to confirm if a field is set. -type RerankParamsDocumentsUnion struct { +type RerankNewParamsDocumentsUnion struct { OfMapOfAnyMap []map[string]any `json:",omitzero,inline"` OfStringArray []string `json:",omitzero,inline"` paramUnion } -func (u RerankParamsDocumentsUnion) MarshalJSON() ([]byte, error) { +func (u RerankNewParamsDocumentsUnion) MarshalJSON() ([]byte, error) { return param.MarshalUnion(u, u.OfMapOfAnyMap, u.OfStringArray) } -func (u *RerankParamsDocumentsUnion) UnmarshalJSON(data []byte) error { +func (u *RerankNewParamsDocumentsUnion) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, u) } -func (u *RerankParamsDocumentsUnion) asAny() any { +func (u *RerankNewParamsDocumentsUnion) asAny() any { if !param.IsOmitted(u.OfMapOfAnyMap) { return &u.OfMapOfAnyMap } else if !param.IsOmitted(u.OfStringArray) { @@ -135,8 +168,8 @@ func (u *RerankParamsDocumentsUnion) asAny() any { // The model to be used for the rerank request. // // [See all of Together AI's rerank models](https://docs.together.ai/docs/serverless-models#rerank-models) -type RerankParamsModel string +type RerankNewParamsModel string const ( - RerankParamsModelSalesforceLlamaRankV1 RerankParamsModel = "Salesforce/Llama-Rank-v1" + RerankNewParamsModelSalesforceLlamaRankV1 RerankNewParamsModel = "Salesforce/Llama-Rank-v1" ) diff --git a/together_test.go b/rerank_test.go similarity index 82% rename from together_test.go rename to rerank_test.go index f44e5e0d..4d0c7fad 100644 --- a/together_test.go +++ b/rerank_test.go @@ -13,7 +13,7 @@ import ( "github.com/togethercomputer/together-go/option" ) -func TestTogetherRerankWithOptionalParams(t *testing.T) { +func TestRerankNewWithOptionalParams(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { baseURL = envURL @@ -25,8 +25,8 @@ func TestTogetherRerankWithOptionalParams(t *testing.T) { option.WithBaseURL(baseURL), option.WithAPIKey("My API Key"), ) - _, err := client.Rerank(context.TODO(), together.RerankParams{ - Documents: together.RerankParamsDocumentsUnion{ + _, err := client.Rerank.New(context.TODO(), together.RerankNewParams{ + Documents: together.RerankNewParamsDocumentsUnion{ OfMapOfAnyMap: []map[string]any{{ "title": "bar", "text": "bar", @@ -41,7 +41,7 @@ func TestTogetherRerankWithOptionalParams(t *testing.T) { "text": "bar", }}, }, - Model: together.RerankParamsModelSalesforceLlamaRankV1, + Model: together.RerankNewParamsModelSalesforceLlamaRankV1, Query: "What animals can I find near Peru?", RankFields: []string{"title", "text"}, ReturnDocuments: together.Bool(true), diff --git a/video.go b/video.go index a9dd20b2..8ca8c25f 100644 --- a/video.go +++ b/video.go @@ -36,7 +36,7 @@ func NewVideoService(opts ...option.RequestOption) (r VideoService) { } // Create a video -func (r *VideoService) New(ctx context.Context, body VideoNewParams, opts ...option.RequestOption) (res *VideoNewResponse, err error) { +func (r *VideoService) New(ctx context.Context, body VideoNewParams, opts ...option.RequestOption) (res *VideoJob, err error) { opts = slices.Concat(r.Options, opts) opts = append([]option.RequestOption{option.WithBaseURL("https://api.together.xyz/v2/")}, opts...) path := "videos" @@ -164,23 +164,6 @@ func (r *VideoJobOutputs) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } -type VideoNewResponse struct { - // Unique identifier for the video job. - ID string `json:"id,required"` - // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. - JSON struct { - ID respjson.Field - ExtraFields map[string]respjson.Field - raw string - } `json:"-"` -} - -// Returns the unmodified JSON received from the API -func (r VideoNewResponse) RawJSON() string { return r.JSON.raw } -func (r *VideoNewResponse) UnmarshalJSON(data []byte) error { - return apijson.UnmarshalRoot(data, r) -} - type VideoNewParams struct { // The model to be used for the video creation request. Model string `json:"model,required"`