Skip to content
38 changes: 9 additions & 29 deletions internal/commands/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,18 +559,8 @@ func resultCodeBashing(codeBashingWrapper wrappers.CodeBashingWrapper) *cobra.Co
),
RunE: runGetCodeBashingCommand(codeBashingWrapper),
}
resultCmd.PersistentFlags().String(commonParams.LanguageFlag, "", "Language of the vulnerability")
err := resultCmd.MarkPersistentFlagRequired(commonParams.LanguageFlag)
if err != nil {
log.Fatal(err)
}
resultCmd.PersistentFlags().String(commonParams.VulnerabilityTypeFlag, "", "Vulnerability type")
err = resultCmd.MarkPersistentFlagRequired(commonParams.VulnerabilityTypeFlag)
if err != nil {
log.Fatal(err)
}
resultCmd.PersistentFlags().String(commonParams.CweIDFlag, "", "CWE ID for the vulnerability")
err = resultCmd.MarkPersistentFlagRequired(commonParams.CweIDFlag)
resultCmd.PersistentFlags().String(commonParams.QueryIDFlag, "", "QueryId of vulnerability")
err := resultCmd.MarkPersistentFlagRequired(commonParams.QueryIDFlag)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -1095,35 +1085,25 @@ func runGetCodeBashingCommand(
codeBashingWrapper wrappers.CodeBashingWrapper,
) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
language, _ := cmd.Flags().GetString(commonParams.LanguageFlag)
cwe, _ := cmd.Flags().GetString(commonParams.CweIDFlag)
vulType, _ := cmd.Flags().GetString(commonParams.VulnerabilityTypeFlag)
params, err := codeBashingWrapper.BuildCodeBashingParams(
[]wrappers.CodeBashingParamsCollection{
{
CweID: "CWE-" + cwe,
Language: language,
CxQueryName: strings.ReplaceAll(vulType, " ", "_"),
},
},
)
if err != nil {
return err
}
// Fetch the cached token or a new one to obtain the codebashing URL incoded in the jwt token

queryID, _ := cmd.Flags().GetString(commonParams.QueryIDFlag)
codeBashingURL, err := codeBashingWrapper.GetCodeBashingURL(codeBashingKey)
if err != nil {
return err
}
// Make the request to the api to obtain the codebashing link and send the codebashing url to enrich the path
CodeBashingModel, webError, err := codeBashingWrapper.GetCodeBashingLinks(params, codeBashingURL)
CodeBashingModel, webError, err := codeBashingWrapper.GetCodeBashingLinks(queryID, codeBashingURL)
if err != nil {
return err
}
if webError != nil {
return errors.New(webError.Message)
}
err = printByFormat(cmd, *CodeBashingModel)
model := *CodeBashingModel
if len(model) > 0 && model[0].Path != "" {
logger.Printf("CodeBashing lesson available at: %s", model[0].Path)
}
if err != nil {
return errors.Wrapf(err, "%s", failedListingCodeBashing)
}
Expand Down
79 changes: 27 additions & 52 deletions internal/commands/result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ import (
const fileName = "cx_result"

const (
resultsCommand = "results"
codeBashingCommand = "codebashing"
vulnerabilityValue = "Reflected XSS All Clients"
languageValue = "PHP"
cweValue = "79"
jsonValue = "json"
tableValue = "table"
listValue = "list"
secretDetectionLine = "| Secret Detection 0 1 1 0 0 Completed |"
resultsCommand = "results"
codeBashingCommand = "codebashing"
queryIDValue = "8481125285487743346"
queryIDWrongValueValue = "11666704984804998184"
jsonValue = "json"
tableValue = "table"
listValue = "list"
secretDetectionLine = "| Secret Detection 0 1 1 0 0 Completed |"
)

func flag(f string) string {
Expand Down Expand Up @@ -553,53 +552,37 @@ func TestRunGetResultsByScanIdWithEmptyOutputPath(t *testing.T) {
_ = execCmdNotNilAssertion(t, "results", "show", "--scan-id", "MOCK", "--output-path", "")
}

func TestRunGetCodeBashingWithoutLanguage(t *testing.T) {
func TestRunGetCodeBashingWithEmptyQueryId(t *testing.T) {
err := execCmdNotNilAssertion(
t,
resultsCommand,
codeBashingCommand,
flag(params.CweIDFlag),
cweValue,
flag(params.VulnerabilityTypeFlag),
vulnerabilityValue)
assert.Equal(t, err.Error(), "required flag(s) \"language\" not set", "Wrong expected error message")
flag(params.QueryIDFlag),
"")
assert.Equal(t, err.Error(), "Cannot GET /lessons/mapping/", "Wrong expected error message")
}

func TestRunGetCodeBashingWithoutVulnerabilityType(t *testing.T) {
err := execCmdNotNilAssertion(
t,
func TestRunGetCodeBashingWithEmptyQueryIdThatDoesNotHaveAnLesson(t *testing.T) {
cmd := createASTTestCommand()
buffer, err := executeRedirectedOsStdoutTestCommand(cmd,
resultsCommand,
codeBashingCommand,
flag(params.CweIDFlag),
cweValue,
flag(params.LanguageFlag),
languageValue)
assert.Equal(t, err.Error(), "required flag(s) \"vulnerability-type\" not set", "Wrong expected error message")
}
flag(params.QueryIDFlag),
queryIDWrongValueValue)

func TestRunGetCodeBashingWithoutCweId(t *testing.T) {
err := execCmdNotNilAssertion(
t,
resultsCommand,
codeBashingCommand,
flag(params.VulnerabilityTypeFlag),
vulnerabilityValue,
flag(params.LanguageFlag),
languageValue)
assert.Equal(t, err.Error(), "required flag(s) \"cwe-id\" not set", "Wrong expected error message")
assert.NilError(t, err, "Command should not return an error")
output := buffer.String()
assert.Assert(t, strings.Contains(output, "/app/home"), "Expected response to contain /app/home path")
assert.Assert(t, !strings.Contains(output, "Cannot GET /lessons/mapping/"), "Response should not contain error message")
}

func TestRunGetCodeBashingWithFormatJson(t *testing.T) {
execCmdNilAssertion(
t,
resultsCommand,
codeBashingCommand,
flag(params.VulnerabilityTypeFlag),
vulnerabilityValue,
flag(params.LanguageFlag),
languageValue,
flag(params.CweIDFlag),
cweValue,
flag(params.QueryIDFlag),
queryIDValue,
flag(params.FormatFlag),
jsonValue)
}
Expand All @@ -609,12 +592,8 @@ func TestRunGetCodeBashingWithFormatTable(t *testing.T) {
t,
resultsCommand,
codeBashingCommand,
flag(params.VulnerabilityTypeFlag),
vulnerabilityValue,
flag(params.LanguageFlag),
languageValue,
flag(params.CweIDFlag),
cweValue,
flag(params.QueryIDFlag),
queryIDValue,
flag(params.FormatFlag),
tableValue)
}
Expand All @@ -624,12 +603,8 @@ func TestRunGetCodeBashingWithFormatList(t *testing.T) {
t,
resultsCommand,
codeBashingCommand,
flag(params.VulnerabilityTypeFlag),
vulnerabilityValue,
flag(params.LanguageFlag),
languageValue,
flag(params.CweIDFlag),
cweValue,
flag(params.QueryIDFlag),
queryIDValue,
flag(params.FormatFlag),
listValue)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/params/binds.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var EnvVarsBinds = []struct {
{IgnoreProxyKey, IgnoreProxyEnv, ""},
{AgentNameKey, AgentNameEnv, "ASTCLI"},
{OriginKey, OriginEnv, "CLI"},
{CodeBashingPathKey, ScansPathEnv, "api/codebashing/lessons"},
{CodeBashingPathKey, ScansPathEnv, "https://core-service.codebashing.com/lessons/mapping"},
{CustomStatesAPIPathKey, CustomStatesAPIPathEnv, "api/custom-states"},
{ScansPathKey, ScansPathEnv, "api/scans"},
{ProjectsPathKey, ProjectsPathEnv, "api/projects"},
Expand Down
71 changes: 67 additions & 4 deletions internal/wrappers/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@
jwtError = "Error retrieving %s from jwt token"
basicFormat = "Basic %s"
bearerFormat = "Bearer %s"
contentTypeHeader = "Content-Type"
formURLContentType = "application/x-www-form-urlencoded"
jsonContentType = "application/json"
defaultDialerDuration = 30 * time.Second
// Can be removed once CodeBashing team will add support for
//http://core-service.codebashing.com/lessons/mapping/{queryId}
//with --header 'Authorization: Bearer <token>' currently

Check failure on line 49 in internal/wrappers/client.go

View workflow job for this annotation

GitHub Actions / lint

commentFormatting: put a space between `//` and comment text (gocritic)
//supports only --header 'Authorization: <token>'
onlyTokenFormat = "%s"
contentTypeHeader = "Content-Type"
formURLContentType = "application/x-www-form-urlencoded"
jsonContentType = "application/json"
defaultDialerDuration = 30 * time.Second
)

var (
Expand Down Expand Up @@ -339,6 +344,18 @@
return SendHTTPRequestByFullURL(method, u, body, auth, timeout, accessToken, true)
}

// Can be removed once CodeBashing team will add support for
// http://core-service.codebashing.com/lessons/mapping/{queryId}
// with --header 'Authorization: Bearer <token>' currently
// supports only --header 'Authorization: <token>'
func SendHTTPRequestNoBaseCBURL(method, path string, body io.Reader, auth bool, timeout uint) (*http.Response, error) {
_, accessToken, err := getURLAndAccessToken(path)
if err != nil {
return nil, err
}
return SendHTTPRequestWithoutBearerTagByFullURL(method, path, body, auth, timeout, accessToken, true)
}

func SendPrivateHTTPRequest(method, path string, body io.Reader, timeout uint, auth bool) (*http.Response, error) {
u, accessToken, err := getURLAndAccessToken(path)
if err != nil {
Expand All @@ -358,6 +375,21 @@
return SendHTTPRequestByFullURLContentLength(method, fullURL, body, -1, auth, timeout, accessToken, bodyPrint)
}

// Can be removed once CodeBashing team will add support for
// http://core-service.codebashing.com/lessons/mapping/{queryId}
// with --header 'Authorization: Bearer <token>' currently
// supports only --header 'Authorization: <token>'
func SendHTTPRequestWithoutBearerTagByFullURL(
method, fullURL string,
body io.Reader,
auth bool,
timeout uint,
accessToken string,
bodyPrint bool,
) (*http.Response, error) {
return SendHTTPRequestWithoutBearerTagByFullURLContentLength(method, fullURL, body, -1, auth, timeout, accessToken, bodyPrint)
}

func SendHTTPRequestByFullURLContentLength(
method, fullURL string,
body io.Reader,
Expand Down Expand Up @@ -385,6 +417,37 @@
return request(client, req, bodyPrint)
}

// Can be removed once CodeBashing team will add support for
// http://core-service.codebashing.com/lessons/mapping/{queryId}
// with --header 'Authorization: Bearer <token>' currently
// allow only --header 'Authorization: <token>'
func SendHTTPRequestWithoutBearerTagByFullURLContentLength(
method, fullURL string,
body io.Reader,
contentLength int64,
auth bool,
timeout uint,
accessToken string,
bodyPrint bool,
) (*http.Response, error) {
req, err := http.NewRequest(method, fullURL, body)
if err != nil {
return nil, err
}
if contentLength >= 0 {
req.ContentLength = contentLength
}
client := GetClient(timeout)
setAgentNameAndOrigin(req)
if auth {
enrichWithOath2Credentials(req, accessToken, onlyTokenFormat)
}

req = addReqMonitor(req)

return request(client, req, bodyPrint)
}

func addReqMonitor(req *http.Request) *http.Request {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
if viper.GetBool(commonParams.DebugFlag) || viper.GetString(commonParams.LogFileFlag) != "" || viper.GetString(commonParams.LogFileConsoleFlag) != "" {
Expand Down
Loading
Loading