@@ -15,16 +15,24 @@ import (
1515func (h * ConnectHandler ) CreateCheckout (ctx context.Context , request * connect.Request [frontierv1beta1.CreateCheckoutRequest ]) (* connect.Response [frontierv1beta1.CreateCheckoutResponse ], error ) {
1616 errorLogger := NewErrorLogger ()
1717
18+ // Always infer billing_id from org_id (ignore billing_id from request for security)
19+ billingID , err := h .GetBillingAccountFromOrgID (ctx , request .Msg .GetOrgId ())
20+ if err != nil {
21+ errorLogger .LogServiceError (ctx , request , "CreateCheckout.GetBillingAccountFromOrgID" , err ,
22+ zap .String ("org_id" , request .Msg .GetOrgId ()))
23+ return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
24+ }
25+
1826 // check if setup requested
1927 if request .Msg .GetSetupBody () != nil && request .Msg .GetSetupBody ().GetPaymentMethod () {
2028 newCheckout , err := h .checkoutService .CreateSessionForPaymentMethod (ctx , checkout.Checkout {
21- CustomerID : request . Msg . GetBillingId () ,
29+ CustomerID : billingID ,
2230 SuccessUrl : request .Msg .GetSuccessUrl (),
2331 CancelUrl : request .Msg .GetCancelUrl (),
2432 })
2533 if err != nil {
2634 errorLogger .LogServiceError (ctx , request , "CreateCheckout.CreateSessionForPaymentMethod" , err ,
27- zap .String ("billing_id" , request . Msg . GetBillingId () ))
35+ zap .String ("billing_id" , billingID ))
2836 return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
2937 }
3038
@@ -36,7 +44,7 @@ func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Re
3644 // check if customer portal requested
3745 if request .Msg .GetSetupBody () != nil && request .Msg .GetSetupBody ().GetCustomerPortal () {
3846 newCheckout , err := h .checkoutService .CreateSessionForCustomerPortal (ctx , checkout.Checkout {
39- CustomerID : request . Msg . GetBillingId () ,
47+ CustomerID : billingID ,
4048 SuccessUrl : request .Msg .GetSuccessUrl (),
4149 CancelUrl : request .Msg .GetCancelUrl (),
4250 })
@@ -45,7 +53,7 @@ func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Re
4553 return nil , connect .NewError (connect .CodeFailedPrecondition , ErrPortalChangesKycCompleted )
4654 }
4755 errorLogger .LogServiceError (ctx , request , "CreateCheckout.CreateSessionForCustomerPortal" , err ,
48- zap .String ("billing_id" , request . Msg . GetBillingId () ))
56+ zap .String ("billing_id" , billingID ))
4957 return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
5058 }
5159
@@ -74,7 +82,7 @@ func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Re
7482 quantity = request .Msg .GetProductBody ().GetQuantity ()
7583 }
7684 newCheckout , err := h .checkoutService .Create (ctx , checkout.Checkout {
77- CustomerID : request . Msg . GetBillingId () ,
85+ CustomerID : billingID ,
7886 SuccessUrl : request .Msg .GetSuccessUrl (),
7987 CancelUrl : request .Msg .GetCancelUrl (),
8088 PlanID : planID ,
@@ -88,7 +96,7 @@ func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Re
8896 return nil , connect .NewError (connect .CodeInvalidArgument , ErrPerSeatLimitReached )
8997 }
9098 errorLogger .LogServiceError (ctx , request , "CreateCheckout.Create" , err ,
91- zap .String ("billing_id" , request . Msg . GetBillingId () ),
99+ zap .String ("billing_id" , billingID ),
92100 zap .String ("plan_id" , planID ),
93101 zap .String ("product_id" , featureID ),
94102 zap .Int64 ("quantity" , quantity ),
@@ -105,6 +113,14 @@ func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Re
105113func (h * ConnectHandler ) DelegatedCheckout (ctx context.Context , request * connect.Request [frontierv1beta1.DelegatedCheckoutRequest ]) (* connect.Response [frontierv1beta1.DelegatedCheckoutResponse ], error ) {
106114 errorLogger := NewErrorLogger ()
107115
116+ // Always infer billing_id from org_id (ignore billing_id from request for security)
117+ billingID , err := h .GetBillingAccountFromOrgID (ctx , request .Msg .GetOrgId ())
118+ if err != nil {
119+ errorLogger .LogServiceError (ctx , request , "DelegatedCheckout.GetBillingAccountFromOrgID" , err ,
120+ zap .String ("org_id" , request .Msg .GetOrgId ()))
121+ return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
122+ }
123+
108124 var planID string
109125 var skipTrial bool
110126 var cancelAfterTrail bool
@@ -122,7 +138,7 @@ func (h *ConnectHandler) DelegatedCheckout(ctx context.Context, request *connect
122138 productQuantity = request .Msg .GetProductBody ().GetQuantity ()
123139 }
124140 subs , prod , err := h .checkoutService .Apply (ctx , checkout.Checkout {
125- CustomerID : request . Msg . GetBillingId () ,
141+ CustomerID : billingID ,
126142 PlanID : planID ,
127143 ProductID : productID ,
128144 Quantity : productQuantity ,
@@ -132,7 +148,7 @@ func (h *ConnectHandler) DelegatedCheckout(ctx context.Context, request *connect
132148 })
133149 if err != nil {
134150 errorLogger .LogServiceError (ctx , request , "DelegatedCheckout.Apply" , err ,
135- zap .String ("billing_id" , request . Msg . GetBillingId () ),
151+ zap .String ("billing_id" , billingID ),
136152 zap .String ("plan_id" , planID ),
137153 zap .String ("product_id" , productID ),
138154 zap .Int64 ("product_quantity" , productQuantity ),
@@ -170,13 +186,21 @@ func (h *ConnectHandler) ListCheckouts(ctx context.Context, request *connect.Req
170186 return nil , connect .NewError (connect .CodeInvalidArgument , ErrBadRequest )
171187 }
172188
189+ // Always infer billing_id from org_id (ignore billing_id from request for security)
190+ billingID , err := h .GetBillingAccountFromOrgID (ctx , request .Msg .GetOrgId ())
191+ if err != nil {
192+ errorLogger .LogServiceError (ctx , request , "ListCheckouts.GetBillingAccountFromOrgID" , err ,
193+ zap .String ("org_id" , request .Msg .GetOrgId ()))
194+ return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
195+ }
196+
173197 var checkouts []* frontierv1beta1.CheckoutSession
174198 checkoutList , err := h .checkoutService .List (ctx , checkout.Filter {
175- CustomerID : request . Msg . GetBillingId () ,
199+ CustomerID : billingID ,
176200 })
177201 if err != nil {
178202 errorLogger .LogServiceError (ctx , request , "ListCheckouts.List" , err ,
179- zap .String ("billing_id" , request . Msg . GetBillingId () ),
203+ zap .String ("billing_id" , billingID ),
180204 zap .String ("org_id" , request .Msg .GetOrgId ()))
181205 return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
182206 }
@@ -192,15 +216,14 @@ func (h *ConnectHandler) ListCheckouts(ctx context.Context, request *connect.Req
192216func (h * ConnectHandler ) GetCheckout (ctx context.Context , request * connect.Request [frontierv1beta1.GetCheckoutRequest ]) (* connect.Response [frontierv1beta1.GetCheckoutResponse ], error ) {
193217 errorLogger := NewErrorLogger ()
194218
195- if request .Msg .GetOrgId () == "" || request . Msg . GetId () == "" {
219+ if request .Msg .GetId () == "" {
196220 return nil , connect .NewError (connect .CodeInvalidArgument , ErrBadRequest )
197221 }
198222
199223 ch , err := h .checkoutService .GetByID (ctx , request .Msg .GetId ())
200224 if err != nil {
201225 errorLogger .LogServiceError (ctx , request , "GetCheckout.GetByID" , err ,
202- zap .String ("checkout_id" , request .Msg .GetId ()),
203- zap .String ("org_id" , request .Msg .GetOrgId ()))
226+ zap .String ("checkout_id" , request .Msg .GetId ()))
204227 return nil , connect .NewError (connect .CodeInternal , ErrInternalServerError )
205228 }
206229
0 commit comments