Skip to content
This repository was archived by the owner on Jun 2, 2026. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 108 additions & 232 deletions api/pkg/api/handler/tenantaccount.go

Large diffs are not rendered by default.

116 changes: 18 additions & 98 deletions api/pkg/api/handler/tenantaccount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ func TestTenantAccountHandler_Create(t *testing.T) {
assert.Nil(t, err)
okBody3, err := json.Marshal(model.APITenantAccountCreateRequest{InfrastructureProviderID: ip.ID.String(), TenantOrg: cdb.GetStrPtr(tnOrg3)})
assert.Nil(t, err)
okBodyInferIP, err := json.Marshal(model.APITenantAccountCreateRequest{TenantOrg: cdb.GetStrPtr("test-tn-org-infer")})
assert.Nil(t, err)

cfg := common.GetTestConfig()
tempClient := &tmocks.Client{}
Expand Down Expand Up @@ -381,6 +383,15 @@ func TestTenantAccountHandler_Create(t *testing.T) {
expectedStatus: http.StatusCreated,
expectedTenantID: nil,
},
{
name: "success when infrastructureProviderId is omitted (inferred from org)",
reqOrgName: ipOrg1,
reqBody: string(okBodyInferIP),
user: ipu,
expectedErr: false,
expectedStatus: http.StatusCreated,
expectedTenantID: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -793,14 +804,6 @@ func TestTenantAccountHandler_GetByID(t *testing.T) {
expectedErr: true,
expectedStatus: http.StatusInternalServerError,
},
{
name: "error when infrastructure provider and tenant not specified",
reqOrgName: tnOrg1,
user: tnUser,
taID: ta11.ID.String(),
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when tenant account id is invalid uuid",
reqOrgName: tnOrg1,
Expand All @@ -817,15 +820,6 @@ func TestTenantAccountHandler_GetByID(t *testing.T) {
expectedErr: true,
expectedStatus: http.StatusNotFound,
},
{
name: "error when infrastructure provider not valid uuid",
reqOrgName: ipOrg1,
user: ipUser,
taID: ta11.ID.String(),
queryInfrastructureProviderID: cdb.GetStrPtr("non-uuid"),
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when infrastructure provider not found for org",
reqOrgName: ipOrg3,
Expand All @@ -835,32 +829,14 @@ func TestTenantAccountHandler_GetByID(t *testing.T) {
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when infrastructure provider in url doesnt match org",
reqOrgName: ipOrg1,
user: ipUser,
taID: ta11.ID.String(),
queryInfrastructureProviderID: cdb.GetStrPtr(uuid.New().String()),
expectedErr: true,
expectedStatus: http.StatusNotFound,
},
{
name: "error when infrastructure provider in org doesnt match infrastructure provider in tenant account",
reqOrgName: ipOrg1,
user: ipUser,
taID: ta21.ID.String(),
queryInfrastructureProviderID: cdb.GetStrPtr(ip1.ID.String()),
expectedErr: true,
expectedStatus: http.StatusNotFound,
},
{
name: "error when tenant id not valid uuid",
reqOrgName: ipOrg1,
user: ipUser,
taID: ta11.ID.String(),
queryTenantID: cdb.GetStrPtr("non-uuid"),
expectedErr: true,
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusForbidden,
},
{
name: "error when tenant not found for org",
Expand All @@ -878,7 +854,7 @@ func TestTenantAccountHandler_GetByID(t *testing.T) {
taID: ta11.ID.String(),
queryTenantID: cdb.GetStrPtr(tn1.ID.String()),
expectedErr: true,
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusForbidden,
},
{
name: "error when tenant in org doesnt match tenant in tenant account",
Expand All @@ -887,7 +863,7 @@ func TestTenantAccountHandler_GetByID(t *testing.T) {
taID: ta12.ID.String(),
queryTenantID: cdb.GetStrPtr(tn1.ID.String()),
expectedErr: true,
expectedStatus: http.StatusNotFound,
expectedStatus: http.StatusForbidden,
},
{
name: "error when tenant id is not found",
Expand Down Expand Up @@ -1087,6 +1063,8 @@ func TestTenantAccountHandler_GetAll(t *testing.T) {

st1 := testTenantAccountBuildSite(t, dbSession, ip1, "site1", ipUser)
assert.NotNil(t, st1)
st2 := testTenantAccountBuildSite(t, dbSession, ip2, "site2", ipUser)
assert.NotNil(t, st2)

tns := []cdbm.Tenant{}
tas := []cdbm.TenantAccount{}
Expand All @@ -1098,6 +1076,8 @@ func TestTenantAccountHandler_GetAll(t *testing.T) {

allocation := testTenantAccountBuildAllocation(t, dbSession, st1, tn, "Test Allocation", ipUser)
assert.NotNil(t, allocation)
allocation2 := testTenantAccountBuildAllocation(t, dbSession, st2, tn, "Test Allocation 2", ipUser)
assert.NotNil(t, allocation2)

ta1 := testTenantAccountBuildTenantAccount(t, dbSession, fmt.Sprintf("test-tenant-account-%02d", i), ip1, tn, tn.Org, cdbm.TenantAccountStatusInvited, ipUser.ID, contactUser1.ID)
assert.NotNil(t, ta1)
Expand Down Expand Up @@ -1165,26 +1145,6 @@ func TestTenantAccountHandler_GetAll(t *testing.T) {
expectedErr: true,
expectedStatus: http.StatusInternalServerError,
},
{
name: "error when infrastructure provider and tenant not specified",
reqOrgName: tnOrgs[0],
user: tnUser,
queryInfrastructureProviderID: nil,
queryTenantID: nil,
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when infrastructure provider not valid uuid",
reqOrgName: ipOrg1,
user: ipUser,
queryInfrastructureProviderID: cdb.GetStrPtr("non-uuid"),
queryTenantID: nil,
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when infrastructure provider not found for org",
reqOrgName: ipOrg3,
Expand All @@ -1195,46 +1155,6 @@ func TestTenantAccountHandler_GetAll(t *testing.T) {
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when infrastructure provider in url doesnt match org",
reqOrgName: ipOrg1,
user: ipUser,
queryInfrastructureProviderID: cdb.GetStrPtr(uuid.New().String()),
queryTenantID: nil,
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when tenant id not valid uuid",
reqOrgName: tnOrgs[0],
user: tnUser,
queryInfrastructureProviderID: nil,
queryTenantID: cdb.GetStrPtr("non-uuid"),
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when tenant not found for org",
reqOrgName: tnOrgs[0],
user: tnUser,
queryInfrastructureProviderID: nil,
queryTenantID: cdb.GetStrPtr(tn15.ID.String()),
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "error when tenant id in url doesnt match org",
reqOrgName: tnOrgs[0],
user: tnUser,
queryInfrastructureProviderID: nil,
queryTenantID: cdb.GetStrPtr(uuid.New().String()),
queryIncludeRelations1: nil,
expectedErr: true,
expectedStatus: http.StatusBadRequest,
},
{
name: "success when infrastructure provider id is specified",
reqOrgName: ipOrg1,
Expand Down
3 changes: 1 addition & 2 deletions api/pkg/api/model/tenantaccount.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ type APITenantAccountCreateRequest struct {
func (tacr APITenantAccountCreateRequest) Validate() error {
return validation.ValidateStruct(&tacr,
validation.Field(&tacr.InfrastructureProviderID,
validation.Required.Error(validationErrorValueRequired),
validationis.UUID.Error(validationErrorInvalidUUID)),
validation.When(tacr.InfrastructureProviderID != "", validationis.UUID.Error(validationErrorInvalidUUID))),
validation.Field(&tacr.TenantID,
validation.When(tacr.TenantOrg == nil, validation.Required.Error(validationErrorTenantIDOrOrgRequired)),
validationis.UUID.Error(validationErrorInvalidUUID)),
Expand Down
5 changes: 2 additions & 3 deletions api/pkg/api/model/tenantaccount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ func TestAPITenantAccountCreateRequest_Validate(t *testing.T) {
errStr string
}{
{
desc: "errors when infrastructureProviderID is not provided",
desc: "ok when infrastructureProviderID is omitted (inferred from org by handler)",
obj: APITenantAccountCreateRequest{TenantID: cdb.GetStrPtr(uuid.New().String())},
expectErr: true,
errStr: "infrastructureProviderId: " + validationErrorValueRequired + ".",
expectErr: false,
},
{
desc: "errors when infrastructureProviderID is invalid",
Expand Down
7 changes: 1 addition & 6 deletions cli/tui/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -2311,17 +2311,12 @@ func cmdTenantAccountList(s *Session, _ []string) error {
}

func cmdTenantAccountCreate(s *Session, _ []string) error {
infrastructureProviderID, err := PromptText("Infrastructure provider ID", true)
if err != nil {
return err
}
tenantOrg, err := PromptText("Tenant org", true)
if err != nil {
return err
}
body := map[string]interface{}{
"infrastructureProviderId": strings.TrimSpace(infrastructureProviderID),
"tenantOrg": strings.TrimSpace(tenantOrg),
"tenantOrg": strings.TrimSpace(tenantOrg),
}
LogCmd(s, "tenant-account", "create", "--tenant-org", strings.TrimSpace(tenantOrg))
bodyJSON, _ := json.Marshal(body)
Expand Down
10 changes: 10 additions & 0 deletions db/pkg/db/model/tenantaccount.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ type TenantAccountFilterInput struct {
Statuses []string
TenantIDs []uuid.UUID
TenantOrgs []string
TenantAccountIDs []uuid.UUID
SearchQuery *string
}

Expand Down Expand Up @@ -343,6 +344,15 @@ func (tasd TenantAccountSQLDAO) setQueryWithFilter(filter TenantAccountFilterInp
tasd.tracerSpan.SetAttribute(tnaDAOSpan, "status", filter.Statuses)
}

if filter.TenantAccountIDs != nil {
if len(filter.TenantAccountIDs) == 1 {
query = query.Where("ta.id = ?", filter.TenantAccountIDs[0])
} else {
query = query.Where("ta.id IN (?)", bun.In(filter.TenantAccountIDs))
}
tasd.tracerSpan.SetAttribute(tnaDAOSpan, "id", filter.TenantAccountIDs)
}

searchQuery, searchTokens, ok := db.NormalizeSearchQuery(filter.SearchQuery)
if ok {
query = query.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
Expand Down
Loading