diff --git a/openapi3/schema.go b/openapi3/schema.go index 1be5d8385..a7e845f79 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "math" + "math/big" "regexp" "strconv" "strings" @@ -164,16 +165,22 @@ func NewFloat64Schema() *Schema { } } +func NewIntegerSchema() *Schema { + return &Schema{ + Type: "integer", + } +} + func NewInt32Schema() *Schema { return &Schema{ - Type: "number", + Type: "integer", Format: "int32", } } func NewInt64Schema() *Schema { return &Schema{ - Type: "number", + Type: "integer", Format: "int64", } } @@ -432,14 +439,31 @@ func (schema *Schema) validate(c context.Context, stack []*Schema) error { schemaType := schema.Type switch schemaType { case "": - case "integer", "long", "float", "double": + case "number": + if format := schema.Format; len(format) > 0 { + switch format { + case "float", "double": + default: + return fmt.Errorf("Unsupported 'format' value '%v", format) + } + } + case "integer": + if format := schema.Format; len(format) > 0 { + switch format { + case "int32", "int64": + default: + return fmt.Errorf("Unsupported 'format' value '%v", format) + } + } case "string": - case "byte": - case "binary": + if format := schema.Format; len(format) > 0 { + switch format { + case "byte", "binary", "date", "dateTime", "password": + default: + return fmt.Errorf("Unsupported 'format' value '%v", format) + } + } case "boolean": - case "date": - case "dateTime": - case "password": case "array": if schema.Items == nil { return errors.New("When schema type is 'array', schema 'items' must be non-null") @@ -689,7 +713,7 @@ func (schema *Schema) visitJSONNumber(value float64, fast bool) error { if math.IsInf(value, 0) { return ErrSchemaInputInf } - if err := schema.validateTypeListAllows("number", fast); err != nil { + if err := schema.validateJSONNumber(value, fast); err != nil { return err } @@ -1084,6 +1108,14 @@ func (schema *Schema) visitJSONObject(value map[string]interface{}, fast bool) e return nil } +func (schema *Schema) validateJSONNumber(value float64, fast bool) error { + bigFloat := big.NewFloat(value) + if bigFloat.IsInt() && schema.TypesContains("integer") { + return nil + } + return schema.validateTypeListAllows("number", fast) +} + func (schema *Schema) validateTypeListAllows(value string, fast bool) error { schemaType := schema.Type if schemaType != "" { diff --git a/openapi3/schema_test.go b/openapi3/schema_test.go index bf6291532..fb26de498 100644 --- a/openapi3/schema_test.go +++ b/openapi3/schema_test.go @@ -136,6 +136,34 @@ var schemaExamples = []schemaExample{ }, }, + { + Title: "INTEGER", + Schema: openapi3.NewInt64Schema(). + WithMin(2). + WithMax(5), + Serialization: map[string]interface{}{ + "type": "integer", + "format": "int64", + "minimum": 2, + "maximum": 5, + }, + AllValid: []interface{}{ + 2, + 5, + }, + AllInvalid: []interface{}{ + nil, + false, + true, + 1, + 6, + 3.5, + "", + []interface{}{}, + map[string]interface{}{}, + }, + }, + { Title: "STRING", Schema: openapi3.NewStringSchema(). @@ -510,3 +538,79 @@ var schemaExamples = []schemaExample{ }, }, } + +type schemaTypeExample struct { + Title string + Schema *openapi3.Schema + AllValid []string + AllInvalid []string +} + +func TestTypes(t *testing.T) { + for _, example := range typeExamples { + t.Run(example.Title, testType(t, example)) + } +} + +func testType(t *testing.T, example schemaTypeExample) func(*testing.T) { + return func(t *testing.T) { + schema := example.Schema + for _, typ := range example.AllValid { + err := validateType(t, schema, typ) + require.NoError(t, err) + } + for _, typ := range example.AllInvalid { + err := validateType(t, schema, typ) + require.Error(t, err) + } + } +} + +func validateType(t *testing.T, schema *openapi3.Schema, typ string) error { + schema.WithFormat(typ) + return schema.Validate(nil) +} + +var typeExamples = []schemaTypeExample{ + { + Title: "STRING", + Schema: openapi3.NewStringSchema(), + AllValid: []string{ + "", + "byte", + "binary", + "date", + "dateTime", + "password", + }, + AllInvalid: []string{ + "unsupported", + }, + }, + + { + Title: "NUMBER", + Schema: openapi3.NewFloat64Schema(), + AllValid: []string{ + "", + "float", + "double", + }, + AllInvalid: []string{ + "unsupported", + }, + }, + + { + Title: "INTEGER", + Schema: openapi3.NewIntegerSchema(), + AllValid: []string{ + "", + "int32", + "int64", + }, + AllInvalid: []string{ + "unsupported", + }, + }, +} diff --git a/openapi3gen/openapi3gen.go b/openapi3gen/openapi3gen.go index 1a982257f..4a80405ba 100644 --- a/openapi3gen/openapi3gen.go +++ b/openapi3gen/openapi3gen.go @@ -116,7 +116,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - schema.Type = "number" + schema.Type = "integer" schema.Format = "int64" case reflect.Float32, reflect.Float64: diff --git a/openapi3gen/openapi3gen_test.go b/openapi3gen/openapi3gen_test.go index 83c38433d..d94cfce9f 100644 --- a/openapi3gen/openapi3gen_test.go +++ b/openapi3gen/openapi3gen_test.go @@ -61,11 +61,11 @@ const expectedSimple = ` "type": "boolean" }, "int": { - "type": "number", + "type": "integer", "format": "int64" }, "int64": { - "type": "number", + "type": "integer", "format": "int64" }, "float64": {