Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 91 additions & 84 deletions apisprout.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,48 @@ var handler = func(rr *RefreshableRouter) http.Handler {
})
}

//
func loadSwaggerFromUri(uri string) (data []byte, err error) {
if strings.HasPrefix(uri, "http") {
req, httpErr := http.NewRequest("GET", uri, nil)
if httpErr != nil {
err = httpErr
return
}
if customHeader := viper.GetString("header"); customHeader != "" {
header := strings.Split(customHeader, ":")
if len(header) != 2 {
err = errors.New("Header format is invalid")
} else {
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
}
}
if err != nil {
return
}

client := &http.Client{}
resp, httpErr := client.Do(req)
if httpErr != nil {
err = httpErr
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Server at %s reported %d status code", uri, resp.StatusCode)
return
}
data, err = ioutil.ReadAll(resp.Body)
if err != nil {
return
}
} else {
data, err = ioutil.ReadFile(uri)
}

return data, err
}

// server loads an OpenAPI file and runs a mock server using the paths and
// examples defined in the file.
func server(cmd *cobra.Command, args []string) {
Expand All @@ -611,83 +653,58 @@ func server(cmd *cobra.Command, args []string) {

// Load either from an HTTP URL or from a local file depending on the passed
// in value.
if strings.HasPrefix(uri, "http") {
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
log.Fatal(err)
}
if customHeader := viper.GetString("header"); customHeader != "" {
header := strings.Split(customHeader, ":")
if len(header) != 2 {
log.Fatal("Header format is invalid.")
}
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Fatal(err)
}
data, err = loadSwaggerFromUri(uri)
if err != nil {
log.Fatal(err)
}

data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
log.Fatal(err)
if viper.GetBool("watch") {
if strings.HasPrefix(uri, "http") {
log.Fatal(errors.New("Watching a URL is not supported."))
}

if viper.GetBool("watch") {
log.Fatal("Watching a URL is not supported.")
}
} else {
data, err = ioutil.ReadFile(uri)
// Set up a new filesystem watcher and reload the router every time
// the file has changed on disk.
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal(err)
}

if viper.GetBool("watch") {
// Set up a new filesystem watcher and reload the router every time
// the file has changed on disk.
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal(err)
}
defer watcher.Close()

go func() {
// Since waiting for events or errors is blocking, we do this in a
// goroutine. It loops forever here but will exit when the process
// is finished, e.g. when you `ctrl+c` to exit.
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
fmt.Printf("🌙 Reloading %s\n", uri)
data, err = ioutil.ReadFile(uri)
if err != nil {
log.Fatal(err)
}

defer watcher.Close()

go func() {
// Since waiting for events or errors is blocking, we do this in a
// goroutine. It loops forever here but will exit when the process
// is finished, e.g. when you `ctrl+c` to exit.
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
fmt.Printf("🌙 Reloading %s\n", uri)
data, err = loadSwaggerFromUri(uri)
if err != nil {
log.Printf("ERROR: %s", err)
} else {
if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
} else {
log.Printf("ERROR: Unable to load OpenAPI document: %s", err)
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
fmt.Println("error:", err)
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Printf("ERROR: %s", err)
}
}()
}
}()

watcher.Add(uri)
}
watcher.Add(uri)
}

swagger, router, err := load(uri, data)
Expand All @@ -697,35 +714,25 @@ func server(cmd *cobra.Command, args []string) {

rr.Set(router)

if strings.HasPrefix(uri, "http") {
http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) {
resp, err := http.Get(uri)
if err != nil {
log.Printf("ERROR: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error while reloading"))
return
}

data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
log.Printf("ERROR: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error while parsing"))
return
}

http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) {
log.Printf("🌙 Reloading %s\n", uri)
data, err = loadSwaggerFromUri(uri)
if err == nil {
if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
}

}
if err == nil {
log.Printf("Reloaded from %s", uri)
w.WriteHeader(200)
w.Write([]byte("reloaded"))
log.Printf("Reloaded from %s", uri)
})
}
} else {
log.Printf("ERROR: %s", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error while reloading"))
}
})

// Add a health check route which returns 200
http.HandleFunc("/__health", func(w http.ResponseWriter, r *http.Request) {
Expand Down