diff --git a/common/testutil.go b/common/testutil.go index 9f0dc69f90..1b4155efba 100644 --- a/common/testutil.go +++ b/common/testutil.go @@ -93,11 +93,23 @@ func IgnoreRoutines() []goleak.Option { "github.com/livepeer/go-livepeer/core.(*Balances).StartCleanup", "internal/synctest.Run", "testing/synctest.testingSynctestTest", + "github.com/livepeer/go-livepeer/core.(*Balances).StartCleanup", "github.com/livepeer/go-livepeer/server.startTrickleSubscribe.func2", "github.com/livepeer/go-livepeer/server.startTrickleSubscribe", + "net/http.(*persistConn).writeLoop", "net/http.(*persistConn).readLoop", "io.(*pipe).read", + "github.com/livepeer/go-livepeer/media.gatherIncomingTracks", } - res := make([]goleak.Option, 0, len(funcs2ignore)) + // Functions that might have other functions on top of their stack (like time.Sleep) + // These need to be ignored with IgnoreAnyFunction instead of IgnoreTopFunction + funcsAnyIgnore := []string{ + "github.com/livepeer/go-livepeer/server.ffmpegOutput", + } + + res := make([]goleak.Option, 0, len(funcs2ignore)+len(funcsAnyIgnore)) for _, f := range funcs2ignore { res = append(res, goleak.IgnoreTopFunction(f)) } + for _, f := range funcsAnyIgnore { + res = append(res, goleak.IgnoreAnyFunction(f)) + } return res } diff --git a/core/ai_orchestrator.go b/core/ai_orchestrator.go index 6801ff5c7d..84379a3b1e 100644 --- a/core/ai_orchestrator.go +++ b/core/ai_orchestrator.go @@ -1200,6 +1200,12 @@ func (orch *orchestrator) JobPriceInfo(sender ethcommon.Address, jobCapability s return nil, err } + //ensure price numerator and denominator can be int64 + jobPrice, err = common.PriceToInt64(jobPrice) + if err != nil { + return nil, fmt.Errorf("invalid job price: %w", err) + } + return &net.PriceInfo{ PricePerUnit: jobPrice.Num().Int64(), PixelsPerUnit: jobPrice.Denom().Int64(), diff --git a/core/external_capabilities.go b/core/external_capabilities.go index f5802db1d8..69c6ddafde 100644 --- a/core/external_capabilities.go +++ b/core/external_capabilities.go @@ -1,15 +1,33 @@ package core import ( + "context" "encoding/json" "fmt" "math/big" "sync" + ethcommon "github.com/ethereum/go-ethereum/common" "github.com/golang/glog" + "github.com/livepeer/go-livepeer/net" + "github.com/livepeer/go-livepeer/trickle" ) +type JobToken struct { + SenderAddress *JobSender `json:"sender_address,omitempty"` + TicketParams *net.TicketParams `json:"ticket_params,omitempty"` + Balance int64 `json:"balance,omitempty"` + Price *net.PriceInfo `json:"price,omitempty"` + ServiceAddr string `json:"service_addr,omitempty"` + + LastNonce uint32 +} +type JobSender struct { + Addr string `json:"addr"` + Sig string `json:"sig"` +} + type ExternalCapability struct { Name string `json:"name"` Description string `json:"description"` @@ -25,13 +43,144 @@ type ExternalCapability struct { Load int } +type StreamInfo struct { + StreamID string + Capability string + + //Orchestrator fields + Sender ethcommon.Address + StreamRequest []byte + pubChannel *trickle.TrickleLocalPublisher + subChannel *trickle.TrickleLocalPublisher + controlChannel *trickle.TrickleLocalPublisher + eventsChannel *trickle.TrickleLocalPublisher + dataChannel *trickle.TrickleLocalPublisher + //Stream fields + JobParams string + StreamCtx context.Context + CancelStream context.CancelFunc + + sdm sync.Mutex +} + +func (sd *StreamInfo) IsActive() bool { + sd.sdm.Lock() + defer sd.sdm.Unlock() + if sd.StreamCtx.Err() != nil { + return false + } + + if sd.controlChannel == nil { + return false + } + + return true +} + +func (sd *StreamInfo) UpdateParams(params string) { + sd.sdm.Lock() + defer sd.sdm.Unlock() + sd.JobParams = params +} + +func (sd *StreamInfo) SetChannels(pub, sub, control, events, data *trickle.TrickleLocalPublisher) { + sd.sdm.Lock() + defer sd.sdm.Unlock() + sd.pubChannel = pub + sd.subChannel = sub + sd.controlChannel = control + sd.eventsChannel = events + sd.dataChannel = data +} + type ExternalCapabilities struct { capm sync.Mutex Capabilities map[string]*ExternalCapability + Streams map[string]*StreamInfo } func NewExternalCapabilities() *ExternalCapabilities { - return &ExternalCapabilities{Capabilities: make(map[string]*ExternalCapability)} + return &ExternalCapabilities{Capabilities: make(map[string]*ExternalCapability), + Streams: make(map[string]*StreamInfo), + } +} + +func (extCaps *ExternalCapabilities) AddStream(streamID string, capability string, streamReq []byte) (*StreamInfo, error) { + extCaps.capm.Lock() + defer extCaps.capm.Unlock() + _, ok := extCaps.Streams[streamID] + if ok { + return nil, fmt.Errorf("stream already exists: %s", streamID) + } + + //add to streams + ctx, cancel := context.WithCancel(context.Background()) + stream := StreamInfo{ + StreamID: streamID, + Capability: capability, + StreamRequest: streamReq, + StreamCtx: ctx, + CancelStream: cancel, + } + extCaps.Streams[streamID] = &stream + + //clean up when stream ends + go func() { + <-ctx.Done() + + //orchestrator channels shutdown + if stream.pubChannel != nil { + if err := stream.pubChannel.Close(); err != nil { + glog.Errorf("error closing pubChannel for stream=%s: %v", streamID, err) + } + } + if stream.subChannel != nil { + if err := stream.subChannel.Close(); err != nil { + glog.Errorf("error closing subChannel for stream=%s: %v", streamID, err) + } + } + if stream.controlChannel != nil { + if err := stream.controlChannel.Close(); err != nil { + glog.Errorf("error closing controlChannel for stream=%s: %v", streamID, err) + } + } + if stream.eventsChannel != nil { + if err := stream.eventsChannel.Close(); err != nil { + glog.Errorf("error closing eventsChannel for stream=%s: %v", streamID, err) + } + } + if stream.dataChannel != nil { + if err := stream.dataChannel.Close(); err != nil { + glog.Errorf("error closing dataChannel for stream=%s: %v", streamID, err) + } + } + return + }() + + return &stream, nil +} + +func (extCaps *ExternalCapabilities) RemoveStream(streamID string) { + extCaps.capm.Lock() + defer extCaps.capm.Unlock() + + streamInfo, ok := extCaps.Streams[streamID] + if ok { + //confirm stream context is canceled before deleting + if streamInfo.StreamCtx.Err() == nil { + streamInfo.CancelStream() + } + } + + delete(extCaps.Streams, streamID) +} + +func (extCaps *ExternalCapabilities) StreamExists(streamID string) bool { + extCaps.capm.Lock() + defer extCaps.capm.Unlock() + + _, ok := extCaps.Streams[streamID] + return ok } func (extCaps *ExternalCapabilities) RemoveCapability(extCap string) { diff --git a/core/livepeernode.go b/core/livepeernode.go index fdb501f0ca..b4a2bd1075 100644 --- a/core/livepeernode.go +++ b/core/livepeernode.go @@ -10,6 +10,7 @@ orchestrator.go: Code that is called only when the node is in orchestrator mode. package core import ( + "context" "errors" "math/big" "math/rand" @@ -19,6 +20,7 @@ import ( "time" "github.com/golang/glog" + "github.com/livepeer/go-livepeer/media" "github.com/livepeer/go-livepeer/pm" "github.com/livepeer/go-livepeer/trickle" @@ -183,6 +185,54 @@ type LivePipeline struct { ControlPub *trickle.TricklePublisher StopControl func() ReportUpdate func([]byte) + DataWriter *media.SegmentWriter + + StreamCtx context.Context + streamCancel context.CancelCauseFunc + streamParams interface{} + streamRequest []byte +} + +func (n *LivepeerNode) NewLivePipeline(requestID, streamID, pipeline string, streamParams interface{}, streamRequest []byte) *LivePipeline { + streamCtx, streamCancel := context.WithCancelCause(context.Background()) + n.LiveMu.Lock() + defer n.LiveMu.Unlock() + n.LivePipelines[streamID] = &LivePipeline{ + RequestID: requestID, + StreamID: streamID, + Pipeline: pipeline, + StreamCtx: streamCtx, + streamParams: streamParams, + streamCancel: streamCancel, + streamRequest: streamRequest, + } + return n.LivePipelines[streamID] +} + +func (n *LivepeerNode) RemoveLivePipeline(streamID string) { + n.LiveMu.Lock() + defer n.LiveMu.Unlock() + delete(n.LivePipelines, streamID) +} + +func (p *LivePipeline) StreamParams() interface{} { + return p.streamParams +} + +func (p *LivePipeline) UpdateStreamParams(newParams interface{}) { + p.streamParams = newParams +} + +func (p *LivePipeline) StreamRequest() []byte { + return p.streamRequest +} + +func (p *LivePipeline) StopStream(err error) { + if p.StopControl != nil { + p.StopControl() + } + + p.streamCancel(err) } // NewLivepeerNode creates a new Livepeer Node. Eth can be nil. diff --git a/server/ai_live_video.go b/server/ai_live_video.go index 12e1898c5d..e20529ae3a 100644 --- a/server/ai_live_video.go +++ b/server/ai_live_video.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "bytes" "context" "encoding/json" @@ -81,7 +82,7 @@ func startTricklePublish(ctx context.Context, url *url.URL, params aiRequestPara ctx, cancel := context.WithCancel(ctx) priceInfo := sess.OrchestratorInfo.PriceInfo var paymentProcessor *LivePaymentProcessor - if priceInfo != nil && priceInfo.PricePerUnit != 0 { + if priceInfo != nil && priceInfo.PricePerUnit != 0 && sess.OrchestratorInfo.AuthToken != nil { paymentSender := livePaymentSender{} sendPaymentFunc := func(inPixels int64) error { return paymentSender.SendPayment(context.Background(), &SegmentInfoSender{ @@ -199,23 +200,26 @@ func suspendOrchestrator(ctx context.Context, params aiRequestParams) { // If the ingest was closed, then do not suspend the orchestrator return } - sel, err := params.sessManager.getSelector(ctx, core.Capability_LiveVideoToVideo, params.liveParams.pipeline) - if err != nil { - clog.Warningf(ctx, "Error suspending orchestrator: %v", err) - return - } - if sel == nil || sel.suspender == nil || params.liveParams == nil || params.liveParams.sess == nil || params.liveParams.sess.OrchestratorInfo == nil { - clog.Warningf(ctx, "Error suspending orchestrator: selector or suspender is nil") - return + //live-video-to-video + if params.sessManager != nil { + sel, err := params.sessManager.getSelector(ctx, core.Capability_LiveVideoToVideo, params.liveParams.pipeline) + if err != nil { + clog.Warningf(ctx, "Error suspending orchestrator: %v", err) + return + } + if sel == nil || sel.suspender == nil || params.liveParams == nil || params.liveParams.sess == nil || params.liveParams.sess.OrchestratorInfo == nil { + clog.Warningf(ctx, "Error suspending orchestrator: selector or suspender is nil") + return + } + // Remove the session from the current pool + sel.Remove(params.liveParams.sess) + sel.warmPool.mu.Lock() + sel.warmPool.selector.Remove(params.liveParams.sess.BroadcastSession) + sel.warmPool.mu.Unlock() + // We do selection every 6 min, so it effectively means the Orchestrator won't be selected for the next 30 min (unless there is no other O available) + clog.Infof(ctx, "Suspending orchestrator %s with penalty %d", params.liveParams.sess.Transcoder(), aiLiveVideoToVideoPenalty) + sel.suspender.suspend(params.liveParams.sess.Transcoder(), aiLiveVideoToVideoPenalty) } - // Remove the session from the current pool - sel.Remove(params.liveParams.sess) - sel.warmPool.mu.Lock() - sel.warmPool.selector.Remove(params.liveParams.sess.BroadcastSession) - sel.warmPool.mu.Unlock() - // We do selection every 6 min, so it effectively means the Orchestrator won't be selected for the next 30 min (unless there is no other O available) - clog.Infof(ctx, "Suspending orchestrator %s with penalty %d", params.liveParams.sess.Transcoder(), aiLiveVideoToVideoPenalty) - sel.suspender.suspend(params.liveParams.sess.Transcoder(), aiLiveVideoToVideoPenalty) } func startTrickleSubscribe(ctx context.Context, url *url.URL, params aiRequestParams, sess *AISession) { @@ -510,9 +514,10 @@ func registerControl(ctx context.Context, params aiRequestParams) { } params.node.LivePipelines[stream] = &core.LivePipeline{ - RequestID: params.liveParams.requestID, - Pipeline: params.liveParams.pipeline, - StreamID: params.liveParams.streamID, + RequestID: params.liveParams.requestID, + Pipeline: params.liveParams.pipeline, + DataWriter: params.liveParams.dataWriter, + StreamID: params.liveParams.streamID, } } @@ -784,6 +789,130 @@ func startEventsSubscribe(ctx context.Context, url *url.URL, params aiRequestPar }() } +func startDataSubscribe(ctx context.Context, url *url.URL, params aiRequestParams, sess *AISession) { + //only start DataSubscribe if enabled + if params.liveParams.dataWriter == nil { + return + } + + // subscribe to the outputs + subscriber, err := trickle.NewTrickleSubscriber(trickle.TrickleSubscriberConfig{ + URL: url.String(), + Ctx: ctx, + }) + if err != nil { + clog.Infof(ctx, "Failed to create data subscriber: %s", err) + return + } + + dataWriter := params.liveParams.dataWriter + + // read segments from trickle subscription + go func() { + defer dataWriter.Close() + + var err error + firstSegment := true + + retries := 0 + // we're trying to keep (retryPause x maxRetries) duration to fall within one output GOP length + const retryPause = 300 * time.Millisecond + const maxRetries = 5 + for { + select { + case <-ctx.Done(): + clog.Info(ctx, "data subscribe done") + return + default: + } + if !params.inputStreamExists() { + clog.Infof(ctx, "data subscribe stopping, input stream does not exist.") + break + } + var segment *http.Response + readBytes, readMessages := 0, 0 + clog.V(8).Infof(ctx, "data subscribe await") + segment, err = subscriber.Read() + if err != nil { + if errors.Is(err, trickle.EOS) || errors.Is(err, trickle.StreamNotFoundErr) { + stopProcessing(ctx, params, fmt.Errorf("data subscribe stopping, stream not found, err=%w", err)) + return + } + var sequenceNonexistent *trickle.SequenceNonexistent + if errors.As(err, &sequenceNonexistent) { + // stream exists but segment doesn't, so skip to leading edge + subscriber.SetSeq(sequenceNonexistent.Latest) + } + // TODO if not EOS then signal a new orchestrator is needed + err = fmt.Errorf("data subscribe error reading: %w", err) + clog.Infof(ctx, "%s", err) + if retries > maxRetries { + stopProcessing(ctx, params, errors.New("data subscribe stopping, retries exceeded")) + return + } + retries++ + params.liveParams.sendErrorEvent(err) + time.Sleep(retryPause) + continue + } + retries = 0 + seq := trickle.GetSeq(segment) + clog.V(8).Infof(ctx, "data subscribe received seq=%d", seq) + copyStartTime := time.Now() + + defer segment.Body.Close() + scanner := bufio.NewScanner(segment.Body) + for scanner.Scan() { + writer, err := dataWriter.Next() + clog.V(8).Infof(ctx, "data subscribe writing seq=%d", seq) + if err != nil { + if err != io.EOF { + stopProcessing(ctx, params, fmt.Errorf("data subscribe could not get next: %w", err)) + } + return + } + n, err := writer.Write(scanner.Bytes()) + if err != nil { + stopProcessing(ctx, params, fmt.Errorf("data subscribe could not write: %w", err)) + } + readBytes += n + readMessages += 1 + + writer.Close() + } + if err := scanner.Err(); err != nil { + clog.InfofErr(ctx, "data subscribe error reading seq=%d", seq, err) + subscriber.SetSeq(seq) + retries++ + continue + } + + if firstSegment { + firstSegment = false + delayMs := time.Since(params.liveParams.startTime).Milliseconds() + if monitor.Enabled { + //monitor.AIFirstSegmentDelay(delayMs, params.liveParams.sess.OrchestratorInfo) + monitor.SendQueueEventAsync("stream_trace", map[string]interface{}{ + "type": "gateway_receive_first_data_segment", + "timestamp": time.Now().UnixMilli(), + "stream_id": params.liveParams.streamID, + "pipeline_id": params.liveParams.pipelineID, + "request_id": params.liveParams.requestID, + "orchestrator_info": map[string]interface{}{ + "address": sess.Address(), + "url": sess.Transcoder(), + }, + }) + } + + clog.V(common.VERBOSE).Infof(ctx, "First Data Segment delay=%dms streamID=%s", delayMs, params.liveParams.streamID) + } + + clog.V(8).Info(ctx, "data subscribe read completed", "seq", seq, "bytes", humanize.Bytes(uint64(readBytes)), "messages", readMessages, "took", time.Since(copyStartTime)) + } + }() +} + func (a aiRequestParams) inputStreamExists() bool { if a.node == nil { return false diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index f368f9feb8..0d1c12df1d 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -98,8 +98,9 @@ func startAIMediaServer(ctx context.Context, ls *LivepeerServer) error { // Configure WHIP ingest only if an addr is specified. // TODO use a proper cli flag + var whipServer *media.WHIPServer if os.Getenv("LIVE_AI_WHIP_ADDR") != "" { - whipServer := media.NewWHIPServer() + whipServer = media.NewWHIPServer() ls.HTTPMux.Handle("POST /live/video-to-video/{stream}/whip", ls.CreateWhip(whipServer)) ls.HTTPMux.Handle("HEAD /live/video-to-video/{stream}/whip", ls.WithCode(http.StatusMethodNotAllowed)) ls.HTTPMux.Handle("OPTIONS /live/video-to-video/{stream}/whip", ls.WithCode(http.StatusNoContent)) @@ -112,6 +113,17 @@ func startAIMediaServer(ctx context.Context, ls *LivepeerServer) error { //API for dynamic capabilities ls.HTTPMux.Handle("/process/request/", ls.SubmitJob()) + ls.HTTPMux.Handle("OPTIONS /ai/stream/", ls.WithCode(http.StatusNoContent)) + ls.HTTPMux.Handle("POST /ai/stream/start", ls.StartStream()) + ls.HTTPMux.Handle("POST /ai/stream/{streamId}/stop", ls.StopStream()) + if os.Getenv("LIVE_AI_WHIP_ADDR") != "" { + ls.HTTPMux.Handle("POST /ai/stream/{streamId}/whip", ls.StartStreamWhipIngest(whipServer)) + } + ls.HTTPMux.Handle("POST /ai/stream/{streamId}/rtmp", ls.StartStreamRTMPIngest()) + ls.HTTPMux.Handle("POST /ai/stream/{streamId}/update", ls.UpdateStream()) + ls.HTTPMux.Handle("GET /ai/stream/{streamId}/status", ls.GetStreamStatus()) + ls.HTTPMux.Handle("GET /ai/stream/{streamId}/data", ls.GetStreamData()) + media.StartFileCleanup(ctx, ls.LivepeerNode.WorkDir) startHearbeats(ctx, ls.LivepeerNode) diff --git a/server/ai_process.go b/server/ai_process.go index cc50b380dd..464a6e96a3 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -96,6 +96,7 @@ type aiRequestParams struct { // For live video pipelines type liveRequestParams struct { segmentReader *media.SwitchableSegmentReader + dataWriter *media.SegmentWriter stream string requestID string streamID string @@ -131,6 +132,12 @@ type liveRequestParams struct { // when the write for the last segment started lastSegmentTime time.Time + + orchPublishUrl string + orchSubscribeUrl string + orchControlUrl string + orchEventsUrl string + orchDataUrl string } // CalculateTextToImageLatencyScore computes the time taken per pixel for an text-to-image request. diff --git a/server/job_rpc.go b/server/job_rpc.go index 88d27b33e7..eeb119cc1c 100644 --- a/server/job_rpc.go +++ b/server/job_rpc.go @@ -44,20 +44,12 @@ const jobOrchSearchTimeoutDefault = 1 * time.Second const jobOrchSearchRespTimeoutDefault = 500 * time.Millisecond var errNoTimeoutSet = errors.New("no timeout_seconds set with request, timeout_seconds is required") -var sendJobReqWithTimeout = sendReqWithTimeout - -type JobSender struct { - Addr string `json:"addr"` - Sig string `json:"sig"` -} +var errNoCapabilityCapacity = errors.New("No capacity available for capability") +var errNoJobCreds = errors.New("Could not verify job creds") +var errPaymentError = errors.New("Could not parse payment") +var errInsufficientBalance = errors.New("Insufficient balance for request") -type JobToken struct { - SenderAddress *JobSender `json:"sender_address,omitempty"` - TicketParams *net.TicketParams `json:"ticket_params,omitempty"` - Balance int64 `json:"balance,omitempty"` - Price *net.PriceInfo `json:"price,omitempty"` - ServiceAddr string `json:"service_addr,omitempty"` -} +var sendJobReqWithTimeout = sendReqWithTimeout type JobRequest struct { ID string `json:"id"` @@ -69,19 +61,63 @@ type JobRequest struct { Sig string `json:"sig"` Timeout int `json:"timeout_seconds"` - orchSearchTimeout time.Duration - orchSearchRespTimeout time.Duration + OrchSearchTimeout time.Duration + OrchSearchRespTimeout time.Duration +} +type JobRequestDetails struct { + StreamId string `json:"stream_id"` } - type JobParameters struct { + //Gateway Orchestrators JobOrchestratorsFilter `json:"orchestrators,omitempty"` //list of orchestrators to use for the job -} + //Orchestrator + EnableVideoIngress bool `json:"enable_video_ingress,omitempty"` + EnableVideoEgress bool `json:"enable_video_egress,omitempty"` + EnableDataOutput bool `json:"enable_data_output,omitempty"` +} type JobOrchestratorsFilter struct { Exclude []string `json:"exclude,omitempty"` Include []string `json:"include,omitempty"` } +type orchJob struct { + Req *JobRequest + Details *JobRequestDetails + Params *JobParameters + + //Orchestrator fields + Sender ethcommon.Address + JobPrice *net.PriceInfo +} +type gatewayJob struct { + Job *orchJob + Orchs []core.JobToken + SignedJobReq string + + node *core.LivepeerNode +} + +func (g *gatewayJob) sign() error { + //sign the request + gateway := g.node.OrchestratorPool.Broadcaster() + sig, err := gateway.Sign([]byte(g.Job.Req.Request + g.Job.Req.Parameters)) + if err != nil { + return errors.New(fmt.Sprintf("Unable to sign request err=%v", err)) + } + g.Job.Req.Sender = gateway.Address().Hex() + g.Job.Req.Sig = "0x" + hex.EncodeToString(sig) + + //create the job request header with the signature + jobReqEncoded, err := json.Marshal(g.Job.Req) + if err != nil { + return errors.New(fmt.Sprintf("Unable to encode job request err=%v", err)) + } + g.SignedJobReq = base64.StdEncoding.EncodeToString(jobReqEncoded) + + return nil +} + // worker registers to Orchestrator func (h *lphttp) RegisterCapability(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -193,7 +229,7 @@ func (h *lphttp) GetJobToken(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - jobToken := JobToken{SenderAddress: nil, TicketParams: nil, Balance: 0, Price: nil} + jobToken := core.JobToken{SenderAddress: nil, TicketParams: nil, Balance: 0, Price: nil} if !orch.CheckExternalCapabilityCapacity(jobCapsHdr) { //send response indicating no capacity available @@ -238,7 +274,7 @@ func (h *lphttp) GetJobToken(w http.ResponseWriter, r *http.Request) { capBalInt = capBalInt / 1000 } - jobToken = JobToken{ + jobToken = core.JobToken{ SenderAddress: jobSenderAddr, TicketParams: ticketParams, Balance: capBalInt, @@ -253,6 +289,53 @@ func (h *lphttp) GetJobToken(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(jobToken) } +func (ls *LivepeerServer) setupGatewayJob(ctx context.Context, r *http.Request, skipOrchSearch bool) (*gatewayJob, error) { + + var orchs []core.JobToken + + jobReqHdr := r.Header.Get(jobRequestHdr) + clog.Infof(ctx, "processing job request req=%v", jobReqHdr) + jobReq, err := verifyJobCreds(ctx, nil, jobReqHdr, true) + if err != nil { + return nil, errors.New(fmt.Sprintf("Unable to parse job request, err=%v", err)) + } + + var jobDetails JobRequestDetails + if err := json.Unmarshal([]byte(jobReq.Request), &jobDetails); err != nil { + return nil, errors.New(fmt.Sprintf("Unable to unmarshal job request err=%v", err)) + } + + var jobParams JobParameters + if err := json.Unmarshal([]byte(jobReq.Parameters), &jobParams); err != nil { + return nil, errors.New(fmt.Sprintf("Unable to unmarshal job parameters err=%v", err)) + } + + // get list of Orchestrators that can do the job if needed + // (e.g. stop requests don't need new list of orchestrators) + if !skipOrchSearch { + searchTimeout, respTimeout := getOrchSearchTimeouts(ctx, r.Header.Get(jobOrchSearchTimeoutHdr), r.Header.Get(jobOrchSearchRespTimeoutHdr)) + jobReq.OrchSearchTimeout = searchTimeout + jobReq.OrchSearchRespTimeout = respTimeout + + //get pool of Orchestrators that can do the job + orchs, err = getJobOrchestrators(ctx, ls.LivepeerNode, jobReq.Capability, jobParams, jobReq.OrchSearchTimeout, jobReq.OrchSearchRespTimeout) + if err != nil { + return nil, errors.New(fmt.Sprintf("Unable to find orchestrators for capability %v err=%v", jobReq.Capability, err)) + } + + if len(orchs) == 0 { + return nil, errors.New(fmt.Sprintf("No orchestrators found for capability %v", jobReq.Capability)) + } + } + + job := orchJob{Req: jobReq, + Details: &jobDetails, + Params: &jobParams, + } + + return &gatewayJob{Job: &job, Orchs: orchs, node: ls.LivepeerNode}, nil +} + func (h *lphttp) ProcessJob(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -280,42 +363,22 @@ func (ls *LivepeerServer) SubmitJob() http.Handler { } func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, r *http.Request) { - jobReqHdr := r.Header.Get(jobRequestHdr) - jobReq, err := verifyJobCreds(ctx, nil, jobReqHdr) + + gatewayJob, err := ls.setupGatewayJob(ctx, r, false) if err != nil { - clog.Errorf(ctx, "Unable to verify job creds err=%v", err) - http.Error(w, fmt.Sprintf("Unable to parse job request, err=%v", err), http.StatusBadRequest) + clog.Errorf(ctx, "Error setting up job: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) return } - ctx = clog.AddVal(ctx, "job_id", jobReq.ID) - ctx = clog.AddVal(ctx, "capability", jobReq.Capability) - clog.Infof(ctx, "processing job request") - searchTimeout, respTimeout := getOrchSearchTimeouts(ctx, r.Header.Get(jobOrchSearchTimeoutHdr), r.Header.Get(jobOrchSearchRespTimeoutHdr)) - jobReq.orchSearchTimeout = searchTimeout - jobReq.orchSearchRespTimeout = respTimeout - - var params JobParameters - if err := json.Unmarshal([]byte(jobReq.Parameters), ¶ms); err != nil { - clog.Errorf(ctx, "Unable to unmarshal job parameters err=%v", err) - http.Error(w, fmt.Sprintf("Unable to unmarshal job parameters err=%v", err), http.StatusBadRequest) - return - } + clog.Infof(ctx, "Job request setup complete details=%v params=%v", gatewayJob.Job.Details, gatewayJob.Job.Params) - //get pool of Orchestrators that can do the job - orchs, err := getJobOrchestrators(ctx, ls.LivepeerNode, jobReq.Capability, params, jobReq.orchSearchTimeout, jobReq.orchSearchRespTimeout) if err != nil { - clog.Errorf(ctx, "Unable to find orchestrators for capability %v err=%v", jobReq.Capability, err) - http.Error(w, fmt.Sprintf("Unable to find orchestrators for capability %v err=%v", jobReq.Capability, err), http.StatusBadRequest) - return - } - - if len(orchs) == 0 { - clog.Errorf(ctx, "No orchestrators found for capability %v", jobReq.Capability) - http.Error(w, fmt.Sprintf("No orchestrators found for capability %v", jobReq.Capability), http.StatusServiceUnavailable) + http.Error(w, fmt.Sprintf("Unable to setup job err=%v", err), http.StatusBadRequest) return } - + ctx = clog.AddVal(ctx, "job_id", gatewayJob.Job.Req.ID) + ctx = clog.AddVal(ctx, "capability", gatewayJob.Job.Req.Capability) // Read the original request body body, err := io.ReadAll(r.Body) if err != nil { @@ -323,29 +386,10 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, return } r.Body.Close() - //sign the request - gateway := ls.LivepeerNode.OrchestratorPool.Broadcaster() - sig, err := gateway.Sign([]byte(jobReq.Request + jobReq.Parameters)) - if err != nil { - clog.Errorf(ctx, "Unable to sign request err=%v", err) - http.Error(w, fmt.Sprintf("Unable to sign request err=%v", err), http.StatusInternalServerError) - return - } - jobReq.Sender = gateway.Address().Hex() - jobReq.Sig = "0x" + hex.EncodeToString(sig) - - //create the job request header with the signature - jobReqEncoded, err := json.Marshal(jobReq) - if err != nil { - clog.Errorf(ctx, "Unable to encode job request err=%v", err) - http.Error(w, fmt.Sprintf("Unable to encode job request err=%v", err), http.StatusInternalServerError) - return - } - jobReqHdr = base64.StdEncoding.EncodeToString(jobReqEncoded) //send the request to the Orchestrator(s) //the loop ends on Gateway error and bad request errors - for _, orchToken := range orchs { + for _, orchToken := range gatewayJob.Orchs { // Extract the worker resource route from the URL path // The prefix is "/process/request/" @@ -360,35 +404,21 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, workerRoute = workerRoute + "/" + workerResourceRoute } - req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(body)) + err := gatewayJob.sign() if err != nil { - clog.Errorf(ctx, "Unable to create request err=%v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + clog.Errorf(ctx, "Error signing job, exiting stream processing request: %v", err) return } - // set the headers - req.Header.Add("Content-Length", r.Header.Get("Content-Length")) - req.Header.Add("Content-Type", r.Header.Get("Content-Type")) - - req.Header.Add(jobRequestHdr, jobReqHdr) - if orchToken.Price.PricePerUnit > 0 { - paymentHdr, err := createPayment(ctx, jobReq, orchToken, ls.LivepeerNode) - if err != nil { - clog.Errorf(ctx, "Unable to create payment err=%v", err) - http.Error(w, fmt.Sprintf("Unable to create payment err=%v", err), http.StatusBadRequest) - return - } - req.Header.Add(jobPaymentHeaderHdr, paymentHdr) - } start := time.Now() - resp, err := sendJobReqWithTimeout(req, time.Duration(jobReq.Timeout+5)*time.Second) //include 5 second buffer + resp, code, err := ls.sendJobToOrch(ctx, r, gatewayJob.Job.Req, gatewayJob.SignedJobReq, orchToken, workerResourceRoute, body) if err != nil { clog.Errorf(ctx, "job not able to be processed by Orchestrator %v err=%v ", orchToken.ServiceAddr, err.Error()) continue } + //error response from Orchestrator - if resp.StatusCode > 399 { + if code > 399 { defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { @@ -398,10 +428,10 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, } clog.Errorf(ctx, "error processing request err=%v ", string(data)) //nonretryable error - if resp.StatusCode < 500 { + if code < 500 { //assume non retryable bad request //return error response from the worker - http.Error(w, string(data), resp.StatusCode) + http.Error(w, string(data), code) return } //retryable error, continue to next orchestrator @@ -427,7 +457,7 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, continue } - gatewayBalance := updateGatewayBalance(ls.LivepeerNode, orchToken, jobReq.Capability, time.Since(start)) + gatewayBalance := updateGatewayBalance(ls.LivepeerNode, orchToken, gatewayJob.Job.Req.Capability, time.Since(start)) clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v balance_from_orch=%v", time.Since(start), gatewayBalance.FloatString(0), orchBalance) w.Write(data) return @@ -450,7 +480,7 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, w.WriteHeader(http.StatusOK) // Read from upstream and forward to client respChan := make(chan string, 100) - respCtx, _ := context.WithTimeout(ctx, time.Duration(jobReq.Timeout+10)*time.Second) //include a small buffer to let Orchestrator close the connection on the timeout + respCtx, _ := context.WithTimeout(ctx, time.Duration(gatewayJob.Job.Req.Timeout+10)*time.Second) //include a small buffer to let Orchestrator close the connection on the timeout go func() { defer resp.Body.Close() @@ -491,12 +521,70 @@ func (ls *LivepeerServer) submitJob(ctx context.Context, w http.ResponseWriter, } } - gatewayBalance := updateGatewayBalance(ls.LivepeerNode, orchToken, jobReq.Capability, time.Since(start)) + gatewayBalance := updateGatewayBalance(ls.LivepeerNode, orchToken, gatewayJob.Job.Req.Capability, time.Since(start)) clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v balance_from_orch=%v", time.Since(start), gatewayBalance.FloatString(0), orchBalance.FloatString(0)) } + } +} + +func (ls *LivepeerServer) sendJobToOrch(ctx context.Context, r *http.Request, jobReq *JobRequest, signedReqHdr string, orchToken core.JobToken, route string, body []byte) (*http.Response, int, error) { + orchUrl := orchToken.ServiceAddr + route + req, err := http.NewRequestWithContext(ctx, "POST", orchUrl, bytes.NewBuffer(body)) + if err != nil { + clog.Errorf(ctx, "Unable to create request err=%v", err) + return nil, http.StatusInternalServerError, err + } + + // set the headers + if r != nil { + req.Header.Add("Content-Length", r.Header.Get("Content-Length")) + req.Header.Add("Content-Type", r.Header.Get("Content-Type")) + } else { + //this is for live requests which will be json to start stream + // update requests should include the content type/length + req.Header.Add("Content-Type", "application/json") + } + + req.Header.Add(jobRequestHdr, signedReqHdr) + if orchToken.Price.PricePerUnit > 0 { + paymentHdr, err := createPayment(ctx, jobReq, &orchToken, ls.LivepeerNode) + if err != nil { + clog.Errorf(ctx, "Unable to create payment err=%v", err) + return nil, http.StatusInternalServerError, fmt.Errorf("Unable to create payment err=%v", err) + } + if paymentHdr != "" { + req.Header.Add(jobPaymentHeaderHdr, paymentHdr) + } + } + resp, err := sendJobReqWithTimeout(req, time.Duration(jobReq.Timeout+5)*time.Second) //include 5 second buffer + if err != nil { + clog.Errorf(ctx, "job not able to be processed by Orchestrator %v err=%v ", orchToken.ServiceAddr, err.Error()) + return nil, http.StatusBadRequest, err } + + return resp, resp.StatusCode, nil +} + +func (ls *LivepeerServer) sendPayment(ctx context.Context, orchPmtUrl, capability, jobReq, payment string) (int, error) { + req, err := http.NewRequestWithContext(ctx, "POST", orchPmtUrl, nil) + if err != nil { + clog.Errorf(ctx, "Unable to create request err=%v", err) + return http.StatusBadRequest, err + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add(jobRequestHdr, jobReq) + req.Header.Add(jobPaymentHeaderHdr, payment) + + resp, err := sendJobReqWithTimeout(req, 10*time.Second) + if err != nil { + clog.Errorf(ctx, "job payment not able to be processed by Orchestrator %v err=%v ", orchPmtUrl, err.Error()) + return http.StatusBadRequest, err + } + + return resp.StatusCode, nil } func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.Request) { @@ -505,77 +593,20 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R orch := h.orchestrator // check the prompt sig from the request // confirms capacity available before processing payment info - job := r.Header.Get(jobRequestHdr) - jobReq, err := verifyJobCreds(ctx, orch, job) + orchJob, err := h.setupOrchJob(ctx, r, true) if err != nil { - if err == errZeroCapacity { - clog.Errorf(ctx, "No capacity available for capability err=%q", err) + if err == errNoCapabilityCapacity { http.Error(w, err.Error(), http.StatusServiceUnavailable) - } else if err == errNoTimeoutSet { - clog.Errorf(ctx, "Timeout not set in request err=%q", err) - http.Error(w, err.Error(), http.StatusBadRequest) } else { - clog.Errorf(ctx, "Could not verify job creds err=%q", err) - http.Error(w, err.Error(), http.StatusForbidden) + http.Error(w, err.Error(), http.StatusBadRequest) } - return } - - sender := ethcommon.HexToAddress(jobReq.Sender) - jobPrice, err := orch.JobPriceInfo(sender, jobReq.Capability) - if err != nil { - clog.Errorf(ctx, "could not get price err=%v", err.Error()) - http.Error(w, fmt.Sprintf("Could not get price err=%v", err.Error()), http.StatusBadRequest) - return - } - clog.V(common.DEBUG).Infof(ctx, "job price=%v units=%v", jobPrice.PricePerUnit, jobPrice.PixelsPerUnit) taskId := core.RandomManifestID() - jobId := jobReq.Capability - ctx = clog.AddVal(ctx, "job_id", jobReq.ID) + ctx = clog.AddVal(ctx, "job_id", orchJob.Req.ID) ctx = clog.AddVal(ctx, "worker_task_id", string(taskId)) - ctx = clog.AddVal(ctx, "capability", jobReq.Capability) - ctx = clog.AddVal(ctx, "sender", jobReq.Sender) - - //no payment included, confirm if balance remains - jobPriceRat := big.NewRat(jobPrice.PricePerUnit, jobPrice.PixelsPerUnit) - var payment net.Payment - // if price is 0, no payment required - if jobPriceRat.Cmp(big.NewRat(0, 1)) > 0 { - // get payment information - payment, err = getPayment(r.Header.Get(jobPaymentHeaderHdr)) - if err != nil { - clog.Errorf(r.Context(), "Could not parse payment: %v", err) - http.Error(w, err.Error(), http.StatusPaymentRequired) - return - } - - if payment.TicketParams == nil { - - //if price is not 0, confirm balance - if jobPriceRat.Cmp(big.NewRat(0, 1)) > 0 { - minBal := jobPriceRat.Mul(jobPriceRat, big.NewRat(60, 1)) //minimum 1 minute balance - orchBal := getPaymentBalance(orch, sender, jobId) - - if orchBal.Cmp(minBal) < 0 { - clog.Errorf(ctx, "Insufficient balance for request") - http.Error(w, "Insufficient balance", http.StatusPaymentRequired) - orch.FreeExternalCapabilityCapacity(jobReq.Capability) - return - } - } - } else { - if err := orch.ProcessPayment(ctx, payment, core.ManifestID(jobId)); err != nil { - clog.Errorf(ctx, "error processing payment err=%q", err) - http.Error(w, err.Error(), http.StatusBadRequest) - orch.FreeExternalCapabilityCapacity(jobReq.Capability) - return - } - } - - clog.Infof(ctx, "balance after payment is %v", getPaymentBalance(orch, sender, jobId).FloatString(0)) - } - + ctx = clog.AddVal(ctx, "capability", orchJob.Req.Capability) + ctx = clog.AddVal(ctx, "sender", orchJob.Req.Sender) clog.V(common.SHORT).Infof(ctx, "Received job, sending for processing") // Read the original body @@ -595,7 +626,7 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R workerResourceRoute = workerResourceRoute[len(prefix):] } - workerRoute := jobReq.CapabilityUrl + workerRoute := orchJob.Req.CapabilityUrl if workerResourceRoute != "" { workerRoute = workerRoute + "/" + workerResourceRoute } @@ -610,18 +641,18 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R req.Header.Add("Content-Type", r.Header.Get("Content-Type")) start := time.Now() - resp, err := sendReqWithTimeout(req, time.Duration(jobReq.Timeout)*time.Second) + resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) if err != nil { clog.Errorf(ctx, "job not able to be processed err=%v ", err.Error()) //if the request failed with connection error, remove the capability //exclude deadline exceeded or context canceled errors does not indicate a fatal error all the time if err != context.DeadlineExceeded && !strings.Contains(err.Error(), "context canceled") { - clog.Errorf(ctx, "removing capability %v due to error %v", jobReq.Capability, err.Error()) - h.orchestrator.RemoveExternalCapability(jobReq.Capability) + clog.Errorf(ctx, "removing capability %v due to error %v", orchJob.Req.Capability, err.Error()) + h.orchestrator.RemoveExternalCapability(orchJob.Req.Capability) } - chargeForCompute(start, jobPrice, orch, sender, jobId) - w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, sender, jobId).FloatString(0)) + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) http.Error(w, fmt.Sprintf("job not able to be processed, removing capability err=%v", err.Error()), http.StatusInternalServerError) return } @@ -631,7 +662,7 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R //release capacity for another request // if requester closes the connection need to release capacity - defer orch.FreeExternalCapabilityCapacity(jobReq.Capability) + defer orch.FreeExternalCapabilityCapacity(orchJob.Req.Capability) if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { //non streaming response @@ -641,8 +672,8 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R if err != nil { clog.Errorf(ctx, "Unable to read response err=%v", err) - chargeForCompute(start, jobPrice, orch, sender, jobId) - w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, sender, jobId).FloatString(0)) + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -651,16 +682,16 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R if resp.StatusCode > 399 { clog.Errorf(ctx, "error processing request err=%v ", string(data)) - chargeForCompute(start, jobPrice, orch, sender, jobId) - w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, sender, jobId).FloatString(0)) + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) //return error response from the worker http.Error(w, string(data), resp.StatusCode) return } - chargeForCompute(start, jobPrice, orch, sender, jobId) - w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, sender, jobId).FloatString(0)) - clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v", time.Since(start), getPaymentBalance(orch, sender, jobId).FloatString(0)) + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v", time.Since(start), getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) w.Write(data) //request completed and returned a response @@ -673,22 +704,22 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") //send payment balance back so client can determine if payment is needed - addPaymentBalanceHeader(w, orch, sender, jobId) + addPaymentBalanceHeader(w, orch, orchJob.Sender, orchJob.Req.Capability) // Flush to ensure data is sent immediately flusher, ok := w.(http.Flusher) if !ok { clog.Errorf(ctx, "streaming not supported") - chargeForCompute(start, jobPrice, orch, sender, jobId) - w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, sender, jobId).FloatString(0)) + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) http.Error(w, "Streaming not supported", http.StatusInternalServerError) return } // Read from upstream and forward to client respChan := make(chan string, 100) - respCtx, _ := context.WithTimeout(ctx, time.Duration(jobReq.Timeout)*time.Second) + respCtx, _ := context.WithTimeout(ctx, time.Duration(orchJob.Req.Timeout)*time.Second) go func() { defer resp.Body.Close() @@ -697,7 +728,7 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R for scanner.Scan() { select { case <-respCtx.Done(): - orchBal := orch.Balance(sender, core.ManifestID(jobId)) + orchBal := orch.Balance(orchJob.Sender, core.ManifestID(orchJob.Req.Capability)) if orchBal == nil { orchBal = big.NewRat(0, 1) } @@ -707,7 +738,7 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R default: line := scanner.Text() if strings.Contains(line, "[DONE]") { - orchBal := orch.Balance(sender, core.ManifestID(jobId)) + orchBal := orch.Balance(orchJob.Sender, core.ManifestID(orchJob.Req.Capability)) if orchBal == nil { orchBal = big.NewRat(0, 1) } @@ -729,9 +760,10 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R case <-pmtWatcher.C: //check balance and end response if out of funds //skips if price is 0 + jobPriceRat := big.NewRat(orchJob.JobPrice.PricePerUnit, orchJob.JobPrice.PixelsPerUnit) if jobPriceRat.Cmp(big.NewRat(0, 1)) > 0 { - h.orchestrator.DebitFees(sender, core.ManifestID(jobId), jobPrice, 5) - senderBalance := getPaymentBalance(orch, sender, jobId) + h.orchestrator.DebitFees(orchJob.Sender, core.ManifestID(orchJob.Req.Capability), orchJob.JobPrice, 5) + senderBalance := getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability) if senderBalance != nil { if senderBalance.Cmp(big.NewRat(0, 1)) < 0 { w.Write([]byte("event: insufficient balance\n")) @@ -751,35 +783,153 @@ func processJob(ctx context.Context, h *lphttp, w http.ResponseWriter, r *http.R } //capacity released with defer stmt above - clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v", time.Since(start), getPaymentBalance(orch, sender, jobId).FloatString(0)) + clog.V(common.SHORT).Infof(ctx, "Job processed successfully took=%v balance=%v", time.Since(start), getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + } +} + +// SetupOrchJob prepares the orchestrator job by extracting and validating the job request from the HTTP headers. +// Payment is applied if applicable. +func (h *lphttp) setupOrchJob(ctx context.Context, r *http.Request, reserveCapacity bool) (*orchJob, error) { + job := r.Header.Get(jobRequestHdr) + orch := h.orchestrator + jobReq, err := verifyJobCreds(ctx, orch, job, reserveCapacity) + if err != nil { + if err == errZeroCapacity && reserveCapacity { + return nil, errNoCapabilityCapacity + } else if err == errNoTimeoutSet { + return nil, errNoTimeoutSet + } else { + clog.Errorf(ctx, "job failed verification: %v", err) + return nil, errNoJobCreds + } + } + + sender := ethcommon.HexToAddress(jobReq.Sender) + + jobPrice, err := orch.JobPriceInfo(sender, jobReq.Capability) + if err != nil { + return nil, errors.New("Could not get job price") + } + clog.V(common.DEBUG).Infof(ctx, "job price=%v units=%v", jobPrice.PricePerUnit, jobPrice.PixelsPerUnit) + + //no payment included, confirm if balance remains + jobPriceRat := big.NewRat(jobPrice.PricePerUnit, jobPrice.PixelsPerUnit) + orchBal := big.NewRat(0, 1) + // if price is 0, no payment required + if jobPriceRat.Cmp(big.NewRat(0, 1)) > 0 { + minBal := new(big.Rat).Mul(jobPriceRat, big.NewRat(60, 1)) //minimum 1 minute balance + //process payment if included + orchBal, pmtErr := processPayment(ctx, orch, sender, jobReq.Capability, r.Header.Get(jobPaymentHeaderHdr)) + if pmtErr != nil { + //log if there are payment errors but continue, balance will runout and clean up + clog.Infof(ctx, "job payment error: %v", pmtErr) + } + + if orchBal.Cmp(minBal) < 0 { + orch.FreeExternalCapabilityCapacity(jobReq.Capability) + return nil, errInsufficientBalance + } + } + + var jobDetails JobRequestDetails + err = json.Unmarshal([]byte(jobReq.Request), &jobDetails) + if err != nil { + return nil, fmt.Errorf("Unable to unmarshal job request details err=%v", err) } + + clog.Infof(ctx, "job request verified id=%v sender=%v capability=%v timeout=%v price=%v balance=%v", jobReq.ID, jobReq.Sender, jobReq.Capability, jobReq.Timeout, jobPriceRat.FloatString(0), orchBal.FloatString(0)) + + return &orchJob{Req: jobReq, Sender: sender, JobPrice: jobPrice, Details: &jobDetails}, nil } -func createPayment(ctx context.Context, jobReq *JobRequest, orchToken JobToken, node *core.LivepeerNode) (string, error) { +// process payment and return balance +func processPayment(ctx context.Context, orch Orchestrator, sender ethcommon.Address, capability string, paymentHdr string) (*big.Rat, error) { + if paymentHdr != "" { + payment, err := getPayment(paymentHdr) + if err != nil { + clog.Errorf(ctx, "job payment invalid: %v", err) + return nil, errPaymentError + } + + if err := orch.ProcessPayment(ctx, payment, core.ManifestID(capability)); err != nil { + orch.FreeExternalCapabilityCapacity(capability) + clog.Errorf(ctx, "Error processing payment: %v", err) + return nil, errPaymentError + } + } + orchBal := getPaymentBalance(orch, sender, capability) + + return orchBal, nil + +} + +func createPayment(ctx context.Context, jobReq *JobRequest, orchToken *core.JobToken, node *core.LivepeerNode) (string, error) { + if orchToken == nil { + return "", errors.New("orchestrator token is nil, cannot create payment") + } + //if no sender or ticket params, no payment + if node.Sender == nil { + return "", errors.New("no ticket sender available, cannot create payment") + } + if orchToken.TicketParams == nil { + return "", errors.New("no ticket params available, cannot create payment") + } + var payment *net.Payment + createTickets := true + clog.Infof(ctx, "creating payment for job request %s", jobReq.Capability) sender := ethcommon.HexToAddress(jobReq.Sender) + orchAddr := ethcommon.BytesToAddress(orchToken.TicketParams.Recipient) - balance := node.Balances.Balance(orchAddr, core.ManifestID(jobReq.Capability)) sessionID := node.Sender.StartSession(*pmTicketParams(orchToken.TicketParams)) - createTickets := true + + //setup balances and update Gateway balance to Orchestrator balance, log differences + //Orchestrator tracks balance paid and will not perform work if the balance it is + //has is not sufficient + orchBal := big.NewRat(orchToken.Balance, 1) + balance := node.Balances.Balance(orchAddr, core.ManifestID(jobReq.Capability)) if balance == nil { //create a balance of 0 node.Balances.Debit(orchAddr, core.ManifestID(jobReq.Capability), big.NewRat(0, 1)) balance = node.Balances.Balance(orchAddr, core.ManifestID(jobReq.Capability)) + } + + diff := new(big.Rat).Sub(orchBal, balance) + if balance.Cmp(orchBal) != 0 { + clog.Infof(ctx, "Adjusting gateway balance to Orchestrator provided balance for sender=%v capability=%v balance=%v orchBal=%v diff=%v", sender.Hex(), jobReq.Capability, balance.FloatString(0), orchBal.FloatString(0), diff.FloatString(0)) + } + + if diff.Sign() > 0 { + node.Balances.Credit(orchAddr, core.ManifestID(jobReq.Capability), diff) } else { - price := big.NewRat(orchToken.Price.PricePerUnit, orchToken.Price.PixelsPerUnit) - cost := price.Mul(price, big.NewRat(int64(jobReq.Timeout), 1)) - if balance.Cmp(cost) > 0 { - createTickets = false - payment = &net.Payment{ - Sender: sender.Bytes(), - ExpectedPrice: orchToken.Price, - } + node.Balances.Debit(orchAddr, core.ManifestID(jobReq.Capability), new(big.Rat).Abs(diff)) + } + + price := big.NewRat(orchToken.Price.PricePerUnit, orchToken.Price.PixelsPerUnit) + cost := new(big.Rat).Mul(price, big.NewRat(int64(jobReq.Timeout), 1)) + minBal := new(big.Rat).Mul(price, big.NewRat(60, 1)) //minimum 1 minute balance + if cost.Cmp(minBal) < 0 { + cost = minBal + } + + if balance.Sign() > 0 && orchToken.Balance == 0 { + clog.Infof(ctx, "Updating balance to 0 because orchestrator balance reset for sender=%v capability=%v balance=%v", sender.Hex(), jobReq.Capability, balance.FloatString(0)) + node.Balances.Debit(orchAddr, core.ManifestID(jobReq.Capability), balance) + balance = node.Balances.Balance(orchAddr, core.ManifestID(jobReq.Capability)) + } + + if balance.Cmp(cost) > 0 { + createTickets = false + payment = &net.Payment{ + Sender: sender.Bytes(), + ExpectedPrice: orchToken.Price, } } + clog.Infof(ctx, "current balance for sender=%v capability=%v is %v, cost=%v price=%v", sender.Hex(), jobReq.Capability, balance.FloatString(3), cost.FloatString(3), price.FloatString(3)) if !createTickets { clog.V(common.DEBUG).Infof(ctx, "No payment required, using balance=%v", balance.FloatString(3)) + return "", nil } else { //calc ticket count ticketCnt := math.Ceil(float64(jobReq.Timeout)) @@ -810,12 +960,12 @@ func createPayment(ctx context.Context, jobReq *JobRequest, orchToken JobToken, senderParams := make([]*net.TicketSenderParams, len(tickets.SenderParams)) for i := 0; i < len(tickets.SenderParams); i++ { senderParams[i] = &net.TicketSenderParams{ - SenderNonce: tickets.SenderParams[i].SenderNonce, + SenderNonce: orchToken.LastNonce + tickets.SenderParams[i].SenderNonce, Sig: tickets.SenderParams[i].Sig, } totalEV = totalEV.Add(totalEV, tickets.WinProbRat()) } - + orchToken.LastNonce = tickets.SenderParams[len(tickets.SenderParams)-1].SenderNonce + 1 payment.TicketSenderParams = senderParams ratPrice, _ := common.RatPriceInfo(payment.ExpectedPrice) @@ -844,11 +994,11 @@ func createPayment(ctx context.Context, jobReq *JobRequest, orchToken JobToken, return base64.StdEncoding.EncodeToString(data), nil } -func updateGatewayBalance(node *core.LivepeerNode, orchToken JobToken, capability string, took time.Duration) *big.Rat { +func updateGatewayBalance(node *core.LivepeerNode, orchToken core.JobToken, capability string, took time.Duration) *big.Rat { orchAddr := ethcommon.BytesToAddress(orchToken.TicketParams.Recipient) // update for usage of compute orchPrice := big.NewRat(orchToken.Price.PricePerUnit, orchToken.Price.PixelsPerUnit) - cost := orchPrice.Mul(orchPrice, big.NewRat(int64(math.Ceil(took.Seconds())), 1)) + cost := new(big.Rat).Mul(orchPrice, big.NewRat(int64(math.Ceil(took.Seconds())), 1)) node.Balances.Debit(orchAddr, core.ManifestID(capability), cost) //get the updated balance @@ -901,14 +1051,14 @@ func getPaymentBalance(orch Orchestrator, sender ethcommon.Address, jobId string return senderBalance } -func verifyTokenCreds(ctx context.Context, orch Orchestrator, tokenCreds string) (*JobSender, error) { +func verifyTokenCreds(ctx context.Context, orch Orchestrator, tokenCreds string) (*core.JobSender, error) { buf, err := base64.StdEncoding.DecodeString(tokenCreds) if err != nil { glog.Error("Unable to base64-decode ", err) return nil, errSegEncoding } - var jobSender JobSender + var jobSender core.JobSender err = json.Unmarshal(buf, &jobSender) if err != nil { clog.Errorf(ctx, "Unable to parse the header text: ", err) @@ -955,7 +1105,7 @@ func parseJobRequest(jobReq string) (*JobRequest, error) { return &jobData, nil } -func verifyJobCreds(ctx context.Context, orch Orchestrator, jobCreds string) (*JobRequest, error) { +func verifyJobCreds(ctx context.Context, orch Orchestrator, jobCreds string, reserveCapacity bool) (*JobRequest, error) { //Gateway needs JobRequest parsed and verification of required fields jobData, err := parseJobRequest(jobCreds) if err != nil { @@ -985,7 +1135,7 @@ func verifyJobCreds(ctx context.Context, orch Orchestrator, jobCreds string) (*J return nil, errSegSig } - if orch.ReserveExternalCapabilityCapacity(jobData.Capability) != nil { + if reserveCapacity && orch.ReserveExternalCapabilityCapacity(jobData.Capability) != nil { return nil, errZeroCapacity } @@ -1015,24 +1165,16 @@ func getOrchSearchTimeouts(ctx context.Context, searchTimeoutHdr, respTimeoutHdr return timeout, respTimeout } -func getJobOrchestrators(ctx context.Context, node *core.LivepeerNode, capability string, params JobParameters, timeout time.Duration, respTimeout time.Duration) ([]JobToken, error) { +func getJobOrchestrators(ctx context.Context, node *core.LivepeerNode, capability string, params JobParameters, timeout time.Duration, respTimeout time.Duration) ([]core.JobToken, error) { orchs := node.OrchestratorPool.GetInfos() - gateway := node.OrchestratorPool.Broadcaster() - //setup the GET request to get the Orchestrator tokens - //get the address and sig for the sender - gatewayReq, err := genOrchestratorReq(gateway, GetOrchestratorInfoParams{}) + reqSender, err := getJobSender(ctx, node) if err != nil { - clog.Errorf(ctx, "Failed to generate request for Orchestrator to verify to request job token err=%v", err) + clog.Errorf(ctx, "Failed to get job sender err=%v", err) return nil, err } - addr := ethcommon.BytesToAddress(gatewayReq.Address) - reqSender := &JobSender{ - Addr: addr.Hex(), - Sig: "0x" + hex.EncodeToString(gatewayReq.Sig), - } - getOrchJobToken := func(ctx context.Context, orchUrl *url.URL, reqSender JobSender, respTimeout time.Duration, tokenCh chan JobToken, errCh chan error) { + getOrchJobToken := func(ctx context.Context, orchUrl *url.URL, reqSender core.JobSender, respTimeout time.Duration, tokenCh chan core.JobToken, errCh chan error) { start := time.Now() tokenReq, err := http.NewRequestWithContext(ctx, "GET", orchUrl.String()+"/process/token", nil) reqSenderStr, _ := json.Marshal(reqSender) @@ -1066,7 +1208,7 @@ func getJobOrchestrators(ctx context.Context, node *core.LivepeerNode, capabilit errCh <- err return } - var jobToken JobToken + var jobToken core.JobToken err = json.Unmarshal(token, &jobToken) if err != nil { clog.Errorf(ctx, "Failed to unmarshal token from Orchestrator %v err=%v", orchUrl.String(), err) @@ -1077,11 +1219,11 @@ func getJobOrchestrators(ctx context.Context, node *core.LivepeerNode, capabilit tokenCh <- jobToken } - var jobTokens []JobToken + var jobTokens []core.JobToken timedOut := false nbResp := 0 numAvailableOrchs := node.OrchestratorPool.Size() - tokenCh := make(chan JobToken, numAvailableOrchs) + tokenCh := make(chan core.JobToken, numAvailableOrchs) errCh := make(chan error, numAvailableOrchs) tokensCtx, cancel := context.WithTimeout(clog.Clone(context.Background(), ctx), timeout) @@ -1116,3 +1258,75 @@ func getJobOrchestrators(ctx context.Context, node *core.LivepeerNode, capabilit return jobTokens, nil } + +func getJobSender(ctx context.Context, node *core.LivepeerNode) (*core.JobSender, error) { + gateway := node.OrchestratorPool.Broadcaster() + orchReq, err := genOrchestratorReq(gateway, GetOrchestratorInfoParams{}) + if err != nil { + clog.Errorf(ctx, "Failed to generate request for Orchestrator to verify to request job token err=%v", err) + return nil, err + } + addr := ethcommon.BytesToAddress(orchReq.Address) + jobSender := &core.JobSender{ + Addr: addr.Hex(), + Sig: "0x" + hex.EncodeToString(orchReq.Sig), + } + + return jobSender, nil +} +func getToken(ctx context.Context, respTimeout time.Duration, orchUrl, capability, sender, senderSig string) (*core.JobToken, error) { + start := time.Now() + tokenReq, err := http.NewRequestWithContext(ctx, "GET", orchUrl+"/process/token", nil) + jobSender := core.JobSender{Addr: sender, Sig: senderSig} + + reqSenderStr, _ := json.Marshal(jobSender) + tokenReq.Header.Set(jobEthAddressHdr, base64.StdEncoding.EncodeToString(reqSenderStr)) + tokenReq.Header.Set(jobCapabilityHdr, capability) + if err != nil { + clog.Errorf(ctx, "Failed to create request for Orchestrator to verify job token request err=%v", err) + return nil, err + } + + var resp *http.Response + var token []byte + var jobToken core.JobToken + var attempt int + var backoff time.Duration = 100 * time.Millisecond + deadline := time.Now().Add(respTimeout) + + for attempt = 0; attempt < 3; attempt++ { + resp, err = sendJobReqWithTimeout(tokenReq, respTimeout) + if err != nil { + clog.Errorf(ctx, "failed to get token from Orchestrator (attempt %d) err=%v", attempt+1, err) + } else if resp.StatusCode != http.StatusOK { + clog.Errorf(ctx, "Failed to get token from Orchestrator %v status=%v (attempt %d)", orchUrl, resp.StatusCode, attempt+1) + } else { + defer resp.Body.Close() + latency := time.Since(start) + clog.V(common.DEBUG).Infof(ctx, "Received job token from uri=%v, latency=%v", orchUrl, latency) + token, err = io.ReadAll(resp.Body) + if err != nil { + clog.Errorf(ctx, "Failed to read token from Orchestrator %v err=%v", orchUrl, err) + } else { + err = json.Unmarshal(token, &jobToken) + if err != nil { + clog.Errorf(ctx, "Failed to unmarshal token from Orchestrator %v err=%v", orchUrl, err) + } else { + return &jobToken, nil + } + } + } + // If not last attempt and time remains, backoff + if time.Now().Add(backoff).Before(deadline) && attempt < 2 { + time.Sleep(backoff) + backoff *= 2 + } else { + break + } + } + // All attempts failed + if err != nil { + return nil, err + } + return nil, fmt.Errorf("failed to get token from Orchestrator after %d attempts", attempt) +} diff --git a/server/job_rpc_test.go b/server/job_rpc_test.go index 97b22e799d..2cbcaa3a5c 100644 --- a/server/job_rpc_test.go +++ b/server/job_rpc_test.go @@ -13,6 +13,7 @@ import ( "net/http/httptest" "net/url" "slices" + "sync" "testing" "time" @@ -54,6 +55,7 @@ type mockJobOrchestrator struct { reserveCapacity func(string) error getUrlForCapability func(string) string balance func(ethcommon.Address, core.ManifestID) *big.Rat + processPayment func(context.Context, net.Payment, core.ManifestID) error debitFees func(ethcommon.Address, core.ManifestID, *net.PriceInfo, int64) freeCapacity func(string) error jobPriceInfo func(ethcommon.Address, string) (*net.PriceInfo, error) @@ -114,6 +116,9 @@ func (r *mockJobOrchestrator) StreamIDs(jobID string) ([]core.StreamID, error) { } func (r *mockJobOrchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID core.ManifestID) error { + if r.processPayment != nil { + return r.processPayment(ctx, payment, manifestID) + } return nil } @@ -134,6 +139,9 @@ func (r *mockJobOrchestrator) SufficientBalance(addr ethcommon.Address, manifest } func (r *mockJobOrchestrator) DebitFees(addr ethcommon.Address, manifestID core.ManifestID, price *net.PriceInfo, pixels int64) { + if r.debitFees != nil { + r.debitFees(addr, manifestID, price, pixels) + } } func (r *mockJobOrchestrator) Balance(addr ethcommon.Address, manifestID core.ManifestID) *big.Rat { @@ -336,13 +344,14 @@ func (s *stubJobOrchestratorPool) SizeWith(scorePred common.ScorePred) int { return count } func (s *stubJobOrchestratorPool) Broadcaster() common.Broadcaster { - return core.NewBroadcaster(s.node) + return stubBroadcaster2() } func mockJobLivepeerNode() *core.LivepeerNode { node, _ := core.NewLivepeerNode(nil, "/tmp/thisdirisnotactuallyusedinthistest", nil) node.NodeType = core.OrchestratorNode node.OrchSecret = "verbigsecret" + node.LiveMu = &sync.RWMutex{} return node } @@ -578,7 +587,7 @@ func TestGetJobToken_InvalidEthAddressHeader(t *testing.T) { } // Create a valid JobSender structure - js := &JobSender{ + js := &core.JobSender{ Addr: "0x0000000000000000000000000000000000000000", Sig: "0x000000000000000000000000000000000000000000000000000000000000000000", } @@ -607,7 +616,7 @@ func TestGetJobToken_MissingCapabilityHeader(t *testing.T) { } // Create a valid JobSender structure - js := &JobSender{ + js := &core.JobSender{ Addr: "0x0000000000000000000000000000000000000000", Sig: "0x000000000000000000000000000000000000000000000000000000000000000000", } @@ -649,7 +658,7 @@ func TestGetJobToken_NoCapacity(t *testing.T) { // Create a valid JobSender structure gateway := stubBroadcaster2() sig, _ := gateway.Sign([]byte(hexutil.Encode(gateway.Address().Bytes()))) - js := &JobSender{ + js := &core.JobSender{ Addr: hexutil.Encode(gateway.Address().Bytes()), Sig: hexutil.Encode(sig), } @@ -692,7 +701,7 @@ func TestGetJobToken_JobPriceInfoError(t *testing.T) { // Create a valid JobSender structure gateway := stubBroadcaster2() sig, _ := gateway.Sign([]byte(hexutil.Encode(gateway.Address().Bytes()))) - js := &JobSender{ + js := &core.JobSender{ Addr: hexutil.Encode(gateway.Address().Bytes()), Sig: hexutil.Encode(sig), } @@ -736,7 +745,7 @@ func TestGetJobToken_InsufficientReserve(t *testing.T) { // Create a valid JobSender structure gateway := stubBroadcaster2() sig, _ := gateway.Sign([]byte(hexutil.Encode(gateway.Address().Bytes()))) - js := &JobSender{ + js := &core.JobSender{ Addr: hexutil.Encode(gateway.Address().Bytes()), Sig: hexutil.Encode(sig), } @@ -787,7 +796,7 @@ func TestGetJobToken_TicketParamsError(t *testing.T) { // Create a valid JobSender structure gateway := stubBroadcaster2() sig, _ := gateway.Sign([]byte(hexutil.Encode(gateway.Address().Bytes()))) - js := &JobSender{ + js := &core.JobSender{ Addr: hexutil.Encode(gateway.Address().Bytes()), Sig: hexutil.Encode(sig), } @@ -851,7 +860,7 @@ func TestGetJobToken_Success(t *testing.T) { // Create a valid JobSender structure gateway := stubBroadcaster2() sig, _ := gateway.Sign([]byte(hexutil.Encode(gateway.Address().Bytes()))) - js := &JobSender{ + js := &core.JobSender{ Addr: hexutil.Encode(gateway.Address().Bytes()), Sig: hexutil.Encode(sig), } @@ -868,7 +877,7 @@ func TestGetJobToken_Success(t *testing.T) { resp := w.Result() assert.Equal(t, http.StatusOK, resp.StatusCode) - var token JobToken + var token core.JobToken body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &token) @@ -916,18 +925,18 @@ func TestCreatePayment(t *testing.T) { mockSender.On("StartSession", mock.Anything).Return("foo").Times(4) node.Sender = &mockSender - node.Balances = core.NewAddressBalances(10) + node.Balances = core.NewAddressBalances(1 * time.Second) defer node.Balances.StopCleanup() jobReq := JobRequest{ Capability: "test-payment-cap", } - sender := JobSender{ + sender := core.JobSender{ Addr: "0x1111111111111111111111111111111111111111", Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", } - orchTocken := JobToken{ + orchTocken := core.JobToken{ TicketParams: &net.TicketParams{ Recipient: ethcommon.HexToAddress("0x1111111111111111111111111111111111111111").Bytes(), FaceValue: big.NewInt(1000).Bytes(), @@ -949,7 +958,7 @@ func TestCreatePayment(t *testing.T) { //payment with one ticket jobReq.Timeout = 1 mockSender.On("CreateTicketBatch", "foo", jobReq.Timeout).Return(mockTicketBatch(jobReq.Timeout), nil).Once() - payment, err := createPayment(ctx, &jobReq, orchTocken, node) + payment, err := createPayment(ctx, &jobReq, &orchTocken, node) assert.Nil(t, err) pmPayment, err := base64.StdEncoding.DecodeString(payment) assert.Nil(t, err) @@ -960,7 +969,7 @@ func TestCreatePayment(t *testing.T) { //test 2 tickets jobReq.Timeout = 2 mockSender.On("CreateTicketBatch", "foo", jobReq.Timeout).Return(mockTicketBatch(jobReq.Timeout), nil).Once() - payment, err = createPayment(ctx, &jobReq, orchTocken, node) + payment, err = createPayment(ctx, &jobReq, &orchTocken, node) assert.Nil(t, err) pmPayment, err = base64.StdEncoding.DecodeString(payment) assert.Nil(t, err) @@ -971,7 +980,7 @@ func TestCreatePayment(t *testing.T) { //test 600 tickets jobReq.Timeout = 600 mockSender.On("CreateTicketBatch", "foo", jobReq.Timeout).Return(mockTicketBatch(jobReq.Timeout), nil).Once() - payment, err = createPayment(ctx, &jobReq, orchTocken, node) + payment, err = createPayment(ctx, &jobReq, &orchTocken, node) assert.Nil(t, err) pmPayment, err = base64.StdEncoding.DecodeString(payment) assert.Nil(t, err) @@ -980,6 +989,51 @@ func TestCreatePayment(t *testing.T) { assert.Equal(t, 600, len(pmTickets.TicketSenderParams)) } +func createTestPayment(capability string) (string, error) { + ctx := context.TODO() + node, _ := core.NewLivepeerNode(nil, "/tmp/thisdirisnotactuallyusedinthistest", nil) + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(4) + mockSender.On("CreateTicketBatch", "foo", 1).Return(mockTicketBatch(1), nil).Once() + node.Sender = &mockSender + + node.Balances = core.NewAddressBalances(1 * time.Second) + defer node.Balances.StopCleanup() + + jobReq := JobRequest{ + Capability: capability, + Timeout: 1, + } + sender := core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + orchTocken := core.JobToken{ + TicketParams: &net.TicketParams{ + Recipient: ethcommon.HexToAddress("0x1111111111111111111111111111111111111111").Bytes(), + FaceValue: big.NewInt(1000).Bytes(), + WinProb: big.NewInt(1).Bytes(), + RecipientRandHash: []byte("hash"), + Seed: big.NewInt(1234).Bytes(), + ExpirationBlock: big.NewInt(100000).Bytes(), + }, + SenderAddress: &sender, + Balance: 0, + Price: &net.PriceInfo{ + PricePerUnit: 10, + PixelsPerUnit: 1, + }, + } + + pmt, err := createPayment(ctx, &jobReq, &orchTocken, node) + if err != nil { + return "", err + } + + return pmt, nil +} + func mockTicketBatch(count int) *pm.TicketBatch { senderParams := make([]*pm.TicketSenderParams, count) for i := 0; i < count; i++ { @@ -998,7 +1052,7 @@ func mockTicketBatch(count int) *pm.TicketBatch { ExpirationBlock: big.NewInt(1000), }, TicketExpirationParams: &pm.TicketExpirationParams{}, - Sender: pm.RandAddress(), + Sender: ethcommon.HexToAddress("0x1111111111111111111111111111111111111111"), SenderParams: senderParams, } } @@ -1008,33 +1062,9 @@ func TestSubmitJob_OrchestratorSelectionParams(t *testing.T) { mockServers := make([]*httptest.Server, 5) orchURLs := make([]string, 5) - // Create a handler that returns a valid job token - tokenHandler := func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/process/token" { - http.NotFound(w, r) - return - } - - token := &JobToken{ - ServiceAddr: "http://" + r.Host, // Use the server's host as the service address - SenderAddress: &JobSender{ - Addr: "0x1234567890abcdef1234567890abcdef123456", - Sig: "0x456", - }, - TicketParams: nil, - Price: &net.PriceInfo{ - PricePerUnit: 100, - PixelsPerUnit: 1, - }, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(token) - } - // Start HTTP test servers for i := 0; i < 5; i++ { - server := httptest.NewServer(http.HandlerFunc(tokenHandler)) + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) mockServers[i] = server orchURLs[i] = server.URL t.Logf("Mock server %d started at %s", i, orchURLs[i]) @@ -1141,3 +1171,157 @@ func TestSubmitJob_OrchestratorSelectionParams(t *testing.T) { } } + +func TestProcessPayment(t *testing.T) { + + ctx := context.Background() + sender := ethcommon.HexToAddress("0x1111111111111111111111111111111111111111") + + cases := []struct { + name string + capability string + expectDelta bool + }{ + {"empty header", "testcap", false}, + {"empty capability", "", false}, + {"random capability", "randomcap", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Simulate a mutable balance for the test + testBalance := big.NewRat(100, 1) + balanceCalled := 0 + paymentCalled := 0 + orch := newMockJobOrchestrator() + orch.balance = func(addr ethcommon.Address, manifestID core.ManifestID) *big.Rat { + balanceCalled++ + return new(big.Rat).Set(testBalance) + } + orch.processPayment = func(ctx context.Context, payment net.Payment, manifestID core.ManifestID) error { + paymentCalled++ + // Simulate payment by increasing balance + testBalance = testBalance.Add(testBalance, big.NewRat(50, 1)) + return nil + } + + testPmtHdr, err := createTestPayment(tc.capability) + if err != nil { + t.Fatalf("Failed to create test payment: %v", err) + } + + before := orch.Balance(sender, core.ManifestID(tc.capability)).FloatString(0) + bal, err := processPayment(ctx, orch, sender, tc.capability, testPmtHdr) + after := orch.Balance(sender, core.ManifestID(tc.capability)).FloatString(0) + t.Logf("Balance before: %s, after: %s", before, after) + assert.NoError(t, err) + assert.NotNil(t, bal) + if testPmtHdr != "" { + assert.NotEqual(t, before, after, "Balance should change if payment header is not empty") + assert.Equal(t, 1, paymentCalled, "ProcessPayment should be called once for non-empty header") + } else { + assert.Equal(t, before, after, "Balance should not change if payment header is empty") + assert.Equal(t, 0, paymentCalled, "ProcessPayment should not be called for empty header") + } + }) + } +} + +func TestSetupGatewayJob(t *testing.T) { + // Prepare a JobRequest with valid fields + jobDetails := JobRequestDetails{StreamId: "test-stream"} + jobParams := JobParameters{ + Orchestrators: JobOrchestratorsFilter{}, + EnableVideoIngress: true, + EnableVideoEgress: true, + EnableDataOutput: true, + } + jobReq := JobRequest{ + ID: "job-1", + Request: marshalToString(t, jobDetails), + Parameters: marshalToString(t, jobParams), + Capability: "test-capability", + Timeout: 10, + } + jobReqB, err := json.Marshal(jobReq) + assert.NoError(t, err) + jobReqB64 := base64.StdEncoding.EncodeToString(jobReqB) + + // Setup a minimal LivepeerServer with a stub OrchestratorPool + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node := mockJobLivepeerNode() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(jobRequestHdr, jobReqB64) + + // Should succeed + gatewayJob, err := ls.setupGatewayJob(context.Background(), req, false) + assert.NoError(t, err) + assert.NotNil(t, gatewayJob) + assert.Equal(t, "test-capability", gatewayJob.Job.Req.Capability) + assert.Equal(t, "test-stream", gatewayJob.Job.Details.StreamId) + assert.Equal(t, 10, gatewayJob.Job.Req.Timeout) + assert.Equal(t, 1, len(gatewayJob.Orchs)) + + //test signing request + assert.Empty(t, gatewayJob.SignedJobReq) + gatewayJob.sign() + assert.NotEmpty(t, gatewayJob.SignedJobReq) + + // Should fail with invalid base64 + req.Header.Set(jobRequestHdr, "not-base64") + gatewayJob, err = ls.setupGatewayJob(context.Background(), req, false) + assert.Error(t, err) + assert.Nil(t, gatewayJob) + + // Should fail with missing orchestrators (simulate getJobOrchestrators returns empty) + req.Header.Set(jobRequestHdr, jobReqB64) + ls.LivepeerNode.OrchestratorPool = newStubOrchestratorPool(node, []string{}) + gatewayJob, err = ls.setupGatewayJob(context.Background(), req, false) + assert.Error(t, err) + assert.Nil(t, gatewayJob) +} + +// marshalToString is a helper to marshal a struct to a JSON string +func marshalToString(t *testing.T, v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshalToString failed: %v", err) + } + return string(b) +} + +func orchTokenHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/process/token" { + http.NotFound(w, r) + return + } + + token := createMockJobToken("http://" + r.Host) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(token) + +} + +func createMockJobToken(hostUrl string) *core.JobToken { + return &core.JobToken{ + ServiceAddr: hostUrl, + SenderAddress: &core.JobSender{ + Addr: "0x1234567890abcdef1234567890abcdef123456", + Sig: "0x456", + }, + TicketParams: &net.TicketParams{ + Recipient: ethcommon.HexToAddress("0x1111111111111111111111111111111111111111").Bytes(), + FaceValue: big.NewInt(1000).Bytes(), + }, + Price: &net.PriceInfo{ + PricePerUnit: 100, + PixelsPerUnit: 1, + }, + } +} diff --git a/server/job_stream.go b/server/job_stream.go new file mode 100644 index 0000000000..1f953193a6 --- /dev/null +++ b/server/job_stream.go @@ -0,0 +1,1364 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "os" + "strings" + "sync" + "time" + + ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/golang/glog" + "github.com/livepeer/go-livepeer/clog" + "github.com/livepeer/go-livepeer/common" + "github.com/livepeer/go-livepeer/core" + "github.com/livepeer/go-livepeer/media" + "github.com/livepeer/go-livepeer/monitor" + "github.com/livepeer/go-livepeer/net" + "github.com/livepeer/go-livepeer/trickle" + "github.com/livepeer/go-tools/drivers" +) + +var getNewTokenTimeout = 3 * time.Second + +func (ls *LivepeerServer) StartStream() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + corsHeaders(w, r.Method) + w.WriteHeader(http.StatusNoContent) + return + } + + // Create fresh context instead of using r.Context() since ctx will outlive the request + ctx := r.Context() + + corsHeaders(w, r.Method) + //verify request, get orchestrators available and sign request + gatewayJob, err := ls.setupGatewayJob(ctx, r, false) + if err != nil { + clog.Errorf(ctx, "Error setting up job: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + //setup body size limit, will error if too large + r.Body = http.MaxBytesReader(w, r.Body, 10<<20) + streamUrls, code, err := ls.setupStream(ctx, r, gatewayJob) + if err != nil { + clog.Errorf(ctx, "Error setting up stream: %s", err) + http.Error(w, err.Error(), code) + return + } + + go ls.runStream(gatewayJob) + + go ls.monitorStream(gatewayJob.Job.Req.ID) + + if streamUrls != nil { + // Stream started successfully + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(streamUrls) + } else { + //case where we are subscribing to own streams in setupStream + w.WriteHeader(http.StatusNoContent) + } + }) +} + +func (ls *LivepeerServer) StopStream() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create fresh context instead of using r.Context() since ctx will outlive the request + ctx := r.Context() + streamId := r.PathValue("streamId") + + stream, exists := ls.LivepeerNode.LivePipelines[streamId] + if !exists { + http.Error(w, "Stream not found", http.StatusNotFound) + return + } + + params, err := getStreamRequestParams(stream) + if err != nil { + clog.Errorf(ctx, "Error getting stream request params: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + stream.StopStream(nil) + delete(ls.LivepeerNode.LivePipelines, streamId) + + stopJob, err := ls.setupGatewayJob(ctx, r, true) + if err != nil { + clog.Errorf(ctx, "Error setting up stop job: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + stopJob.sign() //no changes to make, sign job + //setup sender + jobSender, err := getJobSender(ctx, ls.LivepeerNode) + if err != nil { + clog.Errorf(ctx, "Error getting job sender: %v", err) + return + } + + token, err := sessionToToken(params.liveParams.sess) + if err != nil { + clog.Errorf(ctx, "Error converting session to token: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + newToken, err := getToken(ctx, getNewTokenTimeout, token.ServiceAddr, stopJob.Job.Req.Capability, jobSender.Addr, jobSender.Sig) + if err != nil { + clog.Errorf(ctx, "Error converting session to token: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + clog.Errorf(ctx, "Error reading request body: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer r.Body.Close() + + resp, code, err := ls.sendJobToOrch(ctx, r, stopJob.Job.Req, stopJob.SignedJobReq, *newToken, "/ai/stream/stop", body) + if err != nil { + clog.Errorf(ctx, "Error sending job to orchestrator: %s", err) + http.Error(w, err.Error(), code) + return + } + + w.WriteHeader(http.StatusOK) + io.Copy(w, resp.Body) + return + }) +} + +func (ls *LivepeerServer) runStream(gatewayJob *gatewayJob) { + streamID := gatewayJob.Job.Req.ID + stream, exists := ls.LivepeerNode.LivePipelines[streamID] + if !exists { + glog.Errorf("Stream %s not found", streamID) + return + } + //this context passes to all channels that will close when stream is canceled + ctx := stream.StreamCtx + ctx = clog.AddVal(ctx, "stream_id", streamID) + + params, err := getStreamRequestParams(stream) + if err != nil { + clog.Errorf(ctx, "Error getting stream request params: %s", err) + return + } + + //monitor for lots of fast swaps, likely something wrong with request + orchSwapper := NewOrchestratorSwapper(params) + + firstProcessed := false + for _, orch := range gatewayJob.Orchs { + clog.Infof(ctx, "Starting stream processing") + //refresh the token if not first Orch to confirm capacity and new ticket params + if firstProcessed { + newToken, err := getToken(ctx, getNewTokenTimeout, orch.ServiceAddr, gatewayJob.Job.Req.Capability, gatewayJob.Job.Req.Sender, gatewayJob.Job.Req.Sig) + if err != nil { + clog.Errorf(ctx, "Error getting token for orch=%v err=%v", orch.ServiceAddr, err) + continue + } + orch = *newToken + } + + orchSession, err := tokenToAISession(orch) + if err != nil { + clog.Errorf(ctx, "Error converting token to AISession: %v", err) + continue + } + params.liveParams.sess = &orchSession + + ctx = clog.AddVal(ctx, "orch", hexutil.Encode(orch.TicketParams.Recipient)) + ctx = clog.AddVal(ctx, "orch_url", orch.ServiceAddr) + + //set request ID to persist from Gateway to Worker + gatewayJob.Job.Req.ID = params.liveParams.streamID + err = gatewayJob.sign() + if err != nil { + clog.Errorf(ctx, "Error signing job, exiting stream processing request: %v", err) + stream.StopStream(err) + return + } + orchResp, _, err := ls.sendJobToOrch(ctx, nil, gatewayJob.Job.Req, gatewayJob.SignedJobReq, orch, "/ai/stream/start", stream.StreamRequest()) + if err != nil { + clog.Errorf(ctx, "job not able to be processed by Orchestrator %v err=%v ", orch.ServiceAddr, err.Error()) + continue + } + + GatewayStatus.StoreKey(streamID, "orchestrator", orch.ServiceAddr) + + params.liveParams.orchPublishUrl = orchResp.Header.Get("X-Publish-Url") + params.liveParams.orchSubscribeUrl = orchResp.Header.Get("X-Subscribe-Url") + params.liveParams.orchControlUrl = orchResp.Header.Get("X-Control-Url") + params.liveParams.orchEventsUrl = orchResp.Header.Get("X-Events-Url") + params.liveParams.orchDataUrl = orchResp.Header.Get("X-Data-Url") + + perOrchCtx, perOrchCancel := context.WithCancelCause(ctx) + params.liveParams.kickOrch = perOrchCancel + stream.UpdateStreamParams(params) //update params used to kickOrch (perOrchCancel) and urls + if err = startStreamProcessing(perOrchCtx, stream, params); err != nil { + clog.Errorf(ctx, "Error starting processing: %s", err) + perOrchCancel(err) + break + } + //something caused the Orch to stop performing, try to get the error and move to next Orchestrator + <-perOrchCtx.Done() + err = context.Cause(perOrchCtx) + if errors.Is(err, context.Canceled) { + // this happens if parent ctx was cancelled without a CancelCause + // or if passing `nil` as a CancelCause + err = nil + } + if !params.inputStreamExists() { + clog.Info(ctx, "No stream exists, skipping orchestrator swap") + break + } + + //if swapping too fast, stop trying since likely a bad request + if swapErr := orchSwapper.checkSwap(ctx); swapErr != nil { + if err != nil { + err = fmt.Errorf("%w: %w", swapErr, err) + } else { + err = swapErr + } + break + } + firstProcessed = true + // will swap, but first notify with the reason for the swap + if err == nil { + err = errors.New("unknown swap reason") + } + + clog.Infof(ctx, "Retrying stream with a different orchestrator err=%v", err.Error()) + + params.liveParams.sendErrorEvent(err) + + //if there is ingress input then force off + if params.liveParams.kickInput != nil { + params.liveParams.kickInput(err) + } + } + + //exhausted all Orchestrators, end stream + ls.LivepeerNode.ExternalCapabilities.RemoveStream(streamID) +} + +func (ls *LivepeerServer) monitorStream(streamId string) { + ctx := context.Background() + ctx = clog.AddVal(ctx, "stream_id", streamId) + + stream, exists := ls.LivepeerNode.LivePipelines[streamId] + if !exists { + clog.Errorf(ctx, "Stream %s not found", streamId) + return + } + params, err := getStreamRequestParams(stream) + if err != nil { + clog.Errorf(ctx, "Error getting stream request params: %v", err) + return + } + + ctx = clog.AddVal(ctx, "request_id", params.liveParams.requestID) + + // Create a ticker that runs every minute for payments with buffer to ensure payment is completed + dur := 50 * time.Second + pmtTicker := time.NewTicker(dur) + defer pmtTicker.Stop() + //setup sender + jobSender, err := getJobSender(ctx, ls.LivepeerNode) + if err != nil { + clog.Errorf(ctx, "Error getting job sender: %v", err) + return + } + + for { + select { + case <-stream.StreamCtx.Done(): + clog.Infof(ctx, "Stream %s stopped, ending monitoring", streamId) + ls.LivepeerNode.RemoveLivePipeline(streamId) + return + case <-pmtTicker.C: + if !params.inputStreamExists() { + clog.Infof(ctx, "Input stream does not exist for stream %s, ending monitoring", streamId) + return + } + + err := ls.sendPaymentForStream(ctx, stream, jobSender) + if err != nil { + clog.Errorf(ctx, "Error sending payment for stream %s: %v", streamId, err) + } + } + } +} + +func (ls *LivepeerServer) sendPaymentForStream(ctx context.Context, stream *core.LivePipeline, jobSender *core.JobSender) error { + params, err := getStreamRequestParams(stream) + if err != nil { + clog.Errorf(ctx, "Error getting stream request params: %v", err) + return err + } + token, err := sessionToToken(params.liveParams.sess) + if err != nil { + clog.Errorf(ctx, "Error getting token for session: %v", err) + return err + } + + // fetch new JobToken with each payment + // update the session for the LivePipeline with new token + newToken, err := getToken(ctx, getNewTokenTimeout, token.ServiceAddr, stream.Pipeline, jobSender.Addr, jobSender.Sig) + if err != nil { + clog.Errorf(ctx, "Error getting new token for %s: %v", token.ServiceAddr, err) + return err + } + newSess, err := tokenToAISession(*newToken) + if err != nil { + clog.Errorf(ctx, "Error converting token to AI session: %v", err) + return err + } + params.liveParams.sess = &newSess + stream.UpdateStreamParams(params) + + // send the payment + streamID := params.liveParams.streamID + jobDetails := JobRequestDetails{StreamId: streamID} + jobDetailsStr, err := json.Marshal(jobDetails) + if err != nil { + clog.Errorf(ctx, "Error marshalling job details: %v", err) + return err + } + req := &JobRequest{Request: string(jobDetailsStr), Parameters: "{}", Capability: stream.Pipeline, + Sender: jobSender.Addr, + Timeout: 70, + } + //sign the request + job := gatewayJob{Job: &orchJob{Req: req}, node: ls.LivepeerNode} + err = job.sign() + if err != nil { + clog.Errorf(ctx, "Error signing job, continuing monitoring: %v", err) + return err + } + + if newSess.OrchestratorInfo.PriceInfo.PricePerUnit > 0 { + pmtHdr, err := createPayment(ctx, req, newToken, ls.LivepeerNode) + if err != nil { + clog.Errorf(ctx, "Error processing stream payment for %s: %v", streamID, err) + // Continue monitoring even if payment fails + } + if pmtHdr == "" { + // This is no payment required, error logged above + return nil + } + + //send the payment, update the stream with the refreshed token + clog.Infof(ctx, "Sending stream payment for %s", streamID) + statusCode, err := ls.sendPayment(ctx, token.ServiceAddr+"/ai/stream/payment", stream.Pipeline, job.SignedJobReq, pmtHdr) + if err != nil { + clog.Errorf(ctx, "Error sending stream payment for %s: %v", streamID, err) + return err + } + if statusCode != http.StatusOK { + clog.Errorf(ctx, "Unexpected status code %d received for %s", statusCode, streamID) + return errors.New("unexpected status code") + } + } + + return nil +} + +type StartRequest struct { + Stream string `json:"stream_name"` + RtmpOutput string `json:"rtmp_output"` + StreamId string `json:"stream_id"` + Params string `json:"params"` +} + +type StreamUrls struct { + StreamId string `json:"stream_id"` + WhipUrl string `json:"whip_url"` + WhepUrl string `json:"whep_url"` + RtmpUrl string `json:"rtmp_url"` + RtmpOutputUrl string `json:"rtmp_output_url"` + UpdateUrl string `json:"update_url"` + StatusUrl string `json:"status_url"` + DataUrl string `json:"data_url"` +} + +func (ls *LivepeerServer) setupStream(ctx context.Context, r *http.Request, job *gatewayJob) (*StreamUrls, int, error) { + if job == nil { + return nil, http.StatusBadRequest, errors.New("invalid job") + } + + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + // Setup request body to be able to preserve for retries + // Read the entire body first with 10MB limit + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + if maxErr, ok := err.(*http.MaxBytesError); ok { + clog.Warningf(ctx, "Request body too large (over 10MB)") + return nil, http.StatusRequestEntityTooLarge, fmt.Errorf("request body too large (max %d bytes)", maxErr.Limit) + } else { + clog.Errorf(ctx, "Error reading request body: %v", err) + return nil, http.StatusBadRequest, fmt.Errorf("error reading request body: %w", err) + } + } + r.Body.Close() + + // Decode the StartRequest from JSON body + var startReq StartRequest + if err := json.NewDecoder(bytes.NewBuffer(bodyBytes)).Decode(&startReq); err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("invalid JSON request body: %w", err) + } + + //live-video-to-video uses path value for this + streamName := startReq.Stream + + streamRequestTime := time.Now().UnixMilli() + + ctx = clog.AddVal(ctx, "stream", streamName) + + // If auth webhook is set and returns an output URL, this will be replaced + outputURL := startReq.RtmpOutput + + // convention to avoid re-subscribing to our own streams + // in case we want to push outputs back into mediamtx - + // use an `-out` suffix for the stream name. + if strings.HasSuffix(streamName, "-out") { + // skip for now; we don't want to re-publish our own outputs + return nil, 0, nil + } + + // if auth webhook returns pipeline config these will be replaced + pipeline := job.Job.Req.Capability + rawParams := startReq.Params + streamID := startReq.StreamId + + var pipelineID string + var pipelineParams map[string]interface{} + if rawParams != "" { + if err := json.Unmarshal([]byte(rawParams), &pipelineParams); err != nil { + return nil, http.StatusBadRequest, errors.New("invalid model params") + } + } + + //ensure a streamid exists and includes the streamName if provided + if streamID == "" { + streamID = string(core.RandomManifestID()) + } + if streamName != "" { + streamID = fmt.Sprintf("%s-%s", streamName, streamID) + } + // BYOC uses Livepeer native WHIP + // Currently for webrtc we need to add a path prefix due to the ingress setup + //mediaMTXStreamPrefix := r.PathValue("prefix") + //if mediaMTXStreamPrefix != "" { + // mediaMTXStreamPrefix = mediaMTXStreamPrefix + "/" + //} + mediaMtxHost := os.Getenv("LIVE_AI_PLAYBACK_HOST") + if mediaMtxHost == "" { + mediaMtxHost = "rtmp://localhost:1935" + } + mediaMTXInputURL := fmt.Sprintf("%s/%s%s", mediaMtxHost, "", streamID) + mediaMTXOutputURL := mediaMTXInputURL + "-out" + mediaMTXOutputAlias := fmt.Sprintf("%s-%s-out", mediaMTXInputURL, requestID) + + var ( + whipURL string + rtmpURL string + whepURL string + dataURL string + ) + + updateURL := fmt.Sprintf("https://%s/ai/stream/%s/%s", ls.LivepeerNode.GatewayHost, streamID, "update") + statusURL := fmt.Sprintf("https://%s/ai/stream/%s/%s", ls.LivepeerNode.GatewayHost, streamID, "status") + + if job.Job.Params.EnableVideoIngress { + whipURL = fmt.Sprintf("https://%s/ai/stream/%s/%s", ls.LivepeerNode.GatewayHost, streamID, "whip") + rtmpURL = mediaMTXInputURL + } + if job.Job.Params.EnableVideoEgress { + whepURL = generateWhepUrl(streamID, requestID) + } + if job.Job.Params.EnableDataOutput { + dataURL = fmt.Sprintf("https://%s/ai/stream/%s/%s", ls.LivepeerNode.GatewayHost, streamID, "data") + } + + //if set this will overwrite settings above + if LiveAIAuthWebhookURL != nil { + authResp, err := authenticateAIStream(LiveAIAuthWebhookURL, ls.liveAIAuthApiKey, AIAuthRequest{ + Stream: streamName, + Type: "", //sourceTypeStr + QueryParams: rawParams, + GatewayHost: ls.LivepeerNode.GatewayHost, + WhepURL: whepURL, + UpdateURL: updateURL, + StatusURL: statusURL, + }) + if err != nil { + return nil, http.StatusForbidden, fmt.Errorf("live ai auth failed: %w", err) + } + + if authResp.RTMPOutputURL != "" { + outputURL = authResp.RTMPOutputURL + } + + if authResp.Pipeline != "" { + pipeline = authResp.Pipeline + } + + if len(authResp.paramsMap) > 0 { + if _, ok := authResp.paramsMap["prompt"]; !ok && pipeline == "comfyui" { + pipelineParams = map[string]interface{}{"prompt": authResp.paramsMap} + } else { + pipelineParams = authResp.paramsMap + } + } + + if authResp.StreamID != "" { + streamID = authResp.StreamID + } + + if authResp.PipelineID != "" { + pipelineID = authResp.PipelineID + } + } + + ctx = clog.AddVal(ctx, "stream_id", streamID) + clog.Infof(ctx, "Received live video AI request pipelineParams=%v", streamID, pipelineParams) + + // collect all RTMP outputs + var rtmpOutputs []string + if job.Job.Params.EnableVideoEgress { + if outputURL != "" { + rtmpOutputs = append(rtmpOutputs, outputURL) + } + if mediaMTXOutputURL != "" { + rtmpOutputs = append(rtmpOutputs, mediaMTXOutputURL, mediaMTXOutputAlias) + } + } + + clog.Info(ctx, "RTMP outputs", "destinations", rtmpOutputs) + + // Clear any previous gateway status + GatewayStatus.Clear(streamID) + GatewayStatus.StoreKey(streamID, "whep_url", whepURL) + + monitor.SendQueueEventAsync("stream_trace", map[string]interface{}{ + "type": "gateway_receive_stream_request", + "timestamp": streamRequestTime, + "stream_id": streamID, + "pipeline_id": pipelineID, + "request_id": requestID, + "orchestrator_info": map[string]interface{}{ + "address": "", + "url": "", + }, + }) + + // Count `ai_live_attempts` after successful parameters validation + clog.V(common.VERBOSE).Infof(ctx, "AI Live video attempt") + if monitor.Enabled { + monitor.AILiveVideoAttempt(job.Job.Req.Capability) + } + + sendErrorEvent := LiveErrorEventSender(ctx, streamID, map[string]string{ + "type": "error", + "request_id": requestID, + "stream_id": streamID, + "pipeline_id": pipelineID, + "pipeline": pipeline, + }) + + //params set with ingest types: + // RTMP + // kickInput will kick the input from MediaMTX to force a reconnect + // localRTMPPrefix mediaMTXInputURL matches to get the ingest from MediaMTX + // WHIP + // kickInput will close the whip connection + // localRTMPPrefix set by ENV variable LIVE_AI_PLAYBACK_HOST + ssr := media.NewSwitchableSegmentReader() //this converts ingest to segments to send to Orchestrator + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: nil, + + liveParams: &liveRequestParams{ + segmentReader: ssr, + startTime: time.Now(), + rtmpOutputs: rtmpOutputs, + stream: streamID, //live video to video uses stream name, byoc combines to one id + paymentProcessInterval: ls.livePaymentInterval, + outSegmentTimeout: ls.outSegmentTimeout, + requestID: requestID, + streamID: streamID, + pipelineID: pipelineID, + pipeline: pipeline, + sendErrorEvent: sendErrorEvent, + manifestID: pipeline, //byoc uses one balance per capability name + }, + } + + //create a dataWriter for data channel if enabled + if job.Job.Params.EnableDataOutput { + params.liveParams.dataWriter = media.NewSegmentWriter(5) + } + + //check if stream exists + if params.inputStreamExists() { + return nil, http.StatusBadRequest, fmt.Errorf("stream already exists: %s", streamID) + } + + clog.Infof(ctx, "stream setup videoIngress=%v videoEgress=%v dataOutput=%v", job.Job.Params.EnableVideoIngress, job.Job.Params.EnableVideoEgress, job.Job.Params.EnableDataOutput) + + //save the stream setup + ls.LivepeerNode.NewLivePipeline(requestID, streamID, pipeline, params, bodyBytes) //track the pipeline for cancellation + + job.Job.Req.ID = streamID + streamUrls := StreamUrls{ + StreamId: streamID, + WhipUrl: whipURL, + WhepUrl: whepURL, + RtmpUrl: rtmpURL, + RtmpOutputUrl: strings.Join(rtmpOutputs, ","), + UpdateUrl: updateURL, + StatusUrl: statusURL, + DataUrl: dataURL, + } + + return &streamUrls, http.StatusOK, nil +} + +// mediamtx sends this request to go-livepeer when rtmp stream received +func (ls *LivepeerServer) StartStreamRTMPIngest() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(context.Background(), clog.ClientIP, remoteAddr) + + streamId := r.PathValue("streamId") + ctx = clog.AddVal(ctx, "stream_id", streamId) + + stream, ok := ls.LivepeerNode.LivePipelines[streamId] + if !ok { + respondJsonError(ctx, w, fmt.Errorf("stream not found: %s", streamId), http.StatusNotFound) + return + } + + params, err := getStreamRequestParams(stream) + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + //set source ID and source type needed for mediamtx client control api + sourceID := r.FormValue("source_id") + if sourceID == "" { + http.Error(w, "missing source_id", http.StatusBadRequest) + return + } + ctx = clog.AddVal(ctx, "source_id", sourceID) + sourceType := r.FormValue("source_type") + sourceType = strings.ToLower(sourceType) //normalize the source type so rtmpConn matches to rtmpconn + if sourceType == "" { + http.Error(w, "missing source_type", http.StatusBadRequest) + return + } + + clog.Infof(ctx, "RTMP ingest from MediaMTX connected sourceID=%s sourceType=%s", sourceID, sourceType) + //note that mediaMtxHost is the ip address of media mtx + // mediamtx sends a post request in the runOnReady event setup in mediamtx.yml + // StartLiveVideo calls this remoteHost + mediaMtxHost, err := getRemoteHost(r.RemoteAddr) + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + mediaMTXInputURL := fmt.Sprintf("rtmp://%s/%s%s", mediaMtxHost, "", streamId) + mediaMTXClient := media.NewMediaMTXClient(mediaMtxHost, ls.mediaMTXApiPassword, sourceID, sourceType) + segmenterCtx, cancelSegmenter := context.WithCancel(clog.Clone(context.Background(), ctx)) + + // this function is called when the pipeline hits a fatal error, we kick the input connection to allow + // the client to reconnect and restart the pipeline + kickInput := func(err error) { + defer cancelSegmenter() + if err == nil { + return + } + clog.Errorf(ctx, "Live video pipeline finished with error: %s", err) + + params.liveParams.sendErrorEvent(err) + + err = mediaMTXClient.KickInputConnection(ctx) + if err != nil { + clog.Errorf(ctx, "Failed to kick input connection: %s", err) + } + } + + params.liveParams.localRTMPPrefix = mediaMTXInputURL + params.liveParams.kickInput = kickInput + stream.UpdateStreamParams(params) //add kickInput to stream params + + // Kick off the RTMP pull and segmentation + clog.Infof(ctx, "Starting RTMP ingest from MediaMTX") + go func() { + ms := media.MediaSegmenter{Workdir: ls.LivepeerNode.WorkDir, MediaMTXClient: mediaMTXClient} + //segmenter blocks until done + ms.RunSegmentation(segmenterCtx, params.liveParams.localRTMPPrefix, params.liveParams.segmentReader.Read) + + params.liveParams.sendErrorEvent(errors.New("mediamtx ingest disconnected")) + monitor.SendQueueEventAsync("stream_trace", map[string]interface{}{ + "type": "gateway_ingest_stream_closed", + "timestamp": time.Now().UnixMilli(), + "stream_id": params.liveParams.streamID, + "pipeline_id": params.liveParams.pipelineID, + "request_id": params.liveParams.requestID, + "orchestrator_info": map[string]interface{}{ + "address": "", + "url": "", + }, + }) + params.liveParams.segmentReader.Close() + + stream.StopStream(nil) + }() + + //write response + w.WriteHeader(http.StatusOK) + }) +} + +func (ls *LivepeerServer) StartStreamWhipIngest(whipServer *media.WHIPServer) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(context.Background(), clog.ClientIP, remoteAddr) + + streamId := r.PathValue("streamId") + ctx = clog.AddVal(ctx, "stream_id", streamId) + + stream, ok := ls.LivepeerNode.LivePipelines[streamId] + if !ok { + respondJsonError(ctx, w, fmt.Errorf("stream not found: %s", streamId), http.StatusNotFound) + return + } + + params, err := getStreamRequestParams(stream) + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + whipConn := media.NewWHIPConnection() + whepURL := generateWhepUrl(streamId, params.liveParams.requestID) + + // this function is called when the pipeline hits a fatal error, we kick the input connection to allow + // the client to reconnect and restart the pipeline + kickInput := func(err error) { + if err == nil { + return + } + clog.Errorf(ctx, "Live video pipeline finished with error: %s", err) + params.liveParams.sendErrorEvent(err) + whipConn.Close() + } + params.liveParams.kickInput = kickInput + stream.UpdateStreamParams(params) //add kickInput to stream params + + //wait for the WHIP connection to close and then cleanup + go func() { + statsContext, statsCancel := context.WithCancel(ctx) + defer statsCancel() + go runStats(statsContext, whipConn, streamId, stream.Pipeline, params.liveParams.requestID) + + whipConn.AwaitClose() + params.liveParams.segmentReader.Close() + params.liveParams.kickOrch(errors.New("whip ingest disconnected")) + stream.StopStream(nil) + clog.Info(ctx, "Live cleaned up") + }() + + if whipServer == nil { + respondJsonError(ctx, w, fmt.Errorf("whip server not configured"), http.StatusInternalServerError) + whipConn.Close() + return + } + + conn := whipServer.CreateWHIP(ctx, params.liveParams.segmentReader, whepURL, w, r) + whipConn.SetWHIPConnection(conn) // might be nil if theres an error and thats okay + }) +} + +func startStreamProcessing(ctx context.Context, stream *core.LivePipeline, params aiRequestParams) error { + //required channels + control, err := common.AppendHostname(params.liveParams.orchControlUrl, params.liveParams.sess.BroadcastSession.Transcoder()) + if err != nil { + return fmt.Errorf("invalid control URL: %w", err) + } + events, err := common.AppendHostname(params.liveParams.orchEventsUrl, params.liveParams.sess.BroadcastSession.Transcoder()) + if err != nil { + return fmt.Errorf("invalid events URL: %w", err) + } + + startControlPublish(ctx, control, params) + startEventsSubscribe(ctx, events, params, params.liveParams.sess) + + //Optional channels + if params.liveParams.orchPublishUrl != "" { + clog.Infof(ctx, "Starting video ingress publisher") + pub, err := common.AppendHostname(params.liveParams.orchPublishUrl, params.liveParams.sess.BroadcastSession.Transcoder()) + if err != nil { + return fmt.Errorf("invalid publish URL: %w", err) + } + startTricklePublish(ctx, pub, params, params.liveParams.sess) + } + + if params.liveParams.orchSubscribeUrl != "" { + clog.Infof(ctx, "Starting video egress subscriber") + sub, err := common.AppendHostname(params.liveParams.orchSubscribeUrl, params.liveParams.sess.BroadcastSession.Transcoder()) + if err != nil { + return fmt.Errorf("invalid subscribe URL: %w", err) + } + startTrickleSubscribe(ctx, sub, params, params.liveParams.sess) + } + + if params.liveParams.orchDataUrl != "" { + clog.Infof(ctx, "Starting data channel subscriber") + data, err := common.AppendHostname(params.liveParams.orchDataUrl, params.liveParams.sess.BroadcastSession.Transcoder()) + if err != nil { + return fmt.Errorf("invalid data URL: %w", err) + } + params.liveParams.manifestID = stream.Pipeline + + startDataSubscribe(ctx, data, params, params.liveParams.sess) + } + + return nil +} + +func (ls *LivepeerServer) GetStreamData() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + streamId := r.PathValue("streamId") + if streamId == "" { + http.Error(w, "stream name is required", http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = clog.AddVal(ctx, "stream", streamId) + + // Get the live pipeline for this stream + stream, exists := ls.LivepeerNode.LivePipelines[streamId] + if !exists { + http.Error(w, "Stream not found", http.StatusNotFound) + return + } + params, err := getStreamRequestParams(stream) + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + // Get the data reading buffer + if params.liveParams.dataWriter == nil { + http.Error(w, "Stream data not available", http.StatusServiceUnavailable) + return + } + dataReader := params.liveParams.dataWriter.MakeReader(media.SegmentReaderConfig{}) + + // Set up SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + clog.Infof(ctx, "Starting SSE data stream for stream=%s", streamId) + + // Listen for broadcast signals from ring buffer writes + // dataReader.Read() blocks on rb.cond.Wait() until startDataSubscribe broadcasts + for { + select { + case <-ctx.Done(): + clog.Info(ctx, "SSE data stream client disconnected") + return + default: + reader, err := dataReader.Next() + if err != nil { + if err == io.EOF { + // Stream ended + fmt.Fprintf(w, `event: end\ndata: {"type":"stream_ended"}\n\n`) + flusher.Flush() + return + } + clog.Errorf(ctx, "Error reading from ring buffer: %v", err) + return + } + start := time.Now() + data, err := io.ReadAll(reader) + clog.V(6).Infof(ctx, "SSE data read took %v", time.Since(start)) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + } + }) +} + +func (ls *LivepeerServer) UpdateStream() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + // Get stream from path param + streamId := r.PathValue("streamId") + if streamId == "" { + http.Error(w, "Missing stream name", http.StatusBadRequest) + return + } + stream, ok := ls.LivepeerNode.LivePipelines[streamId] + if !ok { + // Stream not found + http.Error(w, "Stream not found", http.StatusNotFound) + return + } + + reader := http.MaxBytesReader(w, r.Body, 10*1024*1024) // 10 MB + defer reader.Close() + data, err := io.ReadAll(reader) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + params := string(data) + stream.Params = data + controlPub := stream.ControlPub + + if controlPub == nil { + clog.Info(ctx, "No orchestrator available, caching params", "stream", streamId, "params", params) + return + } + + clog.V(6).Infof(ctx, "Sending Live Video Update Control API stream=%s, params=%s", stream, params) + if err := controlPub.Write(strings.NewReader(params)); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + corsHeaders(w, r.Method) + }) +} + +func (ls *LivepeerServer) GetStreamStatus() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + corsHeaders(w, r.Method) + + streamId := r.PathValue("streamId") + if streamId == "" { + http.Error(w, "stream id is required", http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = clog.AddVal(ctx, "stream", streamId) + + // Get status for specific stream + status, exists := StreamStatusStore.Get(streamId) + gatewayStatus, gatewayExists := GatewayStatus.Get(streamId) + if !exists && !gatewayExists { + http.Error(w, "Stream not found", http.StatusNotFound) + return + } + if gatewayExists { + if status == nil { + status = make(map[string]any) + } + status["gateway_status"] = gatewayStatus + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(status); err != nil { + clog.Errorf(ctx, "Failed to encode stream status err=%v", err) + http.Error(w, "Failed to encode status", http.StatusInternalServerError) + return + } + }) +} + +// StartStream handles the POST /stream/start endpoint for the Orchestrator +func (h *lphttp) StartStream(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + orchJob, err := h.setupOrchJob(ctx, r, false) + if err != nil { + code := http.StatusBadRequest + if err == errInsufficientBalance { + code = http.StatusPaymentRequired + } + respondWithError(w, err.Error(), code) + return + } + ctx = clog.AddVal(ctx, "stream_id", orchJob.Req.ID) + + workerRoute := orchJob.Req.CapabilityUrl + "/stream/start" + + // Read the original body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading request body", http.StatusBadRequest) + return + } + r.Body.Close() + + var jobParams JobParameters + err = json.Unmarshal([]byte(orchJob.Req.Parameters), &jobParams) + if err != nil { + clog.Errorf(ctx, "unable to parse parameters err=%v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + clog.Infof(ctx, "Processing stream start request videoIngress=%v videoEgress=%v dataOutput=%v", jobParams.EnableVideoIngress, jobParams.EnableVideoEgress, jobParams.EnableDataOutput) + // Start trickle server for live-video + var ( + mid = orchJob.Req.ID // Request ID is used for the manifest ID + pubUrl = h.orchestrator.ServiceURI().JoinPath(TrickleHTTPPath, mid).String() + subUrl = pubUrl + "-out" + controlUrl = pubUrl + "-control" + eventsUrl = pubUrl + "-events" + dataUrl = pubUrl + "-data" + pubCh *trickle.TrickleLocalPublisher + subCh *trickle.TrickleLocalPublisher + controlPubCh *trickle.TrickleLocalPublisher + eventsCh *trickle.TrickleLocalPublisher + dataCh *trickle.TrickleLocalPublisher + ) + + reqBodyForRunner := make(map[string]interface{}) + reqBodyForRunner["gateway_request_id"] = mid + //required channels + controlPubCh = trickle.NewLocalPublisher(h.trickleSrv, mid+"-control", "application/json") + controlPubCh.CreateChannel() + controlUrl = overwriteHost(h.node.LiveAITrickleHostForRunner, controlUrl) + reqBodyForRunner["control_url"] = controlUrl + w.Header().Set("X-Control-Url", controlUrl) + + eventsCh = trickle.NewLocalPublisher(h.trickleSrv, mid+"-events", "application/json") + eventsCh.CreateChannel() + eventsUrl = overwriteHost(h.node.LiveAITrickleHostForRunner, eventsUrl) + reqBodyForRunner["events_url"] = eventsUrl + w.Header().Set("X-Events-Url", eventsUrl) + + //Optional channels + if jobParams.EnableVideoIngress { + pubCh = trickle.NewLocalPublisher(h.trickleSrv, mid, "video/MP2T") + pubCh.CreateChannel() + pubUrl = overwriteHost(h.node.LiveAITrickleHostForRunner, pubUrl) + reqBodyForRunner["subscribe_url"] = pubUrl //runner needs to subscribe to input + w.Header().Set("X-Publish-Url", pubUrl) //gateway will connect to pubUrl to send ingress video + } + + if jobParams.EnableVideoEgress { + subCh = trickle.NewLocalPublisher(h.trickleSrv, mid+"-out", "video/MP2T") + subCh.CreateChannel() + subUrl = overwriteHost(h.node.LiveAITrickleHostForRunner, subUrl) + reqBodyForRunner["publish_url"] = subUrl //runner needs to send results -out + w.Header().Set("X-Subscribe-Url", subUrl) //gateway will connect to subUrl to receive results + } + + if jobParams.EnableDataOutput { + dataCh = trickle.NewLocalPublisher(h.trickleSrv, mid+"-data", "application/jsonl") + dataCh.CreateChannel() + dataUrl = overwriteHost(h.node.LiveAITrickleHostForRunner, dataUrl) + reqBodyForRunner["data_url"] = dataUrl + w.Header().Set("X-Data-Url", dataUrl) + } + + reqBodyForRunner["request"] = string(body) + reqBodyBytes, err := json.Marshal(reqBodyForRunner) + if err != nil { + clog.Errorf(ctx, "Failed to marshal request body err=%v", err) + http.Error(w, "Failed to marshal request body", http.StatusInternalServerError) + return + } + + req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(reqBodyBytes)) + // set the headers + req.Header.Add("Content-Length", r.Header.Get("Content-Length")) + req.Header.Add("Content-Type", r.Header.Get("Content-Type")) + + start := time.Now() + resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) + if err != nil { + clog.Errorf(ctx, "Error sending request to worker %v: %v", workerRoute, err) + respondWithError(w, "Error sending request to worker", http.StatusInternalServerError) + return + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + clog.Errorf(ctx, "Error reading response body: %v", err) + respondWithError(w, "Error reading response body", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + //error response from worker but assume can retry and pass along error response and status code + if resp.StatusCode > 399 { + clog.Errorf(ctx, "error processing stream start request statusCode=%d", resp.StatusCode) + + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + //return error response from the worker + w.WriteHeader(resp.StatusCode) + w.Write(respBody) + return + } + + chargeForCompute(start, orchJob.JobPrice, orch, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + + clog.V(common.SHORT).Infof(ctx, "stream start processed successfully took=%v balance=%v", time.Since(start), getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + + //setup the stream + stream, err := h.node.ExternalCapabilities.AddStream(orchJob.Req.ID, orchJob.Req.Capability, reqBodyBytes) + if err != nil { + clog.Errorf(ctx, "Error adding stream to external capabilities: %v", err) + respondWithError(w, "Error adding stream to external capabilities", http.StatusInternalServerError) + return + } + + stream.SetChannels(pubCh, subCh, controlPubCh, eventsCh, dataCh) + + //start payment monitoring + go func() { + stream, _ := h.node.ExternalCapabilities.Streams[orchJob.Req.ID] + ctx := context.Background() + ctx = clog.AddVal(ctx, "stream_id", orchJob.Req.ID) + ctx = clog.AddVal(ctx, "capability", orchJob.Req.Capability) + + pmtCheckDur := 23 * time.Second //run slightly faster than gateway so can return updated balance + pmtTicker := time.NewTicker(pmtCheckDur) + defer pmtTicker.Stop() + shouldStopStreamNextRound := false + for { + select { + case <-stream.StreamCtx.Done(): + h.orchestrator.FreeExternalCapabilityCapacity(orchJob.Req.Capability) + clog.Infof(ctx, "Stream ended, stopping payment monitoring and released capacity") + return + case <-pmtTicker.C: + // Check payment status + + jobPriceRat := big.NewRat(orchJob.JobPrice.PricePerUnit, orchJob.JobPrice.PixelsPerUnit) + if jobPriceRat.Cmp(big.NewRat(0, 1)) > 0 { + h.orchestrator.DebitFees(orchJob.Sender, core.ManifestID(orchJob.Req.Capability), orchJob.JobPrice, int64(pmtCheckDur.Seconds())) + senderBalance := getPaymentBalance(orch, orchJob.Sender, orchJob.Req.Capability) + if senderBalance != nil { + if senderBalance.Cmp(big.NewRat(0, 1)) < 0 { + if !shouldStopStreamNextRound { + //warn once + clog.Warningf(ctx, "Insufficient balance for stream capability, will stop stream next round if not replenished sender=%s capability=%s balance=%s", orchJob.Sender, orchJob.Req.Capability, senderBalance.FloatString(0)) + shouldStopStreamNextRound = true + continue + } + + clog.Infof(ctx, "Insufficient balance, stopping stream %s for sender %s", orchJob.Req.ID, orchJob.Sender) + _, exists := h.node.ExternalCapabilities.Streams[orchJob.Req.ID] + if exists { + h.node.ExternalCapabilities.RemoveStream(orchJob.Req.ID) + } + + return + } + + clog.V(8).Infof(ctx, "Payment balance for stream capability is good balance=%v", senderBalance.FloatString(0)) + } + } + + //check if stream still exists + // if not, send stop to worker and exit monitoring + stream, exists := h.node.ExternalCapabilities.Streams[orchJob.Req.ID] + if !exists { + req, err := http.NewRequestWithContext(ctx, "POST", orchJob.Req.CapabilityUrl+"/stream/stop", nil) + // set the headers + _, err = sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) + if err != nil { + clog.Errorf(ctx, "Error sending request to worker %v: %v", orchJob.Req.CapabilityUrl, err) + respondWithError(w, "Error sending request to worker", http.StatusInternalServerError) + return + } + //end monitoring of stream + return + } + + //check if control channel is still open, end if not + if !stream.IsActive() { + // Stop the stream and free capacity + h.node.ExternalCapabilities.RemoveStream(orchJob.Req.ID) + return + } + } + } + }() + + //send back the trickle urls set in header + w.WriteHeader(http.StatusOK) + return +} + +func (h *lphttp) StopStream(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + orchJob, err := h.setupOrchJob(ctx, r, false) + if err != nil { + respondWithError(w, fmt.Sprintf("Failed to stop stream, request not valid err=%v", err), http.StatusBadRequest) + return + } + + var jobDetails JobRequestDetails + err = json.Unmarshal([]byte(orchJob.Req.Request), &jobDetails) + if err != nil { + respondWithError(w, fmt.Sprintf("Failed to stop stream, request not valid, failed to parse stream id err=%v", err), http.StatusBadRequest) + return + } + clog.Infof(ctx, "Stopping stream %s", jobDetails.StreamId) + + // Read the original body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading request body", http.StatusBadRequest) + return + } + r.Body.Close() + + workerRoute := orchJob.Req.CapabilityUrl + "/stream/stop" + req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(body)) + if err != nil { + clog.Errorf(ctx, "failed to create /stop/stream request to worker err=%v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) + if err != nil { + clog.Errorf(ctx, "Error sending request to worker %v: %v", workerRoute, err) + respondWithError(w, "Error sending request to worker", http.StatusInternalServerError) + return + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + clog.Errorf(ctx, "Error reading response body: %v", err) + respondWithError(w, "Error reading response body", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if resp.StatusCode > 399 { + clog.Errorf(ctx, "error processing stream stop request statusCode=%d", resp.StatusCode) + } + + // Stop the stream and free capacity + h.node.ExternalCapabilities.RemoveStream(jobDetails.StreamId) + + w.WriteHeader(resp.StatusCode) + w.Write(respBody) +} + +func (h *lphttp) ProcessStreamPayment(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + ctx := r.Context() + + //this will validate the request and process the payment + orchJob, err := h.setupOrchJob(ctx, r, false) + if err != nil { + respondWithError(w, fmt.Sprintf("Failed to process payment, request not valid err=%v", err), http.StatusBadRequest) + return + } + ctx = clog.AddVal(ctx, "stream_id", orchJob.Details.StreamId) + ctx = clog.AddVal(ctx, "capability", orchJob.Req.Capability) + ctx = clog.AddVal(ctx, "sender", orchJob.Req.Sender) + + senderAddr := ethcommon.HexToAddress(orchJob.Req.Sender) + + capBal := orch.Balance(senderAddr, core.ManifestID(orchJob.Req.Capability)) + if capBal != nil { + capBal, err = common.PriceToInt64(capBal) + if err != nil { + clog.Errorf(ctx, "could not convert balance to int64 sender=%v capability=%v err=%v", senderAddr.Hex(), orchJob.Req.Capability, err.Error()) + capBal = big.NewRat(0, 1) + } + } else { + capBal = big.NewRat(0, 1) + } + + w.Header().Set(jobPaymentBalanceHdr, capBal.FloatString(0)) + w.WriteHeader(http.StatusOK) +} + +func tokenToAISession(token core.JobToken) (AISession, error) { + var session BroadcastSession + + // Initialize the lock to avoid nil pointer dereference in methods + // like (*BroadcastSession).Transcoder() which acquire RLock() + session.lock = &sync.RWMutex{} + + //default to zero price if its nil, Orchestrator will reject stream if charging a price above zero + if token.Price == nil { + token.Price = &net.PriceInfo{} + } + + orchInfo := net.OrchestratorInfo{Transcoder: token.ServiceAddr, TicketParams: token.TicketParams, PriceInfo: token.Price} + orchInfo.Transcoder = token.ServiceAddr + if token.SenderAddress != nil { + orchInfo.Address = ethcommon.Hex2Bytes(token.SenderAddress.Addr) + } + session.OrchestratorInfo = &orchInfo + + return AISession{BroadcastSession: &session}, nil +} + +func sessionToToken(session *AISession) (core.JobToken, error) { + var token core.JobToken + + token.ServiceAddr = session.OrchestratorInfo.Transcoder + token.TicketParams = session.OrchestratorInfo.TicketParams + token.Price = session.OrchestratorInfo.PriceInfo + return token, nil +} + +func getStreamRequestParams(stream *core.LivePipeline) (aiRequestParams, error) { + if stream == nil { + return aiRequestParams{}, fmt.Errorf("stream is nil") + } + + streamParams := stream.StreamParams() + params, ok := streamParams.(aiRequestParams) + if !ok { + return aiRequestParams{}, fmt.Errorf("failed to cast stream params to aiRequestParams") + } + return params, nil +} diff --git a/server/job_stream_test.go b/server/job_stream_test.go new file mode 100644 index 0000000000..ef9600826c --- /dev/null +++ b/server/job_stream_test.go @@ -0,0 +1,1597 @@ +package server + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/livepeer/go-livepeer/core" + "github.com/livepeer/go-livepeer/media" + "github.com/livepeer/go-livepeer/pm" + "github.com/livepeer/go-livepeer/trickle" + "github.com/livepeer/go-tools/drivers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var stubOrchServerUrl string + +// testOrch wraps mockOrchestrator to override a few methods needed by lphttp in tests +type testStreamOrch struct { + *mockOrchestrator + svc *url.URL + capURL string +} + +func (o *testStreamOrch) ServiceURI() *url.URL { return o.svc } +func (o *testStreamOrch) GetUrlForCapability(capability string) string { return o.capURL } + +// streamingResponseWriter implements http.ResponseWriter for streaming responses +type streamingResponseWriter struct { + pipe *io.PipeWriter + headers http.Header + status int +} + +func (w *streamingResponseWriter) Header() http.Header { + return w.headers +} + +func (w *streamingResponseWriter) Write(data []byte) (int, error) { + return w.pipe.Write(data) +} + +func (w *streamingResponseWriter) WriteHeader(statusCode int) { + w.status = statusCode +} + +// Helper: base64-encoded JobRequest with JobParameters (Enable all true, test-capability name) +func base64TestJobRequest(timeout int, enableVideoIngress, enableVideoEgress, enableDataOutput bool) string { + params := JobParameters{ + EnableVideoIngress: enableVideoIngress, + EnableVideoEgress: enableVideoEgress, + EnableDataOutput: enableDataOutput, + } + paramsStr, _ := json.Marshal(params) + + jr := JobRequest{ + Capability: "test-capability", + Parameters: string(paramsStr), + Request: "{}", + Timeout: timeout, + } + + b, _ := json.Marshal(jr) + + return base64.StdEncoding.EncodeToString(b) +} + +func orchAIStreamStartHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/ai/stream/start" { + http.NotFound(w, r) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Publish-Url", fmt.Sprintf("%s%s%s", stubOrchServerUrl, TrickleHTTPPath, "test-stream")) + w.Header().Set("X-Subscribe-Url", fmt.Sprintf("%s%s%s", stubOrchServerUrl, TrickleHTTPPath, "test-stream-out")) + w.Header().Set("X-Control-Url", fmt.Sprintf("%s%s%s", stubOrchServerUrl, TrickleHTTPPath, "test-stream-control")) + w.Header().Set("X-Events-Url", fmt.Sprintf("%s%s%s", stubOrchServerUrl, TrickleHTTPPath, "test-stream-events")) + w.Header().Set("X-Data-Url", fmt.Sprintf("%s%s%s", stubOrchServerUrl, TrickleHTTPPath, "test-stream-data")) + w.WriteHeader(http.StatusOK) +} + +func orchCapabilityUrlHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func TestStartStream_MaxBodyLimit(t *testing.T) { + // Setup server with minimal dependencies + node := mockJobLivepeerNode() + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + + // Set up mock sender to prevent nil pointer dereference + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo") + mockSender.On("CreateTicketBatch", mock.Anything, mock.Anything).Return(mockTicketBatch(10), nil) + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + ls := &LivepeerServer{LivepeerNode: node} + + // Prepare a valid job request header + jobDetails := JobRequestDetails{StreamId: "test-stream"} + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + jobReq := JobRequest{ + ID: "job-1", + Request: marshalToString(t, jobDetails), + Parameters: marshalToString(t, jobParams), + Capability: "test-capability", + Timeout: 10, + } + jobReqB, err := json.Marshal(jobReq) + assert.NoError(t, err) + jobReqB64 := base64.StdEncoding.EncodeToString(jobReqB) + + // Create a body over 10MB + bigBody := bytes.Repeat([]byte("a"), 10<<20+1) // 10MB + 1 byte + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(bigBody)) + req.Header.Set(jobRequestHdr, jobReqB64) + + w := httptest.NewRecorder() + handler := ls.StartStream() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) +} + +func TestStreamStart_SetupStream(t *testing.T) { + node := mockJobLivepeerNode() + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + + // Set up mock sender to prevent nil pointer dereference + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo") + mockSender.On("CreateTicketBatch", mock.Anything, mock.Anything).Return(mockTicketBatch(10), nil) + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Prepare a valid gatewayJob + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq := &JobRequest{ + Capability: "test-capability", + Parameters: paramsStr, + Timeout: 10, + } + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob} + + // Prepare a valid StartRequest body + startReq := StartRequest{ + Stream: "teststream", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + body, _ := json.Marshal(startReq) + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + urls, code, err := ls.setupStream(context.Background(), req, gatewayJob) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, code) + assert.NotNil(t, urls) + assert.Equal(t, "teststream-streamid", urls.StreamId) + //confirm all urls populated + assert.NotEmpty(t, urls.WhipUrl) + assert.NotEmpty(t, urls.RtmpUrl) + assert.NotEmpty(t, urls.WhepUrl) + assert.NotEmpty(t, urls.RtmpOutputUrl) + assert.Contains(t, urls.RtmpOutputUrl, "rtmp://output") + assert.NotEmpty(t, urls.DataUrl) + assert.NotEmpty(t, urls.StatusUrl) + assert.NotEmpty(t, urls.UpdateUrl) + + //confirm LivePipeline created + stream, ok := ls.LivepeerNode.LivePipelines[urls.StreamId] + assert.True(t, ok) + assert.NotNil(t, stream) + assert.Equal(t, urls.StreamId, stream.StreamID) + assert.Equal(t, stream.StreamRequest(), body) + params := stream.StreamParams() + _, checkParamsType := params.(aiRequestParams) + assert.True(t, checkParamsType) + + //test with no data output + jobParams = JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: false} + paramsStr = marshalToString(t, jobParams) + jobReq.Parameters = paramsStr + gatewayJob.Job.Params = &jobParams + req.Body = io.NopCloser(bytes.NewReader(body)) + urls, code, err = ls.setupStream(context.Background(), req, gatewayJob) + assert.Empty(t, urls.DataUrl) + + //test with no video ingress + jobParams = JobParameters{EnableVideoIngress: false, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr = marshalToString(t, jobParams) + jobReq.Parameters = paramsStr + gatewayJob.Job.Params = &jobParams + req.Body = io.NopCloser(bytes.NewReader(body)) + urls, code, err = ls.setupStream(context.Background(), req, gatewayJob) + assert.Empty(t, urls.WhipUrl) + assert.Empty(t, urls.RtmpUrl) + + //test with no video egress + jobParams = JobParameters{EnableVideoIngress: true, EnableVideoEgress: false, EnableDataOutput: true} + paramsStr = marshalToString(t, jobParams) + jobReq.Parameters = paramsStr + gatewayJob.Job.Params = &jobParams + req.Body = io.NopCloser(bytes.NewReader(body)) + urls, code, err = ls.setupStream(context.Background(), req, gatewayJob) + assert.Empty(t, urls.WhepUrl) + assert.Empty(t, urls.RtmpOutputUrl) + + // Test with nil job + urls, code, err = ls.setupStream(context.Background(), req, nil) + assert.Error(t, err) + assert.Equal(t, http.StatusBadRequest, code) + assert.Nil(t, urls) + + // Test with invalid JSON body + badReq := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader([]byte("notjson"))) + badReq.Header.Set("Content-Type", "application/json") + urls, code, err = ls.setupStream(context.Background(), badReq, gatewayJob) + assert.Error(t, err) + assert.Equal(t, http.StatusBadRequest, code) + assert.Nil(t, urls) + + // Test with stream name ending in -out (should return nil, 0, nil) + outReq := StartRequest{ + Stream: "teststream-out", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + outBody, _ := json.Marshal(outReq) + outReqHTTP := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(outBody)) + outReqHTTP.Header.Set("Content-Type", "application/json") + urls, code, err = ls.setupStream(context.Background(), outReqHTTP, gatewayJob) + assert.NoError(t, err) + assert.Equal(t, 0, code) + assert.Nil(t, urls) +} + +func TestRunStream_RunAndCancelStream(t *testing.T) { + node := mockJobLivepeerNode() + + // Set up an lphttp-based orchestrator test server with trickle endpoints + mux := http.NewServeMux() + mockOrch := &mockOrchestrator{} + mockOrch.On("VerifySig", mock.Anything, mock.Anything, mock.Anything).Return(true) + mockOrch.On("DebitFees", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + + lp := &lphttp{orchestrator: nil, transRPC: mux, node: node} + // Configure trickle server on the mux (imitate production trickle endpoints) + lp.trickleSrv = trickle.ConfigureServer(trickle.TrickleServerConfig{ + Mux: mux, + BasePath: TrickleHTTPPath, + Autocreate: true, + }) + // Register orchestrator endpoints used by runStream path + mux.HandleFunc("/ai/stream/start", lp.StartStream) + mux.HandleFunc("/ai/stream/stop", lp.StopStream) + mux.HandleFunc("/process/token", orchTokenHandler) + + server := httptest.NewServer(lp) + defer server.Close() + stubOrchServerUrl = server.URL + + // Configure mock orchestrator behavior expected by lphttp handlers + parsedURL, _ := url.Parse(server.URL) + capabilitySrv := httptest.NewServer(http.HandlerFunc(orchCapabilityUrlHandler)) + defer capabilitySrv.Close() + + // attach our orchestrator implementation to lphttp + lp.orchestrator = &testStreamOrch{mockOrchestrator: mockOrch, svc: parsedURL, capURL: capabilitySrv.URL} + + // Prepare a gatewayJob with a dummy orchestrator token + jobReq := &JobRequest{ + ID: "test-stream", + Capability: "test-capability", + Timeout: 10, + Request: "{}", + } + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq.Parameters = paramsStr + + orchToken := createMockJobToken(server.URL) + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob, Orchs: []core.JobToken{*orchToken}, node: node} + + // Setup a LivepeerServer and a mock pipeline + ls := &LivepeerServer{LivepeerNode: node} + ls.LivepeerNode.OrchestratorPool = newStubOrchestratorPool(ls.LivepeerNode, []string{server.URL}) + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(4) + mockSender.On("CreateTicketBatch", "foo", orchJob.Req.Timeout).Return(mockTicketBatch(orchJob.Req.Timeout), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + //now sign job and create a sig for the sender to include + gatewayJob.sign() + sender, err := getJobSender(context.TODO(), node) + assert.NoError(t, err) + orchJob.Req.Sender = sender.Addr + orchJob.Req.Sig = sender.Sig + // Minimal aiRequestParams and liveRequestParams + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + stream: "test-stream", + streamID: "test-stream", + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + ls.LivepeerNode.NewLivePipeline("req-1", "test-stream", "test-capability", params, nil) + + // Cancel the stream after a short delay to simulate shutdown + done := make(chan struct{}) + go func() { + time.Sleep(100 * time.Millisecond) + stream := node.LivePipelines["test-stream"] + if stream != nil { + params, _ := getStreamRequestParams(stream) + if params.liveParams.kickOrch != nil { + params.liveParams.kickOrch(errors.New("test cancel")) + } + + stream.StopStream(nil) + } + close(done) + }() + + // Should not panic and should clean up + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); ls.runStream(gatewayJob) }() + go func() { defer wg.Done(); ls.monitorStream(gatewayJob.Job.Req.ID) }() + <-done + // Wait for both goroutines to finish before asserting + wg.Wait() + // After cancel, the stream should be removed from LivePipelines + _, exists := node.LivePipelines["test-stream"] + assert.False(t, exists) +} + +// Test StartStream handler +func TestStartStreamHandler(t *testing.T) { + node := mockJobLivepeerNode() + + // Set up an lphttp-based orchestrator test server with trickle endpoints + mux := http.NewServeMux() + ls := &LivepeerServer{ + LivepeerNode: node, + } + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo") + mockSender.On("CreateTicketBatch", mock.Anything, mock.Anything).Return(mockTicketBatch(10), nil) + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(1 * time.Second) + defer node.Balances.StopCleanup() + //setup Orch server stub + mux.HandleFunc("/process/token", orchTokenHandler) + mux.HandleFunc("/ai/stream/start", orchAIStreamStartHandler) + server := httptest.NewServer(mux) + defer server.Close() + ls.LivepeerNode.OrchestratorPool = newStubOrchestratorPool(ls.LivepeerNode, []string{server.URL}) + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + // Prepare a valid StartRequest body + startReq := StartRequest{ + Stream: "teststream", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + body, _ := json.Marshal(startReq) + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Livepeer", base64TestJobRequest(10, true, true, true)) + + w := httptest.NewRecorder() + + handler := ls.StartStream() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + body = w.Body.Bytes() + var streamUrls StreamUrls + err := json.Unmarshal(body, &streamUrls) + assert.NoError(t, err) + stream, exits := ls.LivepeerNode.LivePipelines[streamUrls.StreamId] + assert.True(t, exits) + assert.NotNil(t, stream) + assert.Equal(t, streamUrls.StreamId, stream.StreamID) + params := stream.StreamParams() + streamParams, checkParamsType := params.(aiRequestParams) + assert.True(t, checkParamsType) + //wrap up processing + time.Sleep(100 * time.Millisecond) + streamParams.liveParams.kickOrch(errors.New("test error")) + stream.StopStream(nil) +} + +// Test StopStream handler +func TestStopStreamHandler(t *testing.T) { + t.Run("StreamNotFound", func(t *testing.T) { + // Test case 1: Stream doesn't exist - should return 404 + ls := &LivepeerServer{LivepeerNode: &core.LivepeerNode{LivePipelines: map[string]*core.LivePipeline{}}} + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/stop", nil) + req.SetPathValue("streamId", "non-existent-stream") + w := httptest.NewRecorder() + + handler := ls.StopStream() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Contains(t, w.Body.String(), "Stream not found") + }) + + t.Run("StreamExistsAndStopsSuccessfully", func(t *testing.T) { + // Test case 2: Stream exists - should stop stream and attempt to send request to orchestrator + node := mockJobLivepeerNode() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Mock orchestrator response handlers + switch r.URL.Path { + case "/process/token": + orchTokenHandler(w, r) + case "/ai/stream/stop": + // Mock successful stop response from orchestrator + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "stopped"}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(4) + mockSender.On("CreateTicketBatch", "foo", 10).Return(mockTicketBatch(10), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + // Create a stream to stop + streamID := "test-stream-to-stop" + + // Create minimal AI session with properly formatted URL + token := createMockJobToken(server.URL) + + sess, err := tokenToAISession(*token) + + // Create stream parameters + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + // Add the stream to LivePipelines + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + assert.NotNil(t, stream) + + // Verify stream exists before stopping + _, exists := ls.LivepeerNode.LivePipelines[streamID] + assert.True(t, exists, "Stream should exist before stopping") + + // Create stop request with proper job header + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + jobDetails := JobRequestDetails{StreamId: streamID} + jobReq := JobRequest{ + ID: streamID, + Request: marshalToString(t, jobDetails), + Capability: "test-capability", + Parameters: marshalToString(t, jobParams), + Timeout: 10, + } + jobReqB, err := json.Marshal(jobReq) + assert.NoError(t, err) + jobReqB64 := base64.StdEncoding.EncodeToString(jobReqB) + + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/stop", strings.NewReader(`{"reason": "test stop"}`)) + req.SetPathValue("streamId", streamID) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(jobRequestHdr, jobReqB64) + + w := httptest.NewRecorder() + + handler := ls.StopStream() + handler.ServeHTTP(w, req) + + // The response might vary depending on orchestrator communication success + // The important thing is that the stream is removed regardless + assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError, http.StatusBadRequest}, w.Code, + "Should return valid HTTP status") + + // Verify stream was removed from LivePipelines (this should always happen) + _, exists = ls.LivepeerNode.LivePipelines[streamID] + assert.False(t, exists, "Stream should be removed after stopping") + }) + + t.Run("StreamExistsButOrchestratorError", func(t *testing.T) { + // Test case 3: Stream exists but orchestrator returns error + node := mockJobLivepeerNode() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/process/token": + orchTokenHandler(w, r) + case "/ai/stream/stop": + // Mock orchestrator error + http.Error(w, "Orchestrator error", http.StatusInternalServerError) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(4) + mockSender.On("CreateTicketBatch", "foo", 10).Return(mockTicketBatch(10), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + streamID := "test-stream-orch-error" + + // Create minimal AI session + token := createMockJobToken(server.URL) + sess, err := tokenToAISession(*token) + assert.NoError(t, err) + + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + // Add the stream + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + assert.NotNil(t, stream) + + // Create stop request + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + jobDetails := JobRequestDetails{StreamId: streamID} + jobReq := JobRequest{ + ID: streamID, + Request: marshalToString(t, jobDetails), + Capability: "test-capability", + Parameters: marshalToString(t, jobParams), + Timeout: 10, + } + jobReqB, err := json.Marshal(jobReq) + assert.NoError(t, err) + jobReqB64 := base64.StdEncoding.EncodeToString(jobReqB) + + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/stop", nil) + req.SetPathValue("streamId", streamID) + req.Header.Set(jobRequestHdr, jobReqB64) + + w := httptest.NewRecorder() + + handler := ls.StopStream() + handler.ServeHTTP(w, req) + + // Returns 200 OK because Gateway removed the stream. If the Orchestrator errors, it will return + // the error in the response body + assert.Equal(t, http.StatusOK, w.Code) + + // Stream should still be removed even if orchestrator returns error + _, exists := ls.LivepeerNode.LivePipelines[streamID] + assert.False(t, exists, "Stream should be removed even on orchestrator error") + }) +} + +// Test StartStreamRTMPIngest handler +func TestStartStreamRTMPIngestHandler(t *testing.T) { + // Setup mock MediaMTX server on port 9997 before starting the test + mockMediaMTXServer := createMockMediaMTXServer(t) + defer mockMediaMTXServer.Close() + + node := mockJobLivepeerNode() + node.WorkDir = t.TempDir() + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + + ls := &LivepeerServer{ + LivepeerNode: node, + mediaMTXApiPassword: "test-password", + } + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Prepare a valid gatewayJob + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq := &JobRequest{ + Capability: "test-capability", + Parameters: paramsStr, + Timeout: 10, + } + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob} + + // Prepare a valid StartRequest body + startReq := StartRequest{ + Stream: "teststream", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + body, _ := json.Marshal(startReq) + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + urls, code, err := ls.setupStream(context.Background(), req, gatewayJob) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, code) + assert.NotNil(t, urls) + assert.Equal(t, "teststream-streamid", urls.StreamId) //combination of stream name (Stream) and id (StreamId) + + stream, ok := ls.LivepeerNode.LivePipelines[urls.StreamId] + assert.True(t, ok) + assert.NotNil(t, stream) + + params, err := getStreamRequestParams(stream) + assert.NoError(t, err) + + //these should be empty/nil before rtmp ingest starts + assert.Empty(t, params.liveParams.localRTMPPrefix) + assert.Nil(t, params.liveParams.kickInput) + + rtmpReq := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/rtmp", nil) + rtmpReq.SetPathValue("streamId", "teststream-streamid") + w := httptest.NewRecorder() + + handler := ls.StartStreamRTMPIngest() + handler.ServeHTTP(w, rtmpReq) + // Missing source_id and source_type + assert.Equal(t, http.StatusBadRequest, w.Code) + + // Now provide valid form data + formData := url.Values{} + formData.Set("source_id", "testsourceid") + formData.Set("source_type", "rtmpconn") + rtmpReq = httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/rtmp", strings.NewReader(formData.Encode())) + rtmpReq.SetPathValue("streamId", "teststream-streamid") + // Use localhost as the remote addr to simulate MediaMTX + rtmpReq.RemoteAddr = "127.0.0.1:1935" + + rtmpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w = httptest.NewRecorder() + handler.ServeHTTP(w, rtmpReq) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify that the stream parameters were updated correctly + newParams, _ := getStreamRequestParams(stream) + assert.NotNil(t, newParams.liveParams.kickInput) + assert.NotEmpty(t, newParams.liveParams.localRTMPPrefix) + + // Stop the stream to cleanup + newParams.liveParams.kickInput(errors.New("test error")) + stream.StopStream(nil) +} + +// Test StartStreamWhipIngest handler +func TestStartStreamWhipIngestHandler(t *testing.T) { + node := mockJobLivepeerNode() + node.WorkDir = t.TempDir() + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Prepare a valid gatewayJob + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq := &JobRequest{ + Capability: "test-capability", + Parameters: paramsStr, + Timeout: 10, + } + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob} + + // Prepare a valid StartRequest body for /ai/stream/start + startReq := StartRequest{ + Stream: "teststream", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + body, _ := json.Marshal(startReq) + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + urls, code, err := ls.setupStream(context.Background(), req, gatewayJob) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, code) + assert.NotNil(t, urls) + assert.Equal(t, "teststream-streamid", urls.StreamId) //combination of stream name (Stream) and id (StreamId) + + stream, ok := ls.LivepeerNode.LivePipelines[urls.StreamId] + assert.True(t, ok) + assert.NotNil(t, stream) + + params, err := getStreamRequestParams(stream) + assert.NoError(t, err) + + //these should be empty/nil before whip ingest starts + assert.Empty(t, params.liveParams.localRTMPPrefix) + assert.Nil(t, params.liveParams.kickInput) + + // whipServer is required, using nil will test setup up to initializing the WHIP connection + whipServer := media.NewWHIPServer() + handler := ls.StartStreamWhipIngest(whipServer) + + // SDP offer for WHIP with H.264 video and Opus audio + sdpOffer := `v=0 +o=- 123456789 2 IN IP4 127.0.0.1 +s=- +t=0 0 +a=group:BUNDLE 0 1 +a=msid-semantic: WMS stream +m=video 9 UDP/TLS/RTP/SAVPF 96 +c=IN IP4 0.0.0.0 +a=rtcp:9 IN IP4 0.0.0.0 +a=ice-ufrag:abcd +a=ice-pwd:abcdefghijklmnopqrstuvwxyz123456 +a=fingerprint:sha-256 00:11:22:33:44:55:66:77:88:99:AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:AA:BB:CC:DD:EE:FF +a=setup:actpass +a=mid:0 +a=extmap:1 urn:ietf:params:rtp-hdrext:sdes:mid +a=extmap:2 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id +a=extmap:3 urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id +a=sendonly +a=msid:stream video +a=rtcp-mux +a=rtpmap:96 H264/90000 +a=rtcp-fb:96 goog-remb +a=rtcp-fb:96 transport-cc +a=rtcp-fb:96 ccm fir +a=rtcp-fb:96 nack +a=rtcp-fb:96 nack pli +a=fmtp:96 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f +m=audio 9 UDP/TLS/RTP/SAVPF 111 +c=IN IP4 0.0.0.0 +a=rtcp:9 IN IP4 0.0.0.0 +a=ice-ufrag:abcd +a=ice-pwd:abcdefghijklmnopqrstuvwxyz123456 +a=fingerprint:sha-256 00:11:22:33:44:55:66:77:88:99:AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:AA:BB:CC:DD:EE:FF +a=setup:actpass +a=mid:1 +a=extmap:1 urn:ietf:params:rtp-hdrext:sdes:mid +a=sendonly +a=msid:stream audio +a=rtcp-mux +a=rtpmap:111 opus/48000/2 +a=rtcp-fb:111 transport-cc +a=fmtp:111 minptime=10;useinbandfec=1 +` + + whipReq := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/whip", strings.NewReader(sdpOffer)) + whipReq.SetPathValue("streamId", "teststream-streamid") + whipReq.Header.Set("Content-Type", "application/sdp") + + w := httptest.NewRecorder() + handler.ServeHTTP(w, whipReq) + assert.Equal(t, http.StatusCreated, w.Code) + + newParams, err := getStreamRequestParams(stream) + assert.NoError(t, err) + assert.NotNil(t, newParams.liveParams.kickInput) + + //stop the WHIP connection + time.Sleep(2 * time.Millisecond) //wait for setup + //add kickOrch because we are not calling runStream which would have added it + newParams.liveParams.kickOrch = func(error) {} + stream.UpdateStreamParams(newParams) + newParams.liveParams.kickInput(errors.New("test complete")) +} + +// Test GetStreamData handler +func TestGetStreamDataHandler(t *testing.T) { + + t.Run("StreamData_MissingStreamId", func(t *testing.T) { + // Test with missing stream ID - should return 400 + ls := &LivepeerServer{} + handler := ls.UpdateStream() + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "Missing stream name") + }) + + t.Run("StreamData_DataOutputWorking", func(t *testing.T) { + node := mockJobLivepeerNode() + node.WorkDir = t.TempDir() + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Prepare a valid gatewayJob + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq := &JobRequest{ + Capability: "test-capability", + Parameters: paramsStr, + Timeout: 10, + } + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob} + + // Prepare a valid StartRequest body for /ai/stream/start + startReq := StartRequest{ + Stream: "teststream", + RtmpOutput: "rtmp://output", + StreamId: "streamid", + Params: "{}", + } + body, _ := json.Marshal(startReq) + req := httptest.NewRequest(http.MethodPost, "/ai/stream/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + urls, code, err := ls.setupStream(context.Background(), req, gatewayJob) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, code) + assert.NotNil(t, urls) + assert.Equal(t, "teststream-streamid", urls.StreamId) //combination of stream name (Stream) and id (StreamId) + + stream, ok := ls.LivepeerNode.LivePipelines[urls.StreamId] + assert.True(t, ok) + assert.NotNil(t, stream) + + params, err := getStreamRequestParams(stream) + assert.NoError(t, err) + assert.NotNil(t, params.liveParams) + + // Write some test data first + writer, err := params.liveParams.dataWriter.Next() + assert.NoError(t, err) + writer.Write([]byte("initial-data")) + writer.Close() + + handler := ls.GetStreamData() + dataReq := httptest.NewRequest(http.MethodGet, "/ai/stream/{streamId}/data", nil) + dataReq.SetPathValue("streamId", "teststream-streamid") + + // Create a context with timeout to prevent infinite blocking + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + dataReq = dataReq.WithContext(ctx) + + // Start writing more segments in a goroutine + go func() { + time.Sleep(10 * time.Millisecond) // Give handler time to start + + // Write additional segments + for i := 0; i < 2; i++ { + writer, err := params.liveParams.dataWriter.Next() + if err != nil { + break + } + writer.Write([]byte(fmt.Sprintf("test-data-%d", i))) + writer.Close() + time.Sleep(5 * time.Millisecond) + } + + // Close the writer to signal EOF + time.Sleep(10 * time.Millisecond) + params.liveParams.dataWriter.Close() + }() + + w := httptest.NewRecorder() + handler.ServeHTTP(w, dataReq) + + // Check response + responseBody := w.Body.String() + + // Verify we received some SSE data + assert.Contains(t, responseBody, "data: ", "Should have received SSE data") + + // Check for our test data + if strings.Contains(responseBody, "data: ") { + lines := strings.Split(responseBody, "\n") + dataFound := false + for _, line := range lines { + if strings.HasPrefix(line, "data: ") && strings.Contains(line, "data") { + dataFound = true + break + } + } + assert.True(t, dataFound, "Should have found data in SSE response") + } + }) +} + +// Test UpdateStream handler +func TestUpdateStreamHandler(t *testing.T) { + t.Run("UpdateStream_MissingStreamId", func(t *testing.T) { + // Test with missing stream ID - should return 400 + ls := &LivepeerServer{} + handler := ls.UpdateStream() + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "Missing stream name") + }) + + t.Run("Basic_StreamNotFound", func(t *testing.T) { + // Test with non-existent stream - should return 404 + node := mockJobLivepeerNode() + ls := &LivepeerServer{LivepeerNode: node} + + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", + strings.NewReader(`{"param1": "value1", "param2": "value2"}`)) + req.SetPathValue("streamId", "non-existent-stream") + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler := ls.UpdateStream() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Contains(t, w.Body.String(), "Stream not found") + }) + + t.Run("UpdateStream_WithOrchestratorControlChannel", func(t *testing.T) { + // Setup test infrastructure with mock orchestrator and trickle server + node := mockJobLivepeerNode() + + // Set up trickle-enabled orchestrator server + mux := http.NewServeMux() + mockOrch := &mockOrchestrator{} + + lp := &lphttp{orchestrator: nil, transRPC: mux, node: node} + lp.trickleSrv = trickle.ConfigureServer(trickle.TrickleServerConfig{ + Mux: mux, + BasePath: TrickleHTTPPath, + Autocreate: true, + }) + + // Register other required endpoints + mux.HandleFunc("/process/token", orchTokenHandler) + mux.HandleFunc("/ai/stream/start", orchAIStreamStartHandler) + + server := httptest.NewServer(lp) + defer server.Close() + + // Configure mock orchestrator + parsedURL, _ := url.Parse(server.URL) + capabilitySrv := httptest.NewServer(http.HandlerFunc(orchCapabilityUrlHandler)) + defer capabilitySrv.Close() + + lp.orchestrator = &testStreamOrch{mockOrchestrator: mockOrch, svc: parsedURL, capURL: capabilitySrv.URL} + + // Setup LivepeerServer + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Create a stream with control publisher + streamID := "test-stream" + controlURL := fmt.Sprintf("%s%stest-stream-control", server.URL, TrickleHTTPPath) + controlPub, err := trickle.NewTricklePublisher(controlURL) + assert.NoError(t, err) + + // Create minimal AI session + token := createMockJobToken(server.URL) + sess, err := tokenToAISession(*token) + assert.NoError(t, err) + + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + // Create and setup stream + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + stream.ControlPub = controlPub + + // Test update with valid stream and control publisher + updateData := `{"param1": "updated_value1", "param2": "updated_value2"}` + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", + strings.NewReader(updateData)) + req.SetPathValue("streamId", streamID) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler := ls.UpdateStream() + handler.ServeHTTP(w, req) + + // Should succeed + assert.Equal(t, http.StatusOK, w.Code) + + // Verify stream params were cached + assert.Equal(t, []byte(updateData), stream.Params) + + // Clean up + stream.StopStream(nil) + controlPub.Close() + }) + + t.Run("UpdateStream_WithoutControlChannel", func(t *testing.T) { + // Test stream update when no orchestrator control channel is available + node := mockJobLivepeerNode() + ls := &LivepeerServer{LivepeerNode: node} + + streamID := "test-stream-no-control" + + // Create minimal AI session + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + token := createMockJobToken(server.URL) + sess, err := tokenToAISession(*token) + assert.NoError(t, err) + + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + // Create stream WITHOUT control publisher + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + stream.ControlPub = nil // Explicitly set to nil + + updateData := `{"param1": "cached_value"}` + req := httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", + strings.NewReader(updateData)) + req.SetPathValue("streamId", streamID) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler := ls.UpdateStream() + handler.ServeHTTP(w, req) + + // Should still succeed (params cached locally) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify stream params were cached even without control channel + assert.Equal(t, []byte(updateData), stream.Params) + + // Clean up + stream.StopStream(nil) + }) + + t.Run("UpdateStream_ErrorHandling", func(t *testing.T) { + // Test various error conditions + node := mockJobLivepeerNode() + ls := &LivepeerServer{LivepeerNode: node} + + // Test 1: Wrong HTTP method + req := httptest.NewRequest(http.MethodGet, "/ai/stream/{streamId}/update", nil) + req.SetPathValue("streamId", "test-stream") + w := httptest.NewRecorder() + ls.UpdateStream().ServeHTTP(w, req) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + + // Test 2: Request too large + streamID := "test-stream-large" + token := createMockJobToken("http://example.com") + sess, _ := tokenToAISession(*token) + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + // Create a body larger than 10MB + largeData := bytes.Repeat([]byte("a"), 10*1024*1024+1) + req = httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", + bytes.NewReader(largeData)) + req.SetPathValue("streamId", streamID) + w = httptest.NewRecorder() + + ls.UpdateStream().ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "http: request body too large") + + stream.StopStream(nil) + + // Test 3: Control publisher write failure (simulate network error) + streamID2 := "test-stream-net-error" + params2 := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-2", + sess: &sess, + stream: streamID2, + streamID: streamID2, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream2 := node.NewLivePipeline("req-2", streamID2, "test-capability", params2, nil) + + // Use a control publisher pointing to non-existent URL + badControlPub, err := trickle.NewTricklePublisher("http://localhost:1/nonexistent") + if err == nil { + stream2.ControlPub = badControlPub + + req = httptest.NewRequest(http.MethodPost, "/ai/stream/{streamId}/update", + strings.NewReader(`{"param": "value"}`)) + req.SetPathValue("streamId", streamID2) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + + ls.UpdateStream().ServeHTTP(w, req) + + // Should return 500 due to control publisher write failure + assert.Equal(t, http.StatusInternalServerError, w.Code) + + // But params should still be cached + assert.Equal(t, []byte(`{"param": "value"}`), stream2.Params) + + badControlPub.Close() + } + + stream2.StopStream(nil) + }) +} + +// Test GetStreamStatus handler +func TestGetStreamStatusHandler(t *testing.T) { + ls := &LivepeerServer{} + handler := ls.GetStreamStatus() + // stream does not exist + req := httptest.NewRequest(http.MethodGet, "/ai/stream/{streamId}/status", nil) + req.SetPathValue("streamId", "any-stream") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + + // stream exists + node := mockJobLivepeerNode() + ls.LivepeerNode = node + node.NewLivePipeline("req-1", "any-stream", "test-capability", aiRequestParams{}, nil) + GatewayStatus.StoreKey("any-stream", "test", "test") + req = httptest.NewRequest(http.MethodGet, "/ai/stream/{streamId}/status", nil) + req.SetPathValue("streamId", "any-stream") + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +// Test sendPaymentForStream +func TestSendPaymentForStream(t *testing.T) { + t.Run("Success_ValidPayment", func(t *testing.T) { + // Setup + node := mockJobLivepeerNode() + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(2) + mockSender.On("CreateTicketBatch", "foo", 60).Return(mockTicketBatch(60), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + // Create mock orchestrator server that handles token requests and payments + paymentReceived := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/process/token": + orchTokenHandler(w, r) + case "/ai/stream/payment": + paymentReceived = true + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "payment_processed"}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + // Create a mock stream with AI session + streamID := "test-payment-stream" + token := createMockJobToken(server.URL) + sess, err := tokenToAISession(*token) + assert.NoError(t, err) + + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + // Create a job sender + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Test sendPaymentForStream + ctx := context.Background() + err = ls.sendPaymentForStream(ctx, stream, jobSender) + + // Should succeed + assert.NoError(t, err) + + // Verify payment was sent to orchestrator + assert.True(t, paymentReceived, "Payment should have been sent to orchestrator") + + // Clean up + stream.StopStream(nil) + }) + + t.Run("Error_GetTokenFailed", func(t *testing.T) { + // Setup node without orchestrator pool + node := mockJobLivepeerNode() + // Set up mock sender to prevent nil pointer dereference + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo") + mockSender.On("CreateTicketBatch", mock.Anything, mock.Anything).Return(mockTicketBatch(10), nil) + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + ls := &LivepeerServer{LivepeerNode: node} + + // Create a stream with invalid session + streamID := "test-invalid-token" + invalidToken := createMockJobToken("http://nonexistent-server.com") + sess, _ := tokenToAISession(*invalidToken) + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Should fail to get new token + err := ls.sendPaymentForStream(context.Background(), stream, jobSender) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nonexistent-server.com") + + stream.StopStream(nil) + }) + + t.Run("Error_PaymentCreationFailed", func(t *testing.T) { + // Test with node that has no sender (payment creation will fail) + node := mockJobLivepeerNode() + // node.Sender is nil by default + + server := httptest.NewServer(http.HandlerFunc(orchTokenHandler)) + defer server.Close() + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + + streamID := "test-payment-creation-fail" + token := createMockJobToken(server.URL) + sess, _ := tokenToAISession(*token) + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Should continue even if payment creation fails (no payment required) + err := ls.sendPaymentForStream(context.Background(), stream, jobSender) + assert.NoError(t, err) // Should not error, just logs and continues + + stream.StopStream(nil) + }) + + t.Run("Error_OrchestratorPaymentFailed", func(t *testing.T) { + // Setup node with sender to create payments + node := mockJobLivepeerNode() + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(2) + mockSender.On("CreateTicketBatch", "foo", 60).Return(mockTicketBatch(60), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + // Create mock orchestrator that returns error for payments + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/process/token": + orchTokenHandler(w, r) + case "/ai/stream/payment": + http.Error(w, "Payment processing failed", http.StatusInternalServerError) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + streamID := "test-payment-error" + token := createMockJobToken(server.URL) + sess, _ := tokenToAISession(*token) + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Should fail with payment error + err := ls.sendPaymentForStream(context.Background(), stream, jobSender) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unexpected status code") + + stream.StopStream(nil) + }) + + t.Run("Error_TokenToSessionConversionNoPrice", func(t *testing.T) { + // Test where tokenToAISession fails + node := mockJobLivepeerNode() + + // Set up mock sender to prevent nil pointer dereference + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo") + mockSender.On("CreateTicketBatch", mock.Anything, mock.Anything).Return(mockTicketBatch(10), nil) + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + // Create a server that returns invalid token response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/process/token" { + // Return malformed token that will cause tokenToAISession to fail + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"invalid": "token_structure"}`)) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + + // Create stream with valid initial session + streamID := "test-token-no-price" + token := createMockJobToken(server.URL) + sess, _ := tokenToAISession(*token) + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &sess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Should fail during token to session conversion + err := ls.sendPaymentForStream(context.Background(), stream, jobSender) + assert.NoError(t, err) + + stream.StopStream(nil) + }) + + t.Run("Success_StreamParamsUpdated", func(t *testing.T) { + // Test that stream params are updated with new session after token refresh + node := mockJobLivepeerNode() + mockSender := pm.MockSender{} + mockSender.On("StartSession", mock.Anything).Return("foo").Times(2) + mockSender.On("CreateTicketBatch", "foo", 60).Return(mockTicketBatch(60), nil).Once() + node.Sender = &mockSender + node.Balances = core.NewAddressBalances(10) + defer node.Balances.StopCleanup() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/process/token": + orchTokenHandler(w, r) + case "/ai/stream/payment": + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + node.OrchestratorPool = newStubOrchestratorPool(node, []string{server.URL}) + ls := &LivepeerServer{LivepeerNode: node} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + + streamID := "test-params-update" + originalToken := createMockJobToken(server.URL) + originalSess, _ := tokenToAISession(*originalToken) + originalSessionAddr := originalSess.Address() + + params := aiRequestParams{ + liveParams: &liveRequestParams{ + requestID: "req-1", + sess: &originalSess, + stream: streamID, + streamID: streamID, + sendErrorEvent: func(err error) {}, + segmentReader: media.NewSwitchableSegmentReader(), + }, + node: node, + } + stream := node.NewLivePipeline("req-1", streamID, "test-capability", params, nil) + + jobSender := &core.JobSender{ + Addr: "0x1111111111111111111111111111111111111111", + Sig: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + } + + // Send payment + err := ls.sendPaymentForStream(context.Background(), stream, jobSender) + assert.NoError(t, err) + + // Verify that stream params were updated with new session + updatedParams, err := getStreamRequestParams(stream) + assert.NoError(t, err) + + // The session should be updated (new token fetched) + updatedSessionAddr := updatedParams.liveParams.sess.Address() + // In a real scenario, this might be different, but our mock returns the same token + // The important thing is that UpdateStreamParams was called + assert.NotNil(t, updatedParams.liveParams.sess) + assert.Equal(t, originalSessionAddr, updatedSessionAddr) // Same because mock returns same token + + stream.StopStream(nil) + }) +} + +func TestTokenSessionConversion(t *testing.T) { + token := createMockJobToken("http://example.com") + sess, err := tokenToAISession(*token) + assert.True(t, err != nil || sess != (AISession{})) + assert.NotNil(t, sess.OrchestratorInfo) + assert.NotNil(t, sess.OrchestratorInfo.TicketParams) + + assert.NotEmpty(t, sess.Address()) + assert.NotEmpty(t, sess.Transcoder()) + + _, err = sessionToToken(&sess) + assert.True(t, err != nil || true) +} + +func TestGetStreamRequestParams(t *testing.T) { + _, err := getStreamRequestParams(nil) + assert.Error(t, err) +} + +// createMockMediaMTXServer creates a simple mock MediaMTX server that returns 200 OK to all requests +func createMockMediaMTXServer(t *testing.T) *httptest.Server { + mux := http.NewServeMux() + + // Simple handler that returns 200 OK to any request + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("Mock MediaMTX: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + // Create a listener on port 9997 specifically + listener, err := net.Listen("tcp", ":9997") + if err != nil { + t.Fatalf("Failed to listen on port 9997: %v", err) + } + + server := &httptest.Server{ + Listener: listener, + Config: &http.Server{Handler: mux}, + } + server.Start() + + t.Cleanup(func() { + server.Close() + }) + + return server +} diff --git a/server/rpc.go b/server/rpc.go index 5bcb1264ee..0aea47ee5f 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -254,6 +254,9 @@ func StartTranscodeServer(orch Orchestrator, bind string, mux *http.ServeMux, wo lp.transRPC.HandleFunc("/process/token", lp.GetJobToken) lp.transRPC.HandleFunc("/capability/register", lp.RegisterCapability) lp.transRPC.HandleFunc("/capability/unregister", lp.UnregisterCapability) + lp.transRPC.HandleFunc("/ai/stream/start", lp.StartStream) + lp.transRPC.HandleFunc("/ai/stream/stop", lp.StopStream) + lp.transRPC.HandleFunc("/ai/stream/payment", lp.ProcessStreamPayment) cert, key, err := getCert(orch.ServiceURI(), workDir) if err != nil {