diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1e9c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.idea +cert +.vscode +bin +.DS_store +.history +vendor \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..b7f21fc --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/zta-tools.iml b/.idea/zta-tools.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/zta-tools.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/attrmgr/attrmgr.go b/attrmgr/attrmgr.go new file mode 100644 index 0000000..7c692ca --- /dev/null +++ b/attrmgr/attrmgr.go @@ -0,0 +1,231 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package attrmgr + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/json" + "fmt" + + "github.com/pkg/errors" +) + +var ( + // AttrOID is the ASN.1 object identifier for an attribute extension in an + // X509 certificate + AttrOID = asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6, 7, 8, 1} + // AttrOIDString is the string version of AttrOID + AttrOIDString = "1.2.3.4.5.6.7.8.1" +) + +// Attribute is a name/value pair +type Attribute interface { + // GetName returns the name of the attribute + GetName() string + // GetValue returns the value of the attribute + GetValue() interface{} +} + +// AttributeRequest is a request for an attribute +type AttributeRequest interface { + // GetName returns the name of an attribute + GetName() string + // IsRequired returns true if the attribute is required + IsRequired() bool +} + +// New constructs an attribute manager +func New() *Mgr { return &Mgr{} } + +// Mgr is the attribute manager and is the main object for this package +type Mgr struct{} + +// ProcessAttributeRequestsForCert add attributes to an X509 certificate, given +// attribute requests and attributes. +func (mgr *Mgr) ProcessAttributeRequestsForCert(requests []AttributeRequest, attributes []Attribute, cert *x509.Certificate) error { + attrs, err := mgr.ProcessAttributeRequests(requests, attributes) + if err != nil { + return err + } + return mgr.AddAttributesToCert(attrs, cert) +} + +// ProcessAttributeRequests takes an array of attribute requests and an identity's attributes +// and returns an Attributes object containing the requested attributes. +func (mgr *Mgr) ProcessAttributeRequests(requests []AttributeRequest, attributes []Attribute) (*Attributes, error) { + attrsMap := map[string]interface{}{} + attrs := &Attributes{Attrs: attrsMap} + missingRequiredAttrs := []string{} + // For each of the attribute requests + for _, req := range requests { + // Get the attribute + name := req.GetName() + attr := getAttrByName(name, attributes) + if attr == nil { + if req.IsRequired() { + // Didn't find attribute and it was required; return error below + missingRequiredAttrs = append(missingRequiredAttrs, name) + } + // Skip attribute requests which aren't required + continue + } + attrsMap[name] = attr.GetValue() + } + if len(missingRequiredAttrs) > 0 { + return nil, errors.Errorf("The following required attributes are missing: %+v", + missingRequiredAttrs) + } + return attrs, nil +} + +// ToPkixExtension ... +func (mgr *Mgr) ToPkixExtension(attrs *Attributes) (pkix.Extension, error) { + buf, err := json.Marshal(attrs) + if err != nil { + return pkix.Extension{}, errors.Wrap(err, "Failed to marshal attributes") + } + ext := pkix.Extension{ + Id: AttrOID, + Critical: false, + Value: buf, + } + return ext, nil +} + +// AddAttributesToCertRequest ... +func (mgr *Mgr) AddAttributesToCertRequest(attrs *Attributes, cert *x509.CertificateRequest) error { + buf, err := json.Marshal(attrs) + if err != nil { + return errors.Wrap(err, "Failed to marshal attributes") + } + ext := pkix.Extension{ + Id: AttrOID, + Critical: false, + Value: buf, + } + cert.Extensions = append(cert.Extensions, ext) + return nil +} + +// AddAttributesToCert adds public attribute info to an X509 certificate. +func (mgr *Mgr) AddAttributesToCert(attrs *Attributes, cert *x509.Certificate) error { + buf, err := json.Marshal(attrs) + if err != nil { + return errors.Wrap(err, "Failed to marshal attributes") + } + ext := pkix.Extension{ + Id: AttrOID, + Critical: false, + Value: buf, + } + cert.Extensions = append(cert.Extensions, ext) + return nil +} + +// GetAttributesFromCert gets the attributes from a certificate. +func (mgr *Mgr) GetAttributesFromCert(cert *x509.Certificate) (*Attributes, error) { + // Get certificate attributes from the certificate if it exists + buf, err := getAttributesFromCert(cert) + if err != nil { + return nil, err + } + // Unmarshal into attributes object + attrs := &Attributes{} + if buf != nil { + err := json.Unmarshal(buf, attrs) + if err != nil { + return nil, errors.Wrap(err, "Failed to unmarshal attributes from certificate") + } + } + return attrs, nil +} + +// Attributes contains attribute names and values +type Attributes struct { + Attrs map[string]interface{} `json:"attrs"` +} + +// Names returns the names of the attributes +func (a *Attributes) Names() []string { + i := 0 + names := make([]string, len(a.Attrs)) + for name := range a.Attrs { + names[i] = name + i++ + } + return names +} + +// Contains returns true if the named attribute is found +func (a *Attributes) Contains(name string) bool { + _, ok := a.Attrs[name] + return ok +} + +// Value returns an attribute's value +func (a *Attributes) Value(name string) (interface{}, bool, error) { + attr, ok := a.Attrs[name] + return attr, ok, nil +} + +// True returns nil if the value of attribute 'name' is true; +// otherwise, an appropriate error is returned. +func (a *Attributes) True(name string) error { + val, ok, err := a.Value(name) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("Attribute '%s' was not found", name) + } + if val != "true" { + return fmt.Errorf("Attribute '%s' is not true", name) + } + return nil +} + +// Get the attribute info from a certificate extension, or return nil if not found +func getAttributesFromCert(cert *x509.Certificate) ([]byte, error) { + for _, ext := range cert.Extensions { + if isAttrOID(ext.Id) { + return ext.Value, nil + } + } + return nil, nil +} + +// Is the object ID equal to the attribute info object ID? +func isAttrOID(oid asn1.ObjectIdentifier) bool { + if len(oid) != len(AttrOID) { + return false + } + for idx, val := range oid { + if val != AttrOID[idx] { + return false + } + } + return true +} + +// Get an attribute from 'attrs' by its name, or nil if not found +func getAttrByName(name string, attrs []Attribute) Attribute { + for _, attr := range attrs { + if attr.GetName() == name { + return attr + } + } + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9e5baef --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module github.com/ztalab/zta-tools + +go 1.19 + +require ( + github.com/LyricTian/queue v1.3.0 + github.com/go-redis/redis v6.15.9+incompatible + github.com/pkg/errors v0.9.1 + github.com/sirupsen/logrus v1.9.0 + github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1 + github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5 +) + +require ( + github.com/garyburd/redigo v1.6.3 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nxadm/tail v1.4.8 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.8.0 // indirect + go.uber.org/zap v1.24.0 // indirect + golang.org/x/sys v0.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4b4e578 --- /dev/null +++ b/go.sum @@ -0,0 +1,59 @@ +github.com/LyricTian/queue v1.3.0 h1:1xEFZlteW6iu5Qbrz7ZsiSKMKaxY1bQHsbx0jrB1pDA= +github.com/LyricTian/queue v1.3.0/go.mod h1:pbkoplz/zRToCay3pRjz75P8fQAgvkRKJdEzVUQYhXY= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= +github.com/garyburd/redigo v1.6.3 h1:HCeeRluvAgMusMomi1+6Y5dmFOdYV/JzoRrrbFlkGIc= +github.com/garyburd/redigo v1.6.3/go.mod h1:rTb6epsqigu3kYKBnaF028A7Tf/Aw5s0cqA47doKKqw= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1 h1:++PKKjmPGMivF1f6wTyZmu4YkBfVhxQW3I1uH9AKrzo= +github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1/go.mod h1:DM+b+eGl8VRwit889SUB6ylBgfsrI9r7EUn3axB0K5I= +github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5 h1:NMoBEvuJ2/edm+GAg0dJyPhdvEy3qEjX2+s4YssyXF4= +github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5/go.mod h1:HLhrcqY4QjHdVR/VN+3CYPOH6cnBFjq2lOj0SLc3/Hg= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/influxdb/client.go b/influxdb/client.go new file mode 100644 index 0000000..b1c2f92 --- /dev/null +++ b/influxdb/client.go @@ -0,0 +1,60 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package influxdb + +import ( + _ "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client" // this is important because of the bug in go mod + client "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2" + "github.com/ztdbp/ZACA/pkg/logger" +) + +// UDPClient UDP Client +type UDPClient struct { + Conf *Config + BatchPointsConfig client.BatchPointsConfig + client client.Client +} + +func (p *UDPClient) newUDPV1Client() *UDPClient { + udpClient, err := client.NewUDPClient(client.UDPConfig{ + Addr: p.Conf.UDPAddress, + }) + if err != nil { + logger.Errorf("InfluxDBUDPClient err: %v", err) + } + p.client = udpClient + return p +} + +// FluxDBUDPWrite ... +func (p *UDPClient) FluxDBUDPWrite(bp client.BatchPoints) (err error) { + err = p.newUDPV1Client().client.Write(bp) + return +} + +// HTTPClient HTTP Client +type HTTPClient struct { + Client client.Client + BatchPointsConfig client.BatchPointsConfig +} + +// FluxDBHttpWrite ... +func (p *HTTPClient) FluxDBHttpWrite(bp client.BatchPoints) (err error) { + return p.Client.Write(bp) +} + +// FluxDBHttpClose ... +func (p *HTTPClient) FluxDBHttpClose() (err error) { + return p.Client.Close() +} diff --git a/influxdb/client/influxdb.go b/influxdb/client/influxdb.go new file mode 100644 index 0000000..d605624 --- /dev/null +++ b/influxdb/client/influxdb.go @@ -0,0 +1,881 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client" + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models" +) + +const ( + // DefaultHost is the default host used to connect to an InfluxDB instance + DefaultHost = "localhost" + + // DefaultPort is the default port used to connect to an InfluxDB instance + DefaultPort = 8086 + + // DefaultTimeout is the default connection timeout used to connect to an InfluxDB instance + DefaultTimeout = 0 +) + +// Query is used to send a command to the server. Both Command and Database are required. +type Query struct { + Command string + Database string + + // RetentionPolicy tells the server which retention policy to use by default. + // This option is only effective when querying a server of version 1.6.0 or later. + RetentionPolicy string + + // Chunked tells the server to send back chunked responses. This places + // less load on the server by sending back chunks of the response rather + // than waiting for the entire response all at once. + Chunked bool + + // ChunkSize sets the maximum number of rows that will be returned per + // chunk. Chunks are either divided based on their series or if they hit + // the chunk size limit. + // + // Chunked must be set to true for this option to be used. + ChunkSize int + + // NodeID sets the data node to use for the query results. This option only + // has any effect in the enterprise version of the software where there can be + // more than one data node and is primarily useful for analyzing differences in + // data. The default behavior is to automatically select the appropriate data + // nodes to retrieve all of the data. On a database where the number of data nodes + // is greater than the replication factor, it is expected that setting this option + // will only retrieve partial data. + NodeID int +} + +// ParseConnectionString will parse a string to create a valid connection URL +func ParseConnectionString(path string, ssl bool) (url.URL, error) { + var host string + var port int + + h, p, err := net.SplitHostPort(path) + if err != nil { + if path == "" { + host = DefaultHost + } else { + host = path + } + // If they didn't specify a port, always use the default port + port = DefaultPort + } else { + host = h + port, err = strconv.Atoi(p) + if err != nil { + return url.URL{}, fmt.Errorf("invalid port number %q: %s\n", path, err) + } + } + + u := url.URL{ + Scheme: "http", + Host: host, + } + if ssl { + u.Scheme = "https" + if port != 443 { + u.Host = net.JoinHostPort(host, strconv.Itoa(port)) + } + } else if port != 80 { + u.Host = net.JoinHostPort(host, strconv.Itoa(port)) + } + + return u, nil +} + +// Config is used to specify what server to connect to. +// URL: The URL of the server connecting to. +// Username/Password are optional. They will be passed via basic auth if provided. +// UserAgent: If not provided, will default "InfluxDBClient", +// Timeout: If not provided, will default to 0 (no timeout) +type Config struct { + URL url.URL + UnixSocket string + Username string + Password string + UserAgent string + Timeout time.Duration + Precision string + WriteConsistency string + UnsafeSsl bool + Proxy func(req *http.Request) (*url.URL, error) + TLS *tls.Config +} + +// NewConfig will create a config to be used in connecting to the client +func NewConfig() Config { + return Config{ + Timeout: DefaultTimeout, + } +} + +// Client is used to make calls to the server. +type Client struct { + url url.URL + unixSocket string + username string + password string + httpClient *http.Client + userAgent string + precision string +} + +const ( + // ConsistencyOne requires at least one data node acknowledged a write. + ConsistencyOne = "one" + + // ConsistencyAll requires all data nodes to acknowledge a write. + ConsistencyAll = "all" + + // ConsistencyQuorum requires a quorum of data nodes to acknowledge a write. + ConsistencyQuorum = "quorum" + + // ConsistencyAny allows for hinted hand off, potentially no write happened yet. + ConsistencyAny = "any" +) + +// NewClient will instantiate and return a connected client to issue commands to the server. +func NewClient(c Config) (*Client, error) { + tlsConfig := new(tls.Config) + if c.TLS != nil { + tlsConfig = c.TLS.Clone() + } + tlsConfig.InsecureSkipVerify = c.UnsafeSsl + + tr := &http.Transport{ + Proxy: c.Proxy, + TLSClientConfig: tlsConfig, + } + + if c.UnixSocket != "" { + // No need for compression in local communications. + tr.DisableCompression = true + + tr.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", c.UnixSocket) + } + } + + client := Client{ + url: c.URL, + unixSocket: c.UnixSocket, + username: c.Username, + password: c.Password, + httpClient: &http.Client{Timeout: c.Timeout, Transport: tr}, + userAgent: c.UserAgent, + precision: c.Precision, + } + if client.userAgent == "" { + client.userAgent = "InfluxDBClient" + } + return &client, nil +} + +// SetAuth will update the username and passwords +func (c *Client) SetAuth(u, p string) { + c.username = u + c.password = p +} + +// SetPrecision will update the precision +func (c *Client) SetPrecision(precision string) { + c.precision = precision +} + +// Query sends a command to the server and returns the Response +func (c *Client) Query(q Query) (*Response, error) { + return c.QueryContext(context.Background(), q) +} + +// QueryContext sends a command to the server and returns the Response +// It uses a context that can be cancelled by the command line client +func (c *Client) QueryContext(ctx context.Context, q Query) (*Response, error) { + u := c.url + u.Path = path.Join(u.Path, "query") + + values := u.Query() + values.Set("q", q.Command) + values.Set("db", q.Database) + if q.RetentionPolicy != "" { + values.Set("rp", q.RetentionPolicy) + } + if q.Chunked { + values.Set("chunked", "true") + if q.ChunkSize > 0 { + values.Set("chunk_size", strconv.Itoa(q.ChunkSize)) + } + } + if q.NodeID > 0 { + values.Set("node_id", strconv.Itoa(q.NodeID)) + } + if c.precision != "" { + values.Set("epoch", c.precision) + } + u.RawQuery = values.Encode() + + req, err := http.NewRequest("POST", u.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + req = req.WithContext(ctx) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + if q.Chunked { + cr := NewChunkedResponse(resp.Body) + for { + r, err := cr.NextResponse() + if err != nil { + // If we got an error while decoding the response, send that back. + return nil, err + } + + if r == nil { + break + } + + response.Results = append(response.Results, r.Results...) + if r.Err != nil { + response.Err = r.Err + break + } + } + } else { + dec := json.NewDecoder(resp.Body) + dec.UseNumber() + if err := dec.Decode(&response); err != nil { + // Ignore EOF errors if we got an invalid status code. + if !(err == io.EOF && resp.StatusCode != http.StatusOK) { + return nil, err + } + } + } + + // If we don't have an error in our json response, and didn't get StatusOK, + // then send back an error. + if resp.StatusCode != http.StatusOK && response.Error() == nil { + return &response, fmt.Errorf("received status code %d from server", resp.StatusCode) + } + return &response, nil +} + +// Write takes BatchPoints and allows for writing of multiple points with defaults +// If successful, error is nil and Response is nil +// If an error occurs, Response may contain additional information if populated. +func (c *Client) Write(bp BatchPoints) (*Response, error) { + u := c.url + u.Path = path.Join(u.Path, "write") + + var b bytes.Buffer + for _, p := range bp.Points { + err := checkPointTypes(p) + if err != nil { + return nil, err + } + if p.Raw != "" { + if _, err := b.WriteString(p.Raw); err != nil { + return nil, err + } + } else { + for k, v := range bp.Tags { + if p.Tags == nil { + p.Tags = make(map[string]string, len(bp.Tags)) + } + p.Tags[k] = v + } + + if _, err := b.WriteString(p.MarshalString()); err != nil { + return nil, err + } + } + + if err := b.WriteByte('\n'); err != nil { + return nil, err + } + } + + req, err := http.NewRequest("POST", u.String(), &b) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + precision := bp.Precision + if precision == "" { + precision = c.precision + } + + params := req.URL.Query() + params.Set("db", bp.Database) + params.Set("rp", bp.RetentionPolicy) + params.Set("precision", precision) + params.Set("consistency", bp.WriteConsistency) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + var err = fmt.Errorf(string(body)) + response.Err = err + return &response, err + } + + return nil, nil +} + +// WriteLineProtocol takes a string with line returns to delimit each write +// If successful, error is nil and Response is nil +// If an error occurs, Response may contain additional information if populated. +func (c *Client) WriteLineProtocol(data, database, retentionPolicy, precision, writeConsistency string) (*Response, error) { + u := c.url + u.Path = path.Join(u.Path, "write") + + r := strings.NewReader(data) + + req, err := http.NewRequest("POST", u.String(), r) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + params := req.URL.Query() + params.Set("db", database) + params.Set("rp", retentionPolicy) + params.Set("precision", precision) + params.Set("consistency", writeConsistency) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var response Response + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + err := fmt.Errorf(string(body)) + response.Err = err + return &response, err + } + + return nil, nil +} + +// Ping will check to see if the server is up +// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred. +func (c *Client) Ping() (time.Duration, string, error) { + now := time.Now() + + u := c.url + u.Path = path.Join(u.Path, "ping") + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return 0, "", err + } + req.Header.Set("User-Agent", c.userAgent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + version := resp.Header.Get("X-Influxdb-Version") + return time.Since(now), version, nil +} + +// Structs + +// Message represents a user message. +type Message struct { + Level string `json:"level,omitempty"` + Text string `json:"text,omitempty"` +} + +// Result represents a resultset returned from a single statement. +type Result struct { + Series []models.Row + Messages []*Message + Err error +} + +// MarshalJSON encodes the result into JSON. +func (r *Result) MarshalJSON() ([]byte, error) { + // Define a struct that outputs "error" as a string. + var o struct { + Series []models.Row `json:"series,omitempty"` + Messages []*Message `json:"messages,omitempty"` + Err string `json:"error,omitempty"` + } + + // Copy fields to output struct. + o.Series = r.Series + o.Messages = r.Messages + if r.Err != nil { + o.Err = r.Err.Error() + } + + return json.Marshal(&o) +} + +// UnmarshalJSON decodes the data into the Result struct +func (r *Result) UnmarshalJSON(b []byte) error { + var o struct { + Series []models.Row `json:"series,omitempty"` + Messages []*Message `json:"messages,omitempty"` + Err string `json:"error,omitempty"` + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + err := dec.Decode(&o) + if err != nil { + return err + } + r.Series = o.Series + r.Messages = o.Messages + if o.Err != "" { + r.Err = errors.New(o.Err) + } + return nil +} + +// Response represents a list of statement results. +type Response struct { + Results []Result + Err error +} + +// MarshalJSON encodes the response into JSON. +func (r *Response) MarshalJSON() ([]byte, error) { + // Define a struct that outputs "error" as a string. + var o struct { + Results []Result `json:"results,omitempty"` + Err string `json:"error,omitempty"` + } + + // Copy fields to output struct. + o.Results = r.Results + if r.Err != nil { + o.Err = r.Err.Error() + } + + return json.Marshal(&o) +} + +// UnmarshalJSON decodes the data into the Response struct +func (r *Response) UnmarshalJSON(b []byte) error { + var o struct { + Results []Result `json:"results,omitempty"` + Err string `json:"error,omitempty"` + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + err := dec.Decode(&o) + if err != nil { + return err + } + r.Results = o.Results + if o.Err != "" { + r.Err = errors.New(o.Err) + } + return nil +} + +// Error returns the first error from any statement. +// Returns nil if no errors occurred on any statements. +func (r *Response) Error() error { + if r.Err != nil { + return r.Err + } + for _, result := range r.Results { + if result.Err != nil { + return result.Err + } + } + return nil +} + +// duplexReader reads responses and writes it to another writer while +// satisfying the reader interface. +type duplexReader struct { + r io.Reader + w io.Writer +} + +func (r *duplexReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil { + r.w.Write(p[:n]) + } + return n, err +} + +// ChunkedResponse represents a response from the server that +// uses chunking to stream the output. +type ChunkedResponse struct { + dec *json.Decoder + duplex *duplexReader + buf bytes.Buffer +} + +// NewChunkedResponse reads a stream and produces responses from the stream. +func NewChunkedResponse(r io.Reader) *ChunkedResponse { + resp := &ChunkedResponse{} + resp.duplex = &duplexReader{r: r, w: &resp.buf} + resp.dec = json.NewDecoder(resp.duplex) + resp.dec.UseNumber() + return resp +} + +// NextResponse reads the next line of the stream and returns a response. +func (r *ChunkedResponse) NextResponse() (*Response, error) { + var response Response + if err := r.dec.Decode(&response); err != nil { + if err == io.EOF { + return nil, nil + } + // A decoding error happened. This probably means the server crashed + // and sent a last-ditch error message to us. Ensure we have read the + // entirety of the connection to get any remaining error text. + io.Copy(ioutil.Discard, r.duplex) + return nil, errors.New(strings.TrimSpace(r.buf.String())) + } + r.buf.Reset() + return &response, nil +} + +// Point defines the fields that will be written to the database +// Measurement, Time, and Fields are required +// Precision can be specified if the time is in epoch format (integer). +// Valid values for Precision are n, u, ms, s, m, and h +type Point struct { + Measurement string + Tags map[string]string + Time time.Time + Fields map[string]interface{} + Precision string + Raw string +} + +// MarshalJSON will format the time in RFC3339Nano +// Precision is also ignored as it is only used for writing, not reading +// Or another way to say it is we always send back in nanosecond precision +func (p *Point) MarshalJSON() ([]byte, error) { + point := struct { + Measurement string `json:"measurement,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Time string `json:"time,omitempty"` + Fields map[string]interface{} `json:"fields,omitempty"` + Precision string `json:"precision,omitempty"` + }{ + Measurement: p.Measurement, + Tags: p.Tags, + Fields: p.Fields, + Precision: p.Precision, + } + // Let it omit empty if it's really zero + if !p.Time.IsZero() { + point.Time = p.Time.UTC().Format(time.RFC3339Nano) + } + return json.Marshal(&point) +} + +// MarshalString renders string representation of a Point with specified +// precision. The default precision is nanoseconds. +func (p *Point) MarshalString() string { + pt, err := models.NewPoint(p.Measurement, models.NewTags(p.Tags), p.Fields, p.Time) + if err != nil { + return "# ERROR: " + err.Error() + " " + p.Measurement + } + if p.Precision == "" || p.Precision == "ns" || p.Precision == "n" { + return pt.String() + } + return pt.PrecisionString(p.Precision) +} + +// UnmarshalJSON decodes the data into the Point struct +func (p *Point) UnmarshalJSON(b []byte) error { + var normal struct { + Measurement string `json:"measurement"` + Tags map[string]string `json:"tags"` + Time time.Time `json:"time"` + Precision string `json:"precision"` + Fields map[string]interface{} `json:"fields"` + } + var epoch struct { + Measurement string `json:"measurement"` + Tags map[string]string `json:"tags"` + Time *int64 `json:"time"` + Precision string `json:"precision"` + Fields map[string]interface{} `json:"fields"` + } + + if err := func() error { + var err error + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + if err = dec.Decode(&epoch); err != nil { + return err + } + // Convert from epoch to time.Time, but only if Time + // was actually set. + var ts time.Time + if epoch.Time != nil { + ts, err = EpochToTime(*epoch.Time, epoch.Precision) + if err != nil { + return err + } + } + p.Measurement = epoch.Measurement + p.Tags = epoch.Tags + p.Time = ts + p.Precision = epoch.Precision + p.Fields = normalizeFields(epoch.Fields) + return nil + }(); err == nil { + return nil + } + + dec := json.NewDecoder(bytes.NewBuffer(b)) + dec.UseNumber() + if err := dec.Decode(&normal); err != nil { + return err + } + normal.Time = SetPrecision(normal.Time, normal.Precision) + p.Measurement = normal.Measurement + p.Tags = normal.Tags + p.Time = normal.Time + p.Precision = normal.Precision + p.Fields = normalizeFields(normal.Fields) + + return nil +} + +// Remove any notion of json.Number +func normalizeFields(fields map[string]interface{}) map[string]interface{} { + newFields := map[string]interface{}{} + + for k, v := range fields { + switch v := v.(type) { + case json.Number: + jv, e := v.Float64() + if e != nil { + panic(fmt.Sprintf("unable to convert json.Number to float64: %s", e)) + } + newFields[k] = jv + default: + newFields[k] = v + } + } + return newFields +} + +// BatchPoints is used to send batched data in a single write. +// Database and Points are required +// If no retention policy is specified, it will use the databases default retention policy. +// If tags are specified, they will be "merged" with all points. If a point already has that tag, it will be ignored. +// If time is specified, it will be applied to any point with an empty time. +// Precision can be specified if the time is in epoch format (integer). +// Valid values for Precision are n, u, ms, s, m, and h +type BatchPoints struct { + Points []Point `json:"points,omitempty"` + Database string `json:"database,omitempty"` + RetentionPolicy string `json:"retentionPolicy,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Time time.Time `json:"time,omitempty"` + Precision string `json:"precision,omitempty"` + WriteConsistency string `json:"-"` +} + +// UnmarshalJSON decodes the data into the BatchPoints struct +func (bp *BatchPoints) UnmarshalJSON(b []byte) error { + var normal struct { + Points []Point `json:"points"` + Database string `json:"database"` + RetentionPolicy string `json:"retentionPolicy"` + Tags map[string]string `json:"tags"` + Time time.Time `json:"time"` + Precision string `json:"precision"` + } + var epoch struct { + Points []Point `json:"points"` + Database string `json:"database"` + RetentionPolicy string `json:"retentionPolicy"` + Tags map[string]string `json:"tags"` + Time *int64 `json:"time"` + Precision string `json:"precision"` + } + + if err := func() error { + var err error + if err = json.Unmarshal(b, &epoch); err != nil { + return err + } + // Convert from epoch to time.Time + var ts time.Time + if epoch.Time != nil { + ts, err = EpochToTime(*epoch.Time, epoch.Precision) + if err != nil { + return err + } + } + bp.Points = epoch.Points + bp.Database = epoch.Database + bp.RetentionPolicy = epoch.RetentionPolicy + bp.Tags = epoch.Tags + bp.Time = ts + bp.Precision = epoch.Precision + return nil + }(); err == nil { + return nil + } + + if err := json.Unmarshal(b, &normal); err != nil { + return err + } + normal.Time = SetPrecision(normal.Time, normal.Precision) + bp.Points = normal.Points + bp.Database = normal.Database + bp.RetentionPolicy = normal.RetentionPolicy + bp.Tags = normal.Tags + bp.Time = normal.Time + bp.Precision = normal.Precision + + return nil +} + +// utility functions + +// Addr provides the current url as a string of the server the client is connected to. +func (c *Client) Addr() string { + if c.unixSocket != "" { + return c.unixSocket + } + return c.url.String() +} + +// checkPointTypes ensures no unsupported types are submitted to influxdb, returning error if they are found. +func checkPointTypes(p Point) error { + for _, v := range p.Fields { + switch v.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool, string, nil: + return nil + default: + return fmt.Errorf("unsupported point type: %T", v) + } + } + return nil +} + +// helper functions + +// EpochToTime takes a unix epoch time and uses precision to return back a time.Time +func EpochToTime(epoch int64, precision string) (time.Time, error) { + if precision == "" { + precision = "s" + } + var t time.Time + switch precision { + case "h": + t = time.Unix(0, epoch*int64(time.Hour)) + case "m": + t = time.Unix(0, epoch*int64(time.Minute)) + case "s": + t = time.Unix(0, epoch*int64(time.Second)) + case "ms": + t = time.Unix(0, epoch*int64(time.Millisecond)) + case "u": + t = time.Unix(0, epoch*int64(time.Microsecond)) + case "n": + t = time.Unix(0, epoch) + default: + return time.Time{}, fmt.Errorf("Unknown precision %q", precision) + } + return t, nil +} + +// SetPrecision will round a time to the specified precision +func SetPrecision(t time.Time, precision string) time.Time { + switch precision { + case "n": + case "u": + return t.Round(time.Microsecond) + case "ms": + return t.Round(time.Millisecond) + case "s": + return t.Round(time.Second) + case "m": + return t.Round(time.Minute) + case "h": + return t.Round(time.Hour) + } + return t +} diff --git a/influxdb/client/models/inline_fnv.go b/influxdb/client/models/inline_fnv.go new file mode 100644 index 0000000..4c89b4b --- /dev/null +++ b/influxdb/client/models/inline_fnv.go @@ -0,0 +1,45 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models" + +// from stdlib hash/fnv/fnv.go +const ( + prime64 = 1099511628211 + offset64 = 14695981039346656037 +) + +// InlineFNV64a is an alloc-free port of the standard library's fnv64a. +// See https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function. +type InlineFNV64a uint64 + +// NewInlineFNV64a returns a new instance of InlineFNV64a. +func NewInlineFNV64a() InlineFNV64a { + return offset64 +} + +// Write adds data to the running hash. +func (s *InlineFNV64a) Write(data []byte) (int, error) { + hash := uint64(*s) + for _, c := range data { + hash ^= uint64(c) + hash *= prime64 + } + *s = InlineFNV64a(hash) + return len(data), nil +} + +// Sum64 returns the uint64 of the current resulting hash. +func (s *InlineFNV64a) Sum64() uint64 { + return uint64(*s) +} diff --git a/influxdb/client/models/inline_strconv_parse.go b/influxdb/client/models/inline_strconv_parse.go new file mode 100644 index 0000000..9f763b0 --- /dev/null +++ b/influxdb/client/models/inline_strconv_parse.go @@ -0,0 +1,57 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models" + +import ( + "reflect" + "strconv" + "unsafe" +) + +// parseIntBytes is a zero-alloc wrapper around strconv.ParseInt. +func parseIntBytes(b []byte, base int, bitSize int) (i int64, err error) { + s := unsafeBytesToString(b) + return strconv.ParseInt(s, base, bitSize) +} + +// parseUintBytes is a zero-alloc wrapper around strconv.ParseUint. +func parseUintBytes(b []byte, base int, bitSize int) (i uint64, err error) { + s := unsafeBytesToString(b) + return strconv.ParseUint(s, base, bitSize) +} + +// parseFloatBytes is a zero-alloc wrapper around strconv.ParseFloat. +func parseFloatBytes(b []byte, bitSize int) (float64, error) { + s := unsafeBytesToString(b) + return strconv.ParseFloat(s, bitSize) +} + +// parseBoolBytes is a zero-alloc wrapper around strconv.ParseBool. +func parseBoolBytes(b []byte) (bool, error) { + return strconv.ParseBool(unsafeBytesToString(b)) +} + +// unsafeBytesToString converts a []byte to a string without a heap allocation. +// +// It is unsafe, and is intended to prepare input to short-lived functions +// that require strings. +func unsafeBytesToString(in []byte) string { + src := *(*reflect.SliceHeader)(unsafe.Pointer(&in)) + dst := reflect.StringHeader{ + Data: src.Data, + Len: src.Len, + } + s := *(*string)(unsafe.Pointer(&dst)) + return s +} diff --git a/influxdb/client/models/points.go b/influxdb/client/models/points.go new file mode 100644 index 0000000..1694aa3 --- /dev/null +++ b/influxdb/client/models/points.go @@ -0,0 +1,2425 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models" + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "sort" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/pkg/escape" +) + +type escapeSet struct { + k [1]byte + esc [2]byte +} + +var ( + measurementEscapeCodes = [...]escapeSet{ + {k: [1]byte{','}, esc: [2]byte{'\\', ','}}, + {k: [1]byte{' '}, esc: [2]byte{'\\', ' '}}, + } + + tagEscapeCodes = [...]escapeSet{ + {k: [1]byte{','}, esc: [2]byte{'\\', ','}}, + {k: [1]byte{' '}, esc: [2]byte{'\\', ' '}}, + {k: [1]byte{'='}, esc: [2]byte{'\\', '='}}, + } + + // ErrPointMustHaveAField is returned when operating on a point that does not have any fields. + ErrPointMustHaveAField = errors.New("point without fields is unsupported") + + // ErrInvalidNumber is returned when a number is expected but not provided. + ErrInvalidNumber = errors.New("invalid number") + + // ErrInvalidPoint is returned when a point cannot be parsed correctly. + ErrInvalidPoint = errors.New("point is invalid") +) + +const ( + // MaxKeyLength is the largest allowed size of the combined measurement and tag keys. + MaxKeyLength = 65535 +) + +// enableUint64Support will enable uint64 support if set to true. +var enableUint64Support = false + +// EnableUintSupport manually enables uint support for the point parser. +// This function will be removed in the future and only exists for unit tests during the +// transition. +func EnableUintSupport() { + enableUint64Support = true +} + +// Point defines the values that will be written to the database. +type Point interface { + // Name return the measurement name for the point. + Name() []byte + + // SetName updates the measurement name for the point. + SetName(string) + + // Tags returns the tag set for the point. + Tags() Tags + + // ForEachTag iterates over each tag invoking fn. If fn return false, iteration stops. + ForEachTag(fn func(k, v []byte) bool) + + // AddTag adds or replaces a tag value for a point. + AddTag(key, value string) + + // SetTags replaces the tags for the point. + SetTags(tags Tags) + + // HasTag returns true if the tag exists for the point. + HasTag(tag []byte) bool + + // Fields returns the fields for the point. + Fields() (Fields, error) + + // Time return the timestamp for the point. + Time() time.Time + + // SetTime updates the timestamp for the point. + SetTime(t time.Time) + + // UnixNano returns the timestamp of the point as nanoseconds since Unix epoch. + UnixNano() int64 + + // HashID returns a non-cryptographic checksum of the point's key. + HashID() uint64 + + // Key returns the key (measurement joined with tags) of the point. + Key() []byte + + // String returns a string representation of the point. If there is a + // timestamp associated with the point then it will be specified with the default + // precision of nanoseconds. + String() string + + // MarshalBinary returns a binary representation of the point. + MarshalBinary() ([]byte, error) + + // PrecisionString returns a string representation of the point. If there + // is a timestamp associated with the point then it will be specified in the + // given unit. + PrecisionString(precision string) string + + // RoundedString returns a string representation of the point. If there + // is a timestamp associated with the point, then it will be rounded to the + // given duration. + RoundedString(d time.Duration) string + + // Split will attempt to return multiple points with the same timestamp whose + // string representations are no longer than size. Points with a single field or + // a point without a timestamp may exceed the requested size. + Split(size int) []Point + + // Round will round the timestamp of the point to the given duration. + Round(d time.Duration) + + // StringSize returns the length of the string that would be returned by String(). + StringSize() int + + // AppendString appends the result of String() to the provided buffer and returns + // the result, potentially reducing string allocations. + AppendString(buf []byte) []byte + + // FieldIterator retuns a FieldIterator that can be used to traverse the + // fields of a point without constructing the in-memory map. + FieldIterator() FieldIterator +} + +// FieldType represents the type of a field. +type FieldType int + +const ( + // Integer indicates the field's type is integer. + Integer FieldType = iota + + // Float indicates the field's type is float. + Float + + // Boolean indicates the field's type is boolean. + Boolean + + // String indicates the field's type is string. + String + + // Empty is used to indicate that there is no field. + Empty + + // Unsigned indicates the field's type is an unsigned integer. + Unsigned +) + +// FieldIterator provides a low-allocation interface to iterate through a point's fields. +type FieldIterator interface { + // Next indicates whether there any fields remaining. + Next() bool + + // FieldKey returns the key of the current field. + FieldKey() []byte + + // Type returns the FieldType of the current field. + Type() FieldType + + // StringValue returns the string value of the current field. + StringValue() string + + // IntegerValue returns the integer value of the current field. + IntegerValue() (int64, error) + + // UnsignedValue returns the unsigned value of the current field. + UnsignedValue() (uint64, error) + + // BooleanValue returns the boolean value of the current field. + BooleanValue() (bool, error) + + // FloatValue returns the float value of the current field. + FloatValue() (float64, error) + + // Reset resets the iterator to its initial state. + Reset() +} + +// Points represents a sortable list of points by timestamp. +type Points []Point + +// Len implements sort.Interface. +func (a Points) Len() int { return len(a) } + +// Less implements sort.Interface. +func (a Points) Less(i, j int) bool { return a[i].Time().Before(a[j].Time()) } + +// Swap implements sort.Interface. +func (a Points) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// point is the default implementation of Point. +type point struct { + time time.Time + + // text encoding of measurement and tags + // key must always be stored sorted by tags, if the original line was not sorted, + // we need to resort it + key []byte + + // text encoding of field data + fields []byte + + // text encoding of timestamp + ts []byte + + // cached version of parsed fields from data + cachedFields map[string]interface{} + + // cached version of parsed name from key + cachedName string + + // cached version of parsed tags + cachedTags Tags + + it fieldIterator +} + +// type assertions +var ( + _ Point = (*point)(nil) + _ FieldIterator = (*point)(nil) +) + +const ( + // the number of characters for the largest possible int64 (9223372036854775807) + maxInt64Digits = 19 + + // the number of characters for the smallest possible int64 (-9223372036854775808) + minInt64Digits = 20 + + // the number of characters for the largest possible uint64 (18446744073709551615) + maxUint64Digits = 20 + + // the number of characters required for the largest float64 before a range check + // would occur during parsing + maxFloat64Digits = 25 + + // the number of characters required for smallest float64 before a range check occur + // would occur during parsing + minFloat64Digits = 27 +) + +// ParsePoints returns a slice of Points from a text representation of a point +// with each point separated by newlines. If any points fail to parse, a non-nil error +// will be returned in addition to the points that parsed successfully. +func ParsePoints(buf []byte) ([]Point, error) { + return ParsePointsWithPrecision(buf, time.Now().UTC(), "n") +} + +// ParsePointsString is identical to ParsePoints but accepts a string. +func ParsePointsString(buf string) ([]Point, error) { + return ParsePoints([]byte(buf)) +} + +// ParseKey returns the measurement name and tags from a point. +// +// NOTE: to minimize heap allocations, the returned Tags will refer to subslices of buf. +// This can have the unintended effect preventing buf from being garbage collected. +func ParseKey(buf []byte) (string, Tags) { + name, tags := ParseKeyBytes(buf) + return string(name), tags +} + +func ParseKeyBytes(buf []byte) ([]byte, Tags) { + return ParseKeyBytesWithTags(buf, nil) +} + +func ParseKeyBytesWithTags(buf []byte, tags Tags) ([]byte, Tags) { + // Ignore the error because scanMeasurement returns "missing fields" which we ignore + // when just parsing a key + state, i, _ := scanMeasurement(buf, 0) + + var name []byte + if state == tagKeyState { + tags = parseTags(buf, tags) + // scanMeasurement returns the location of the comma if there are tags, strip that off + name = buf[:i-1] + } else { + name = buf[:i] + } + return unescapeMeasurement(name), tags +} + +func ParseTags(buf []byte) Tags { + return parseTags(buf, nil) +} + +func ParseName(buf []byte) []byte { + // Ignore the error because scanMeasurement returns "missing fields" which we ignore + // when just parsing a key + state, i, _ := scanMeasurement(buf, 0) + var name []byte + if state == tagKeyState { + name = buf[:i-1] + } else { + name = buf[:i] + } + + return unescapeMeasurement(name) +} + +// ParsePointsWithPrecision is similar to ParsePoints, but allows the +// caller to provide a precision for time. +// +// NOTE: to minimize heap allocations, the returned Points will refer to subslices of buf. +// This can have the unintended effect preventing buf from being garbage collected. +func ParsePointsWithPrecision(buf []byte, defaultTime time.Time, precision string) ([]Point, error) { + points := make([]Point, 0, bytes.Count(buf, []byte{'\n'})+1) + var ( + pos int + block []byte + failed []string + ) + for pos < len(buf) { + pos, block = scanLine(buf, pos) + pos++ + + if len(block) == 0 { + continue + } + + start := skipWhitespace(block, 0) + + // If line is all whitespace, just skip it + if start >= len(block) { + continue + } + + // lines which start with '#' are comments + if block[start] == '#' { + continue + } + + // strip the newline if one is present + if block[len(block)-1] == '\n' { + block = block[:len(block)-1] + } + + pt, err := parsePoint(block[start:], defaultTime, precision) + if err != nil { + failed = append(failed, fmt.Sprintf("unable to parse '%s': %v", string(block[start:]), err)) + } else { + points = append(points, pt) + } + + } + if len(failed) > 0 { + return points, fmt.Errorf("%s", strings.Join(failed, "\n")) + } + return points, nil + +} + +func parsePoint(buf []byte, defaultTime time.Time, precision string) (Point, error) { + // scan the first block which is measurement[,tag1=value1,tag2=value2...] + pos, key, err := scanKey(buf, 0) + if err != nil { + return nil, err + } + + // measurement name is required + if len(key) == 0 { + return nil, fmt.Errorf("missing measurement") + } + + if len(key) > MaxKeyLength { + return nil, fmt.Errorf("max key length exceeded: %v > %v", len(key), MaxKeyLength) + } + + // scan the second block is which is field1=value1[,field2=value2,...] + pos, fields, err := scanFields(buf, pos) + if err != nil { + return nil, err + } + + // at least one field is required + if len(fields) == 0 { + return nil, fmt.Errorf("missing fields") + } + + var maxKeyErr error + err = walkFields(fields, func(k, v []byte) bool { + if sz := seriesKeySize(key, k); sz > MaxKeyLength { + maxKeyErr = fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength) + return false + } + return true + }) + + if err != nil { + return nil, err + } + + if maxKeyErr != nil { + return nil, maxKeyErr + } + + // scan the last block which is an optional integer timestamp + pos, ts, err := scanTime(buf, pos) + if err != nil { + return nil, err + } + + pt := &point{ + key: key, + fields: fields, + ts: ts, + } + + if len(ts) == 0 { + pt.time = defaultTime + pt.SetPrecision(precision) + } else { + ts, err := parseIntBytes(ts, 10, 64) + if err != nil { + return nil, err + } + pt.time, err = SafeCalcTime(ts, precision) + if err != nil { + return nil, err + } + + // Determine if there are illegal non-whitespace characters after the + // timestamp block. + for pos < len(buf) { + if buf[pos] != ' ' { + return nil, ErrInvalidPoint + } + pos++ + } + } + return pt, nil +} + +// GetPrecisionMultiplier will return a multiplier for the precision specified. +func GetPrecisionMultiplier(precision string) int64 { + d := time.Nanosecond + switch precision { + case "u": + d = time.Microsecond + case "ms": + d = time.Millisecond + case "s": + d = time.Second + case "m": + d = time.Minute + case "h": + d = time.Hour + } + return int64(d) +} + +// scanKey scans buf starting at i for the measurement and tag portion of the point. +// It returns the ending position and the byte slice of key within buf. If there +// are tags, they will be sorted if they are not already. +func scanKey(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + + i = start + + // Determines whether the tags are sort, assume they are + sorted := true + + // indices holds the indexes within buf of the start of each tag. For example, + // a buf of 'cpu,host=a,region=b,zone=c' would have indices slice of [4,11,20] + // which indicates that the first tag starts at buf[4], seconds at buf[11], and + // last at buf[20] + indices := make([]int, 100) + + // tracks how many commas we've seen so we know how many values are indices. + // Since indices is an arbitrarily large slice, + // we need to know how many values in the buffer are in use. + commas := 0 + + // First scan the Point's measurement. + state, i, err := scanMeasurement(buf, i) + if err != nil { + return i, buf[start:i], err + } + + // Optionally scan tags if needed. + if state == tagKeyState { + i, commas, indices, err = scanTags(buf, i, indices) + if err != nil { + return i, buf[start:i], err + } + } + + // Now we know where the key region is within buf, and the location of tags, we + // need to determine if duplicate tags exist and if the tags are sorted. This iterates + // over the list comparing each tag in the sequence with each other. + for j := 0; j < commas-1; j++ { + // get the left and right tags + _, left := scanTo(buf[indices[j]:indices[j+1]-1], 0, '=') + _, right := scanTo(buf[indices[j+1]:indices[j+2]-1], 0, '=') + + // If left is greater than right, the tags are not sorted. We do not have to + // continue because the short path no longer works. + // If the tags are equal, then there are duplicate tags, and we should abort. + // If the tags are not sorted, this pass may not find duplicate tags and we + // need to do a more exhaustive search later. + if cmp := bytes.Compare(left, right); cmp > 0 { + sorted = false + break + } else if cmp == 0 { + return i, buf[start:i], fmt.Errorf("duplicate tags") + } + } + + // If the tags are not sorted, then sort them. This sort is inline and + // uses the tag indices we created earlier. The actual buffer is not sorted, the + // indices are using the buffer for value comparison. After the indices are sorted, + // the buffer is reconstructed from the sorted indices. + if !sorted && commas > 0 { + // Get the measurement name for later + measurement := buf[start : indices[0]-1] + + // Sort the indices + indices := indices[:commas] + insertionSort(0, commas, buf, indices) + + // Create a new key using the measurement and sorted indices + b := make([]byte, len(buf[start:i])) + pos := copy(b, measurement) + for _, i := range indices { + b[pos] = ',' + pos++ + _, v := scanToSpaceOr(buf, i, ',') + pos += copy(b[pos:], v) + } + + // Check again for duplicate tags now that the tags are sorted. + for j := 0; j < commas-1; j++ { + // get the left and right tags + _, left := scanTo(buf[indices[j]:], 0, '=') + _, right := scanTo(buf[indices[j+1]:], 0, '=') + + // If the tags are equal, then there are duplicate tags, and we should abort. + // If the tags are not sorted, this pass may not find duplicate tags and we + // need to do a more exhaustive search later. + if bytes.Equal(left, right) { + return i, b, fmt.Errorf("duplicate tags") + } + } + + return i, b, nil + } + + return i, buf[start:i], nil +} + +// The following constants allow us to specify which state to move to +// next, when scanning sections of a Point. +const ( + tagKeyState = iota + tagValueState + fieldsState +) + +// scanMeasurement examines the measurement part of a Point, returning +// the next state to move to, and the current location in the buffer. +func scanMeasurement(buf []byte, i int) (int, int, error) { + // Check first byte of measurement, anything except a comma is fine. + // It can't be a space, since whitespace is stripped prior to this + // function call. + if i >= len(buf) || buf[i] == ',' { + return -1, i, fmt.Errorf("missing measurement") + } + + for { + i++ + if i >= len(buf) { + // cpu + return -1, i, fmt.Errorf("missing fields") + } + + if buf[i-1] == '\\' { + // Skip character (it's escaped). + continue + } + + // Unescaped comma; move onto scanning the tags. + if buf[i] == ',' { + return tagKeyState, i + 1, nil + } + + // Unescaped space; move onto scanning the fields. + if buf[i] == ' ' { + // cpu value=1.0 + return fieldsState, i, nil + } + } +} + +// scanTags examines all the tags in a Point, keeping track of and +// returning the updated indices slice, number of commas and location +// in buf where to start examining the Point fields. +func scanTags(buf []byte, i int, indices []int) (int, int, []int, error) { + var ( + err error + commas int + state = tagKeyState + ) + + for { + switch state { + case tagKeyState: + // Grow our indices slice if we have too many tags. + if commas >= len(indices) { + newIndics := make([]int, cap(indices)*2) + copy(newIndics, indices) + indices = newIndics + } + indices[commas] = i + commas++ + + i, err = scanTagsKey(buf, i) + state = tagValueState // tag value always follows a tag key + case tagValueState: + state, i, err = scanTagsValue(buf, i) + case fieldsState: + indices[commas] = i + 1 + return i, commas, indices, nil + } + + if err != nil { + return i, commas, indices, err + } + } +} + +// scanTagsKey scans each character in a tag key. +func scanTagsKey(buf []byte, i int) (int, error) { + // First character of the key. + if i >= len(buf) || buf[i] == ' ' || buf[i] == ',' || buf[i] == '=' { + // cpu,{'', ' ', ',', '='} + return i, fmt.Errorf("missing tag key") + } + + // Examine each character in the tag key until we hit an unescaped + // equals (the tag value), or we hit an error (i.e., unescaped + // space or comma). + for { + i++ + + // Either we reached the end of the buffer or we hit an + // unescaped comma or space. + if i >= len(buf) || + ((buf[i] == ' ' || buf[i] == ',') && buf[i-1] != '\\') { + // cpu,tag{'', ' ', ','} + return i, fmt.Errorf("missing tag value") + } + + if buf[i] == '=' && buf[i-1] != '\\' { + // cpu,tag= + return i + 1, nil + } + } +} + +// scanTagsValue scans each character in a tag value. +func scanTagsValue(buf []byte, i int) (int, int, error) { + // Tag value cannot be empty. + if i >= len(buf) || buf[i] == ',' || buf[i] == ' ' { + // cpu,tag={',', ' '} + return -1, i, fmt.Errorf("missing tag value") + } + + // Examine each character in the tag value until we hit an unescaped + // comma (move onto next tag key), an unescaped space (move onto + // fields), or we error out. + for { + i++ + if i >= len(buf) { + // cpu,tag=value + return -1, i, fmt.Errorf("missing fields") + } + + // An unescaped equals sign is an invalid tag value. + if buf[i] == '=' && buf[i-1] != '\\' { + // cpu,tag={'=', 'fo=o'} + return -1, i, fmt.Errorf("invalid tag format") + } + + if buf[i] == ',' && buf[i-1] != '\\' { + // cpu,tag=foo, + return tagKeyState, i + 1, nil + } + + // cpu,tag=foo value=1.0 + // cpu, tag=foo\= value=1.0 + if buf[i] == ' ' && buf[i-1] != '\\' { + return fieldsState, i, nil + } + } +} + +func insertionSort(l, r int, buf []byte, indices []int) { + for i := l + 1; i < r; i++ { + for j := i; j > l && less(buf, indices, j, j-1); j-- { + indices[j], indices[j-1] = indices[j-1], indices[j] + } + } +} + +func less(buf []byte, indices []int, i, j int) bool { + // This grabs the tag names for i & j, it ignores the values + _, a := scanTo(buf, indices[i], '=') + _, b := scanTo(buf, indices[j], '=') + return bytes.Compare(a, b) < 0 +} + +// scanFields scans buf, starting at i for the fields section of a point. It returns +// the ending position and the byte slice of the fields within buf. +func scanFields(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + i = start + quoted := false + + // tracks how many '=' we've seen + equals := 0 + + // tracks how many commas we've seen + commas := 0 + + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // escaped characters? + if buf[i] == '\\' && i+1 < len(buf) { + i += 2 + continue + } + + // If the value is quoted, scan until we get to the end quote + // Only quote values in the field value since quotes are not significant + // in the field key + if buf[i] == '"' && equals > commas { + quoted = !quoted + i++ + continue + } + + // If we see an =, ensure that there is at least on char before and after it + if buf[i] == '=' && !quoted { + equals++ + + // check for "... =123" but allow "a\ =123" + if buf[i-1] == ' ' && buf[i-2] != '\\' { + return i, buf[start:i], fmt.Errorf("missing field key") + } + + // check for "...a=123,=456" but allow "a=123,a\,=456" + if buf[i-1] == ',' && buf[i-2] != '\\' { + return i, buf[start:i], fmt.Errorf("missing field key") + } + + // check for "... value=" + if i+1 >= len(buf) { + return i, buf[start:i], fmt.Errorf("missing field value") + } + + // check for "... value=,value2=..." + if buf[i+1] == ',' || buf[i+1] == ' ' { + return i, buf[start:i], fmt.Errorf("missing field value") + } + + if isNumeric(buf[i+1]) || buf[i+1] == '-' || buf[i+1] == 'N' || buf[i+1] == 'n' { + var err error + i, err = scanNumber(buf, i+1) + if err != nil { + return i, buf[start:i], err + } + continue + } + // If next byte is not a double-quote, the value must be a boolean + if buf[i+1] != '"' { + var err error + i, _, err = scanBoolean(buf, i+1) + if err != nil { + return i, buf[start:i], err + } + continue + } + } + + if buf[i] == ',' && !quoted { + commas++ + } + + // reached end of block? + if buf[i] == ' ' && !quoted { + break + } + i++ + } + + if quoted { + return i, buf[start:i], fmt.Errorf("unbalanced quotes") + } + + // check that all field sections had key and values (e.g. prevent "a=1,b" + if equals == 0 || commas != equals-1 { + return i, buf[start:i], fmt.Errorf("invalid field format") + } + + return i, buf[start:i], nil +} + +// scanTime scans buf, starting at i for the time section of a point. It +// returns the ending position and the byte slice of the timestamp within buf +// and and error if the timestamp is not in the correct numeric format. +func scanTime(buf []byte, i int) (int, []byte, error) { + start := skipWhitespace(buf, i) + i = start + + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // Reached end of block or trailing whitespace? + if buf[i] == '\n' || buf[i] == ' ' { + break + } + + // Handle negative timestamps + if i == start && buf[i] == '-' { + i++ + continue + } + + // Timestamps should be integers, make sure they are so we don't need + // to actually parse the timestamp until needed. + if buf[i] < '0' || buf[i] > '9' { + return i, buf[start:i], fmt.Errorf("bad timestamp") + } + i++ + } + return i, buf[start:i], nil +} + +func isNumeric(b byte) bool { + return (b >= '0' && b <= '9') || b == '.' +} + +// scanNumber returns the end position within buf, start at i after +// scanning over buf for an integer, or float. It returns an +// error if a invalid number is scanned. +func scanNumber(buf []byte, i int) (int, error) { + start := i + var isInt, isUnsigned bool + + // Is negative number? + if i < len(buf) && buf[i] == '-' { + i++ + // There must be more characters now, as just '-' is illegal. + if i == len(buf) { + return i, ErrInvalidNumber + } + } + + // how many decimal points we've see + decimal := false + + // indicates the number is float in scientific notation + scientific := false + + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' || buf[i] == ' ' { + break + } + + if buf[i] == 'i' && i > start && !(isInt || isUnsigned) { + isInt = true + i++ + continue + } else if buf[i] == 'u' && i > start && !(isInt || isUnsigned) { + isUnsigned = true + i++ + continue + } + + if buf[i] == '.' { + // Can't have more than 1 decimal (e.g. 1.1.1 should fail) + if decimal { + return i, ErrInvalidNumber + } + decimal = true + } + + // `e` is valid for floats but not as the first char + if i > start && (buf[i] == 'e' || buf[i] == 'E') { + scientific = true + i++ + continue + } + + // + and - are only valid at this point if they follow an e (scientific notation) + if (buf[i] == '+' || buf[i] == '-') && (buf[i-1] == 'e' || buf[i-1] == 'E') { + i++ + continue + } + + // NaN is an unsupported value + if i+2 < len(buf) && (buf[i] == 'N' || buf[i] == 'n') { + return i, ErrInvalidNumber + } + + if !isNumeric(buf[i]) { + return i, ErrInvalidNumber + } + i++ + } + + if (isInt || isUnsigned) && (decimal || scientific) { + return i, ErrInvalidNumber + } + + numericDigits := i - start + if isInt { + numericDigits-- + } + if decimal { + numericDigits-- + } + if buf[start] == '-' { + numericDigits-- + } + + if numericDigits == 0 { + return i, ErrInvalidNumber + } + + // It's more common that numbers will be within min/max range for their type but we need to prevent + // out or range numbers from being parsed successfully. This uses some simple heuristics to decide + // if we should parse the number to the actual type. It does not do it all the time because it incurs + // extra allocations and we end up converting the type again when writing points to disk. + if isInt { + // Make sure the last char is an 'i' for integers (e.g. 9i10 is not valid) + if buf[i-1] != 'i' { + return i, ErrInvalidNumber + } + // Parse the int to check bounds the number of digits could be larger than the max range + // We subtract 1 from the index to remove the `i` from our tests + if len(buf[start:i-1]) >= maxInt64Digits || len(buf[start:i-1]) >= minInt64Digits { + if _, err := parseIntBytes(buf[start:i-1], 10, 64); err != nil { + return i, fmt.Errorf("unable to parse integer %s: %s", buf[start:i-1], err) + } + } + } else if isUnsigned { + // Return an error if uint64 support has not been enabled. + if !enableUint64Support { + return i, ErrInvalidNumber + } + // Make sure the last char is a 'u' for unsigned + if buf[i-1] != 'u' { + return i, ErrInvalidNumber + } + // Make sure the first char is not a '-' for unsigned + if buf[start] == '-' { + return i, ErrInvalidNumber + } + // Parse the uint to check bounds the number of digits could be larger than the max range + // We subtract 1 from the index to remove the `u` from our tests + if len(buf[start:i-1]) >= maxUint64Digits { + if _, err := parseUintBytes(buf[start:i-1], 10, 64); err != nil { + return i, fmt.Errorf("unable to parse unsigned %s: %s", buf[start:i-1], err) + } + } + } else { + // Parse the float to check bounds if it's scientific or the number of digits could be larger than the max range + if scientific || len(buf[start:i]) >= maxFloat64Digits || len(buf[start:i]) >= minFloat64Digits { + if _, err := parseFloatBytes(buf[start:i], 10); err != nil { + return i, fmt.Errorf("invalid float") + } + } + } + + return i, nil +} + +// scanBoolean returns the end position within buf, start at i after +// scanning over buf for boolean. Valid values for a boolean are +// t, T, true, TRUE, f, F, false, FALSE. It returns an error if a invalid boolean +// is scanned. +func scanBoolean(buf []byte, i int) (int, []byte, error) { + start := i + + if i < len(buf) && (buf[i] != 't' && buf[i] != 'f' && buf[i] != 'T' && buf[i] != 'F') { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + i++ + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' || buf[i] == ' ' { + break + } + i++ + } + + // Single char bool (t, T, f, F) is ok + if i-start == 1 { + return i, buf[start:i], nil + } + + // length must be 4 for true or TRUE + if (buf[start] == 't' || buf[start] == 'T') && i-start != 4 { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + // length must be 5 for false or FALSE + if (buf[start] == 'f' || buf[start] == 'F') && i-start != 5 { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + // Otherwise + valid := false + switch buf[start] { + case 't': + valid = bytes.Equal(buf[start:i], []byte("true")) + case 'f': + valid = bytes.Equal(buf[start:i], []byte("false")) + case 'T': + valid = bytes.Equal(buf[start:i], []byte("TRUE")) || bytes.Equal(buf[start:i], []byte("True")) + case 'F': + valid = bytes.Equal(buf[start:i], []byte("FALSE")) || bytes.Equal(buf[start:i], []byte("False")) + } + + if !valid { + return i, buf[start:i], fmt.Errorf("invalid boolean") + } + + return i, buf[start:i], nil + +} + +// skipWhitespace returns the end position within buf, starting at i after +// scanning over spaces in tags. +func skipWhitespace(buf []byte, i int) int { + for i < len(buf) { + if buf[i] != ' ' && buf[i] != '\t' && buf[i] != 0 { + break + } + i++ + } + return i +} + +// scanLine returns the end position in buf and the next line found within +// buf. +func scanLine(buf []byte, i int) (int, []byte) { + start := i + quoted := false + fields := false + + // tracks how many '=' and commas we've seen + // this duplicates some of the functionality in scanFields + equals := 0 + commas := 0 + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // skip past escaped characters + if buf[i] == '\\' && i+2 < len(buf) { + i += 2 + continue + } + + if buf[i] == ' ' { + fields = true + } + + // If we see a double quote, makes sure it is not escaped + if fields { + if !quoted && buf[i] == '=' { + i++ + equals++ + continue + } else if !quoted && buf[i] == ',' { + i++ + commas++ + continue + } else if buf[i] == '"' && equals > commas { + i++ + quoted = !quoted + continue + } + } + + if buf[i] == '\n' && !quoted { + break + } + + i++ + } + + return i, buf[start:i] +} + +// scanTo returns the end position in buf and the next consecutive block +// of bytes, starting from i and ending with stop byte, where stop byte +// has not been escaped. +// +// If there are leading spaces, they are skipped. +func scanTo(buf []byte, i int, stop byte) (int, []byte) { + start := i + for { + // reached the end of buf? + if i >= len(buf) { + break + } + + // Reached unescaped stop value? + if buf[i] == stop && (i == 0 || buf[i-1] != '\\') { + break + } + i++ + } + + return i, buf[start:i] +} + +// scanTo returns the end position in buf and the next consecutive block +// of bytes, starting from i and ending with stop byte. If there are leading +// spaces, they are skipped. +func scanToSpaceOr(buf []byte, i int, stop byte) (int, []byte) { + start := i + if buf[i] == stop || buf[i] == ' ' { + return i, buf[start:i] + } + + for { + i++ + if buf[i-1] == '\\' { + continue + } + + // reached the end of buf? + if i >= len(buf) { + return i, buf[start:i] + } + + // reached end of block? + if buf[i] == stop || buf[i] == ' ' { + return i, buf[start:i] + } + } +} + +func scanTagValue(buf []byte, i int) (int, []byte) { + start := i + for { + if i >= len(buf) { + break + } + + if buf[i] == ',' && buf[i-1] != '\\' { + break + } + i++ + } + if i > len(buf) { + return i, nil + } + return i, buf[start:i] +} + +func scanFieldValue(buf []byte, i int) (int, []byte) { + start := i + quoted := false + for i < len(buf) { + // Only escape char for a field value is a double-quote and backslash + if buf[i] == '\\' && i+1 < len(buf) && (buf[i+1] == '"' || buf[i+1] == '\\') { + i += 2 + continue + } + + // Quoted value? (e.g. string) + if buf[i] == '"' { + i++ + quoted = !quoted + continue + } + + if buf[i] == ',' && !quoted { + break + } + i++ + } + return i, buf[start:i] +} + +func EscapeMeasurement(in []byte) []byte { + for _, c := range measurementEscapeCodes { + if bytes.IndexByte(in, c.k[0]) != -1 { + in = bytes.Replace(in, c.k[:], c.esc[:], -1) + } + } + return in +} + +func unescapeMeasurement(in []byte) []byte { + if bytes.IndexByte(in, '\\') == -1 { + return in + } + + for i := range measurementEscapeCodes { + c := &measurementEscapeCodes[i] + if bytes.IndexByte(in, c.k[0]) != -1 { + in = bytes.Replace(in, c.esc[:], c.k[:], -1) + } + } + return in +} + +func escapeTag(in []byte) []byte { + for i := range tagEscapeCodes { + c := &tagEscapeCodes[i] + if bytes.IndexByte(in, c.k[0]) != -1 { + in = bytes.Replace(in, c.k[:], c.esc[:], -1) + } + } + return in +} + +func unescapeTag(in []byte) []byte { + if bytes.IndexByte(in, '\\') == -1 { + return in + } + + for i := range tagEscapeCodes { + c := &tagEscapeCodes[i] + if bytes.IndexByte(in, c.k[0]) != -1 { + in = bytes.Replace(in, c.esc[:], c.k[:], -1) + } + } + return in +} + +// escapeStringFieldReplacer replaces double quotes and backslashes +// with the same character preceded by a backslash. +// As of Go 1.7 this benchmarked better in allocations and CPU time +// compared to iterating through a string byte-by-byte and appending to a new byte slice, +// calling strings.Replace twice, and better than (*Regex).ReplaceAllString. +var escapeStringFieldReplacer = strings.NewReplacer(`"`, `\"`, `\`, `\\`) + +// EscapeStringField returns a copy of in with any double quotes or +// backslashes with escaped values. +func EscapeStringField(in string) string { + return escapeStringFieldReplacer.Replace(in) +} + +// unescapeStringField returns a copy of in with any escaped double-quotes +// or backslashes unescaped. +func unescapeStringField(in string) string { + if strings.IndexByte(in, '\\') == -1 { + return in + } + + var out []byte + i := 0 + for { + if i >= len(in) { + break + } + // unescape backslashes + if in[i] == '\\' && i+1 < len(in) && in[i+1] == '\\' { + out = append(out, '\\') + i += 2 + continue + } + // unescape double-quotes + if in[i] == '\\' && i+1 < len(in) && in[i+1] == '"' { + out = append(out, '"') + i += 2 + continue + } + out = append(out, in[i]) + i++ + + } + return string(out) +} + +// NewPoint returns a new point with the given measurement name, tags, fields and timestamp. If +// an unsupported field value (NaN, or +/-Inf) or out of range time is passed, this function +// returns an error. +func NewPoint(name string, tags Tags, fields Fields, t time.Time) (Point, error) { + key, err := pointKey(name, tags, fields, t) + if err != nil { + return nil, err + } + + return &point{ + key: key, + time: t, + fields: fields.MarshalBinary(), + }, nil +} + +// pointKey checks some basic requirements for valid points, and returns the +// key, along with an possible error. +func pointKey(measurement string, tags Tags, fields Fields, t time.Time) ([]byte, error) { + if len(fields) == 0 { + return nil, ErrPointMustHaveAField + } + + if !t.IsZero() { + if err := CheckTime(t); err != nil { + return nil, err + } + } + + for key, value := range fields { + switch value := value.(type) { + case float64: + // Ensure the caller validates and handles invalid field values + if math.IsInf(value, 0) { + return nil, fmt.Errorf("+/-Inf is an unsupported value for field %s", key) + } + if math.IsNaN(value) { + return nil, fmt.Errorf("NaN is an unsupported value for field %s", key) + } + case float32: + // Ensure the caller validates and handles invalid field values + if math.IsInf(float64(value), 0) { + return nil, fmt.Errorf("+/-Inf is an unsupported value for field %s", key) + } + if math.IsNaN(float64(value)) { + return nil, fmt.Errorf("NaN is an unsupported value for field %s", key) + } + } + if len(key) == 0 { + return nil, fmt.Errorf("all fields must have non-empty names") + } + } + + key := MakeKey([]byte(measurement), tags) + for field := range fields { + sz := seriesKeySize(key, []byte(field)) + if sz > MaxKeyLength { + return nil, fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength) + } + } + + return key, nil +} + +func seriesKeySize(key, field []byte) int { + // 4 is the length of the tsm1.fieldKeySeparator constant. It's inlined here to avoid a circular + // dependency. + return len(key) + 4 + len(field) +} + +// NewPointFromBytes returns a new Point from a marshalled Point. +func NewPointFromBytes(b []byte) (Point, error) { + p := &point{} + if err := p.UnmarshalBinary(b); err != nil { + return nil, err + } + + // This does some basic validation to ensure there are fields and they + // can be unmarshalled as well. + iter := p.FieldIterator() + var hasField bool + for iter.Next() { + if len(iter.FieldKey()) == 0 { + continue + } + hasField = true + switch iter.Type() { + case Float: + _, err := iter.FloatValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case Integer: + _, err := iter.IntegerValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case Unsigned: + _, err := iter.UnsignedValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + case String: + // Skip since this won't return an error + case Boolean: + _, err := iter.BooleanValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + } + } + + if !hasField { + return nil, ErrPointMustHaveAField + } + + return p, nil +} + +// MustNewPoint returns a new point with the given measurement name, tags, fields and timestamp. If +// an unsupported field value (NaN) is passed, this function panics. +func MustNewPoint(name string, tags Tags, fields Fields, time time.Time) Point { + pt, err := NewPoint(name, tags, fields, time) + if err != nil { + panic(err.Error()) + } + return pt +} + +// Key returns the key (measurement joined with tags) of the point. +func (p *point) Key() []byte { + return p.key +} + +func (p *point) name() []byte { + _, name := scanTo(p.key, 0, ',') + return name +} + +func (p *point) Name() []byte { + return escape.Unescape(p.name()) +} + +// SetName updates the measurement name for the point. +func (p *point) SetName(name string) { + p.cachedName = "" + p.key = MakeKey([]byte(name), p.Tags()) +} + +// Time return the timestamp for the point. +func (p *point) Time() time.Time { + return p.time +} + +// SetTime updates the timestamp for the point. +func (p *point) SetTime(t time.Time) { + p.time = t +} + +// Round will round the timestamp of the point to the given duration. +func (p *point) Round(d time.Duration) { + p.time = p.time.Round(d) +} + +// Tags returns the tag set for the point. +func (p *point) Tags() Tags { + if p.cachedTags != nil { + return p.cachedTags + } + p.cachedTags = parseTags(p.key, nil) + return p.cachedTags +} + +func (p *point) ForEachTag(fn func(k, v []byte) bool) { + walkTags(p.key, fn) +} + +func (p *point) HasTag(tag []byte) bool { + if len(p.key) == 0 { + return false + } + + var exists bool + walkTags(p.key, func(key, value []byte) bool { + if bytes.Equal(tag, key) { + exists = true + return false + } + return true + }) + + return exists +} + +func walkTags(buf []byte, fn func(key, value []byte) bool) { + if len(buf) == 0 { + return + } + + pos, name := scanTo(buf, 0, ',') + + // it's an empty key, so there are no tags + if len(name) == 0 { + return + } + + hasEscape := bytes.IndexByte(buf, '\\') != -1 + i := pos + 1 + var key, value []byte + for { + if i >= len(buf) { + break + } + i, key = scanTo(buf, i, '=') + i, value = scanTagValue(buf, i+1) + + if len(value) == 0 { + continue + } + + if hasEscape { + if !fn(unescapeTag(key), unescapeTag(value)) { + return + } + } else { + if !fn(key, value) { + return + } + } + + i++ + } +} + +// walkFields walks each field key and value via fn. If fn returns false, the iteration +// is stopped. The values are the raw byte slices and not the converted types. +func walkFields(buf []byte, fn func(key, value []byte) bool) error { + var i int + var key, val []byte + for len(buf) > 0 { + i, key = scanTo(buf, 0, '=') + if i > len(buf)-2 { + return fmt.Errorf("invalid value: field-key=%s", key) + } + buf = buf[i+1:] + i, val = scanFieldValue(buf, 0) + buf = buf[i:] + if !fn(key, val) { + break + } + + // slice off comma + if len(buf) > 0 { + buf = buf[1:] + } + } + return nil +} + +// parseTags parses buf into the provided destination tags, returning destination +// Tags, which may have a different length and capacity. +func parseTags(buf []byte, dst Tags) Tags { + if len(buf) == 0 { + return nil + } + + n := bytes.Count(buf, []byte(",")) + if cap(dst) < n { + dst = make(Tags, n) + } else { + dst = dst[:n] + } + + // Ensure existing behaviour when point has no tags and nil slice passed in. + if dst == nil { + dst = Tags{} + } + + // Series keys can contain escaped commas, therefore the number of commas + // in a series key only gives an estimation of the upper bound on the number + // of tags. + var i int + walkTags(buf, func(key, value []byte) bool { + dst[i].Key, dst[i].Value = key, value + i++ + return true + }) + return dst[:i] +} + +// MakeKey creates a key for a set of tags. +func MakeKey(name []byte, tags Tags) []byte { + return AppendMakeKey(nil, name, tags) +} + +// AppendMakeKey appends the key derived from name and tags to dst and returns the extended buffer. +func AppendMakeKey(dst []byte, name []byte, tags Tags) []byte { + // unescape the name and then re-escape it to avoid double escaping. + // The key should always be stored in escaped form. + dst = append(dst, EscapeMeasurement(unescapeMeasurement(name))...) + dst = tags.AppendHashKey(dst) + return dst +} + +// SetTags replaces the tags for the point. +func (p *point) SetTags(tags Tags) { + p.key = MakeKey(p.Name(), tags) + p.cachedTags = tags +} + +// AddTag adds or replaces a tag value for a point. +func (p *point) AddTag(key, value string) { + tags := p.Tags() + tags = append(tags, Tag{Key: []byte(key), Value: []byte(value)}) + sort.Sort(tags) + p.cachedTags = tags + p.key = MakeKey(p.Name(), tags) +} + +// Fields returns the fields for the point. +func (p *point) Fields() (Fields, error) { + if p.cachedFields != nil { + return p.cachedFields, nil + } + cf, err := p.unmarshalBinary() + if err != nil { + return nil, err + } + p.cachedFields = cf + return p.cachedFields, nil +} + +// SetPrecision will round a time to the specified precision. +func (p *point) SetPrecision(precision string) { + switch precision { + case "n": + case "u": + p.SetTime(p.Time().Truncate(time.Microsecond)) + case "ms": + p.SetTime(p.Time().Truncate(time.Millisecond)) + case "s": + p.SetTime(p.Time().Truncate(time.Second)) + case "m": + p.SetTime(p.Time().Truncate(time.Minute)) + case "h": + p.SetTime(p.Time().Truncate(time.Hour)) + } +} + +// String returns the string representation of the point. +func (p *point) String() string { + if p.Time().IsZero() { + return string(p.Key()) + " " + string(p.fields) + } + return string(p.Key()) + " " + string(p.fields) + " " + strconv.FormatInt(p.UnixNano(), 10) +} + +// AppendString appends the string representation of the point to buf. +func (p *point) AppendString(buf []byte) []byte { + buf = append(buf, p.key...) + buf = append(buf, ' ') + buf = append(buf, p.fields...) + + if !p.time.IsZero() { + buf = append(buf, ' ') + buf = strconv.AppendInt(buf, p.UnixNano(), 10) + } + + return buf +} + +// StringSize returns the length of the string that would be returned by String(). +func (p *point) StringSize() int { + size := len(p.key) + len(p.fields) + 1 + + if !p.time.IsZero() { + digits := 1 // even "0" has one digit + t := p.UnixNano() + if t < 0 { + // account for negative sign, then negate + digits++ + t = -t + } + for t > 9 { // already accounted for one digit + digits++ + t /= 10 + } + size += digits + 1 // digits and a space + } + + return size +} + +// MarshalBinary returns a binary representation of the point. +func (p *point) MarshalBinary() ([]byte, error) { + if len(p.fields) == 0 { + return nil, ErrPointMustHaveAField + } + + tb, err := p.time.MarshalBinary() + if err != nil { + return nil, err + } + + b := make([]byte, 8+len(p.key)+len(p.fields)+len(tb)) + i := 0 + + binary.BigEndian.PutUint32(b[i:], uint32(len(p.key))) + i += 4 + + i += copy(b[i:], p.key) + + binary.BigEndian.PutUint32(b[i:i+4], uint32(len(p.fields))) + i += 4 + + i += copy(b[i:], p.fields) + + copy(b[i:], tb) + return b, nil +} + +// UnmarshalBinary decodes a binary representation of the point into a point struct. +func (p *point) UnmarshalBinary(b []byte) error { + var n int + + // Read key length. + if len(b) < 4 { + return io.ErrShortBuffer + } + n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:] + + // Read key. + if len(b) < n { + return io.ErrShortBuffer + } + p.key, b = b[:n], b[n:] + + // Read fields length. + if len(b) < 4 { + return io.ErrShortBuffer + } + n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:] + + // Read fields. + if len(b) < n { + return io.ErrShortBuffer + } + p.fields, b = b[:n], b[n:] + + // Read timestamp. + return p.time.UnmarshalBinary(b) +} + +// PrecisionString returns a string representation of the point. If there +// is a timestamp associated with the point then it will be specified in the +// given unit. +func (p *point) PrecisionString(precision string) string { + if p.Time().IsZero() { + return fmt.Sprintf("%s %s", p.Key(), string(p.fields)) + } + return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields), + p.UnixNano()/GetPrecisionMultiplier(precision)) +} + +// RoundedString returns a string representation of the point. If there +// is a timestamp associated with the point, then it will be rounded to the +// given duration. +func (p *point) RoundedString(d time.Duration) string { + if p.Time().IsZero() { + return fmt.Sprintf("%s %s", p.Key(), string(p.fields)) + } + return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields), + p.time.Round(d).UnixNano()) +} + +func (p *point) unmarshalBinary() (Fields, error) { + iter := p.FieldIterator() + fields := make(Fields, 8) + for iter.Next() { + if len(iter.FieldKey()) == 0 { + continue + } + switch iter.Type() { + case Float: + v, err := iter.FloatValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case Integer: + v, err := iter.IntegerValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case Unsigned: + v, err := iter.UnsignedValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + case String: + fields[string(iter.FieldKey())] = iter.StringValue() + case Boolean: + v, err := iter.BooleanValue() + if err != nil { + return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err) + } + fields[string(iter.FieldKey())] = v + } + } + return fields, nil +} + +// HashID returns a non-cryptographic checksum of the point's key. +func (p *point) HashID() uint64 { + h := NewInlineFNV64a() + h.Write(p.key) + sum := h.Sum64() + return sum +} + +// UnixNano returns the timestamp of the point as nanoseconds since Unix epoch. +func (p *point) UnixNano() int64 { + return p.Time().UnixNano() +} + +// Split will attempt to return multiple points with the same timestamp whose +// string representations are no longer than size. Points with a single field or +// a point without a timestamp may exceed the requested size. +func (p *point) Split(size int) []Point { + if p.time.IsZero() || p.StringSize() <= size { + return []Point{p} + } + + // key string, timestamp string, spaces + size -= len(p.key) + len(strconv.FormatInt(p.time.UnixNano(), 10)) + 2 + + var points []Point + var start, cur int + + for cur < len(p.fields) { + end, _ := scanTo(p.fields, cur, '=') + end, _ = scanFieldValue(p.fields, end+1) + + if cur > start && end-start > size { + points = append(points, &point{ + key: p.key, + time: p.time, + fields: p.fields[start : cur-1], + }) + start = cur + } + + cur = end + 1 + } + + points = append(points, &point{ + key: p.key, + time: p.time, + fields: p.fields[start:], + }) + + return points +} + +// Tag represents a single key/value tag pair. +type Tag struct { + Key []byte + Value []byte +} + +// NewTag returns a new Tag. +func NewTag(key, value []byte) Tag { + return Tag{ + Key: key, + Value: value, + } +} + +// Size returns the size of the key and value. +func (t Tag) Size() int { return len(t.Key) + len(t.Value) } + +// Clone returns a shallow copy of Tag. +// +// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed. +// Use Clone to create a Tag with new byte slices that do not refer to the argument to ParsePointsWithPrecision. +func (t Tag) Clone() Tag { + other := Tag{ + Key: make([]byte, len(t.Key)), + Value: make([]byte, len(t.Value)), + } + + copy(other.Key, t.Key) + copy(other.Value, t.Value) + + return other +} + +// String returns the string reprsentation of the tag. +func (t *Tag) String() string { + var buf bytes.Buffer + buf.WriteByte('{') + buf.WriteString(string(t.Key)) + buf.WriteByte(' ') + buf.WriteString(string(t.Value)) + buf.WriteByte('}') + return buf.String() +} + +// Tags represents a sorted list of tags. +type Tags []Tag + +// NewTags returns a new Tags from a map. +func NewTags(m map[string]string) Tags { + if len(m) == 0 { + return nil + } + a := make(Tags, 0, len(m)) + for k, v := range m { + a = append(a, NewTag([]byte(k), []byte(v))) + } + sort.Sort(a) + return a +} + +// HashKey hashes all of a tag's keys. +func (a Tags) HashKey() []byte { + return a.AppendHashKey(nil) +} + +func (a Tags) needsEscape() bool { + for i := range a { + t := &a[i] + for j := range tagEscapeCodes { + c := &tagEscapeCodes[j] + if bytes.IndexByte(t.Key, c.k[0]) != -1 || bytes.IndexByte(t.Value, c.k[0]) != -1 { + return true + } + } + } + return false +} + +// AppendHashKey appends the result of hashing all of a tag's keys and values to dst and returns the extended buffer. +func (a Tags) AppendHashKey(dst []byte) []byte { + // Empty maps marshal to empty bytes. + if len(a) == 0 { + return dst + } + + // Type invariant: Tags are sorted + + sz := 0 + var escaped Tags + if a.needsEscape() { + var tmp [20]Tag + if len(a) < len(tmp) { + escaped = tmp[:len(a)] + } else { + escaped = make(Tags, len(a)) + } + + for i := range a { + t := &a[i] + nt := &escaped[i] + nt.Key = escapeTag(t.Key) + nt.Value = escapeTag(t.Value) + sz += len(nt.Key) + len(nt.Value) + } + } else { + sz = a.Size() + escaped = a + } + + sz += len(escaped) + (len(escaped) * 2) // separators + + // Generate marshaled bytes. + if cap(dst)-len(dst) < sz { + nd := make([]byte, len(dst), len(dst)+sz) + copy(nd, dst) + dst = nd + } + buf := dst[len(dst) : len(dst)+sz] + idx := 0 + for i := range escaped { + k := &escaped[i] + if len(k.Value) == 0 { + continue + } + buf[idx] = ',' + idx++ + copy(buf[idx:], k.Key) + idx += len(k.Key) + buf[idx] = '=' + idx++ + copy(buf[idx:], k.Value) + idx += len(k.Value) + } + return dst[:len(dst)+idx] +} + +// String returns the string representation of the tags. +func (a Tags) String() string { + var buf bytes.Buffer + buf.WriteByte('[') + for i := range a { + buf.WriteString(a[i].String()) + if i < len(a)-1 { + buf.WriteByte(' ') + } + } + buf.WriteByte(']') + return buf.String() +} + +// Size returns the number of bytes needed to store all tags. Note, this is +// the number of bytes needed to store all keys and values and does not account +// for data structures or delimiters for example. +func (a Tags) Size() int { + var total int + for i := range a { + total += a[i].Size() + } + return total +} + +// Clone returns a copy of the slice where the elements are a result of calling `Clone` on the original elements +// +// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed. +// Use Clone to create Tags with new byte slices that do not refer to the argument to ParsePointsWithPrecision. +func (a Tags) Clone() Tags { + if len(a) == 0 { + return nil + } + + others := make(Tags, len(a)) + for i := range a { + others[i] = a[i].Clone() + } + + return others +} + +func (a Tags) Len() int { return len(a) } +func (a Tags) Less(i, j int) bool { return bytes.Compare(a[i].Key, a[j].Key) == -1 } +func (a Tags) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// Equal returns true if a equals other. +func (a Tags) Equal(other Tags) bool { + if len(a) != len(other) { + return false + } + for i := range a { + if !bytes.Equal(a[i].Key, other[i].Key) || !bytes.Equal(a[i].Value, other[i].Value) { + return false + } + } + return true +} + +// CompareTags returns -1 if a < b, 1 if a > b, and 0 if a == b. +func CompareTags(a, b Tags) int { + // Compare each key & value until a mismatch. + for i := 0; i < len(a) && i < len(b); i++ { + if cmp := bytes.Compare(a[i].Key, b[i].Key); cmp != 0 { + return cmp + } + if cmp := bytes.Compare(a[i].Value, b[i].Value); cmp != 0 { + return cmp + } + } + + // If all tags are equal up to this point then return shorter tagset. + if len(a) < len(b) { + return -1 + } else if len(a) > len(b) { + return 1 + } + + // All tags are equal. + return 0 +} + +// Get returns the value for a key. +func (a Tags) Get(key []byte) []byte { + // OPTIMIZE: Use sort.Search if tagset is large. + + for _, t := range a { + if bytes.Equal(t.Key, key) { + return t.Value + } + } + return nil +} + +// GetString returns the string value for a string key. +func (a Tags) GetString(key string) string { + return string(a.Get([]byte(key))) +} + +// Set sets the value for a key. +func (a *Tags) Set(key, value []byte) { + for i, t := range *a { + if bytes.Equal(t.Key, key) { + (*a)[i].Value = value + return + } + } + *a = append(*a, Tag{Key: key, Value: value}) + sort.Sort(*a) +} + +// SetString sets the string value for a string key. +func (a *Tags) SetString(key, value string) { + a.Set([]byte(key), []byte(value)) +} + +// Map returns a map representation of the tags. +func (a Tags) Map() map[string]string { + m := make(map[string]string, len(a)) + for _, t := range a { + m[string(t.Key)] = string(t.Value) + } + return m +} + +// CopyTags returns a shallow copy of tags. +func CopyTags(a Tags) Tags { + other := make(Tags, len(a)) + copy(other, a) + return other +} + +// DeepCopyTags returns a deep copy of tags. +func DeepCopyTags(a Tags) Tags { + // Calculate size of keys/values in bytes. + var n int + for _, t := range a { + n += len(t.Key) + len(t.Value) + } + + // Build single allocation for all key/values. + buf := make([]byte, n) + + // Copy tags to new set. + other := make(Tags, len(a)) + for i, t := range a { + copy(buf, t.Key) + other[i].Key, buf = buf[:len(t.Key)], buf[len(t.Key):] + + copy(buf, t.Value) + other[i].Value, buf = buf[:len(t.Value)], buf[len(t.Value):] + } + + return other +} + +// Fields represents a mapping between a Point's field names and their +// values. +type Fields map[string]interface{} + +// FieldIterator retuns a FieldIterator that can be used to traverse the +// fields of a point without constructing the in-memory map. +func (p *point) FieldIterator() FieldIterator { + p.Reset() + return p +} + +type fieldIterator struct { + start, end int + key, keybuf []byte + valueBuf []byte + fieldType FieldType +} + +// Next indicates whether there any fields remaining. +func (p *point) Next() bool { + p.it.start = p.it.end + if p.it.start >= len(p.fields) { + return false + } + + p.it.end, p.it.key = scanTo(p.fields, p.it.start, '=') + if escape.IsEscaped(p.it.key) { + p.it.keybuf = escape.AppendUnescaped(p.it.keybuf[:0], p.it.key) + p.it.key = p.it.keybuf + } + + p.it.end, p.it.valueBuf = scanFieldValue(p.fields, p.it.end+1) + p.it.end++ + + if len(p.it.valueBuf) == 0 { + p.it.fieldType = Empty + return true + } + + c := p.it.valueBuf[0] + + if c == '"' { + p.it.fieldType = String + return true + } + + if strings.IndexByte(`0123456789-.nNiIu`, c) >= 0 { + if p.it.valueBuf[len(p.it.valueBuf)-1] == 'i' { + p.it.fieldType = Integer + p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1] + } else if p.it.valueBuf[len(p.it.valueBuf)-1] == 'u' { + p.it.fieldType = Unsigned + p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1] + } else { + p.it.fieldType = Float + } + return true + } + + // to keep the same behavior that currently exists, default to boolean + p.it.fieldType = Boolean + return true +} + +// FieldKey returns the key of the current field. +func (p *point) FieldKey() []byte { + return p.it.key +} + +// Type returns the FieldType of the current field. +func (p *point) Type() FieldType { + return p.it.fieldType +} + +// StringValue returns the string value of the current field. +func (p *point) StringValue() string { + return unescapeStringField(string(p.it.valueBuf[1 : len(p.it.valueBuf)-1])) +} + +// IntegerValue returns the integer value of the current field. +func (p *point) IntegerValue() (int64, error) { + n, err := parseIntBytes(p.it.valueBuf, 10, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse integer value %q: %v", p.it.valueBuf, err) + } + return n, nil +} + +// UnsignedValue returns the unsigned value of the current field. +func (p *point) UnsignedValue() (uint64, error) { + n, err := parseUintBytes(p.it.valueBuf, 10, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse unsigned value %q: %v", p.it.valueBuf, err) + } + return n, nil +} + +// BooleanValue returns the boolean value of the current field. +func (p *point) BooleanValue() (bool, error) { + b, err := parseBoolBytes(p.it.valueBuf) + if err != nil { + return false, fmt.Errorf("unable to parse bool value %q: %v", p.it.valueBuf, err) + } + return b, nil +} + +// FloatValue returns the float value of the current field. +func (p *point) FloatValue() (float64, error) { + f, err := parseFloatBytes(p.it.valueBuf, 64) + if err != nil { + return 0, fmt.Errorf("unable to parse floating point value %q: %v", p.it.valueBuf, err) + } + return f, nil +} + +// Reset resets the iterator to its initial state. +func (p *point) Reset() { + p.it.fieldType = Empty + p.it.key = nil + p.it.valueBuf = nil + p.it.start = 0 + p.it.end = 0 +} + +// MarshalBinary encodes all the fields to their proper type and returns the binary +// represenation +// NOTE: uint64 is specifically not supported due to potential overflow when we decode +// again later to an int64 +// NOTE2: uint is accepted, and may be 64 bits, and is for some reason accepted... +func (p Fields) MarshalBinary() []byte { + var b []byte + keys := make([]string, 0, len(p)) + + for k := range p { + keys = append(keys, k) + } + + // Not really necessary, can probably be removed. + sort.Strings(keys) + + for i, k := range keys { + if i > 0 { + b = append(b, ',') + } + b = appendField(b, k, p[k]) + } + + return b +} + +func appendField(b []byte, k string, v interface{}) []byte { + b = append(b, []byte(escape.String(k))...) + b = append(b, '=') + + // check popular types first + switch v := v.(type) { + case float64: + b = strconv.AppendFloat(b, v, 'f', -1, 64) + case int64: + b = strconv.AppendInt(b, v, 10) + b = append(b, 'i') + case string: + b = append(b, '"') + b = append(b, []byte(EscapeStringField(v))...) + b = append(b, '"') + case bool: + b = strconv.AppendBool(b, v) + case int32: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int16: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int8: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case int: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint64: + b = strconv.AppendUint(b, v, 10) + b = append(b, 'u') + case uint32: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint16: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint8: + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case uint: + // TODO: 'uint' should be converted to writing as an unsigned integer, + // but we cannot since that would break backwards compatibility. + b = strconv.AppendInt(b, int64(v), 10) + b = append(b, 'i') + case float32: + b = strconv.AppendFloat(b, float64(v), 'f', -1, 32) + case []byte: + b = append(b, v...) + case nil: + // skip + default: + // Can't determine the type, so convert to string + b = append(b, '"') + b = append(b, []byte(EscapeStringField(fmt.Sprintf("%v", v)))...) + b = append(b, '"') + + } + + return b +} + +// ValidKeyToken returns true if the token used for measurement, tag key, or tag +// value is a valid unicode string and only contains printable, non-replacement characters. +func ValidKeyToken(s string) bool { + if !utf8.ValidString(s) { + return false + } + for _, r := range s { + if !unicode.IsPrint(r) || r == unicode.ReplacementChar { + return false + } + } + return true +} + +// ValidKeyTokens returns true if the measurement name and all tags are valid. +func ValidKeyTokens(name string, tags Tags) bool { + if !ValidKeyToken(name) { + return false + } + for _, tag := range tags { + if !ValidKeyToken(string(tag.Key)) || !ValidKeyToken(string(tag.Value)) { + return false + } + } + return true +} diff --git a/influxdb/client/models/rows.go b/influxdb/client/models/rows.go new file mode 100644 index 0000000..1d10292 --- /dev/null +++ b/influxdb/client/models/rows.go @@ -0,0 +1,75 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models + +import ( + "sort" +) + +// Row represents a single row returned from the execution of a statement. +type Row struct { + Name string `json:"name,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Columns []string `json:"columns,omitempty"` + Values [][]interface{} `json:"values,omitempty"` + Partial bool `json:"partial,omitempty"` +} + +// SameSeries returns true if r contains values for the same series as o. +func (r *Row) SameSeries(o *Row) bool { + return r.tagsHash() == o.tagsHash() && r.Name == o.Name +} + +// tagsHash returns a hash of tag key/value pairs. +func (r *Row) tagsHash() uint64 { + h := NewInlineFNV64a() + keys := r.tagsKeys() + for _, k := range keys { + h.Write([]byte(k)) + h.Write([]byte(r.Tags[k])) + } + return h.Sum64() +} + +// tagKeys returns a sorted list of tag keys. +func (r *Row) tagsKeys() []string { + a := make([]string, 0, len(r.Tags)) + for k := range r.Tags { + a = append(a, k) + } + sort.Strings(a) + return a +} + +// Rows represents a collection of rows. Rows implements sort.Interface. +type Rows []*Row + +// Len implements sort.Interface. +func (p Rows) Len() int { return len(p) } + +// Less implements sort.Interface. +func (p Rows) Less(i, j int) bool { + // Sort by name first. + if p[i].Name != p[j].Name { + return p[i].Name < p[j].Name + } + + // Sort by tag set hash. Tags don't have a meaningful sort order so we + // just compute a hash and sort by that instead. This allows the tests + // to receive rows in a predictable order every time. + return p[i].tagsHash() < p[j].tagsHash() +} + +// Swap implements sort.Interface. +func (p Rows) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/influxdb/client/models/statistic.go b/influxdb/client/models/statistic.go new file mode 100644 index 0000000..4c963b1 --- /dev/null +++ b/influxdb/client/models/statistic.go @@ -0,0 +1,55 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models + +// Statistic is the representation of a statistic used by the monitoring service. +type Statistic struct { + Name string `json:"name"` + Tags map[string]string `json:"tags"` + Values map[string]interface{} `json:"values"` +} + +// NewStatistic returns an initialized Statistic. +func NewStatistic(name string) Statistic { + return Statistic{ + Name: name, + Tags: make(map[string]string), + Values: make(map[string]interface{}), + } +} + +// StatisticTags is a map that can be merged with others without causing +// mutations to either map. +type StatisticTags map[string]string + +// Merge creates a new map containing the merged contents of tags and t. +// If both tags and the receiver map contain the same key, the value in tags +// is used in the resulting map. +// +// Merge always returns a usable map. +func (t StatisticTags) Merge(tags map[string]string) map[string]string { + // Add everything in tags to the result. + out := make(map[string]string, len(tags)) + for k, v := range tags { + out[k] = v + } + + // Only add values from t that don't appear in tags. + for k, v := range t { + if _, ok := tags[k]; !ok { + out[k] = v + } + } + return out +} diff --git a/influxdb/client/models/time.go b/influxdb/client/models/time.go new file mode 100644 index 0000000..cb46414 --- /dev/null +++ b/influxdb/client/models/time.go @@ -0,0 +1,87 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models + +// Helper time methods since parsing time can easily overflow and we only support a +// specific time range. + +import ( + "fmt" + "math" + "time" +) + +const ( + // MinNanoTime is the minumum time that can be represented. + // + // 1677-09-21 00:12:43.145224194 +0000 UTC + // + // The two lowest minimum integers are used as sentinel values. The + // minimum value needs to be used as a value lower than any other value for + // comparisons and another separate value is needed to act as a sentinel + // default value that is unusable by the user, but usable internally. + // Because these two values need to be used for a special purpose, we do + // not allow users to write points at these two times. + MinNanoTime = int64(math.MinInt64) + 2 + + // MaxNanoTime is the maximum time that can be represented. + // + // 2262-04-11 23:47:16.854775806 +0000 UTC + // + // The highest time represented by a nanosecond needs to be used for an + // exclusive range in the shard group, so the maximum time needs to be one + // less than the possible maximum number of nanoseconds representable by an + // int64 so that we don't lose a point at that one time. + MaxNanoTime = int64(math.MaxInt64) - 1 +) + +var ( + minNanoTime = time.Unix(0, MinNanoTime).UTC() + maxNanoTime = time.Unix(0, MaxNanoTime).UTC() + + // ErrTimeOutOfRange gets returned when time is out of the representable range using int64 nanoseconds since the epoch. + ErrTimeOutOfRange = fmt.Errorf("time outside range %d - %d", MinNanoTime, MaxNanoTime) +) + +// SafeCalcTime safely calculates the time given. Will return error if the time is outside the +// supported range. +func SafeCalcTime(timestamp int64, precision string) (time.Time, error) { + mult := GetPrecisionMultiplier(precision) + if t, ok := safeSignedMult(timestamp, mult); ok { + tme := time.Unix(0, t).UTC() + return tme, CheckTime(tme) + } + + return time.Time{}, ErrTimeOutOfRange +} + +// CheckTime checks that a time is within the safe range. +func CheckTime(t time.Time) error { + if t.Before(minNanoTime) || t.After(maxNanoTime) { + return ErrTimeOutOfRange + } + return nil +} + +// Perform the multiplication and check to make sure it didn't overflow. +func safeSignedMult(a, b int64) (int64, bool) { + if a == 0 || b == 0 || a == 1 || b == 1 { + return a * b, true + } + if a == MinNanoTime || b == MaxNanoTime { + return 0, false + } + c := a * b + return c, c/b == a +} diff --git a/influxdb/client/models/uint_support.go b/influxdb/client/models/uint_support.go new file mode 100644 index 0000000..dcff490 --- /dev/null +++ b/influxdb/client/models/uint_support.go @@ -0,0 +1,21 @@ +//go:build uint || uint64 +// +build uint uint64 + +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package models + +func init() { + EnableUintSupport() +} diff --git a/influxdb/client/pkg/escape/bytes.go b/influxdb/client/pkg/escape/bytes.go new file mode 100644 index 0000000..5d93c50 --- /dev/null +++ b/influxdb/client/pkg/escape/bytes.go @@ -0,0 +1,127 @@ +// and InfluxDB line protocol. +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package escape // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/pkg/escape" + +import ( + "bytes" + "strings" +) + +// Codes is a map of bytes to be escaped. +var Codes = map[byte][]byte{ + ',': []byte(`\,`), + '"': []byte(`\"`), + ' ': []byte(`\ `), + '=': []byte(`\=`), +} + +// Bytes escapes characters on the input slice, as defined by Codes. +func Bytes(in []byte) []byte { + for b, esc := range Codes { + in = bytes.Replace(in, []byte{b}, esc, -1) + } + return in +} + +const escapeChars = `," =` + +// IsEscaped returns whether b has any escaped characters, +// i.e. whether b seems to have been processed by Bytes. +func IsEscaped(b []byte) bool { + for len(b) > 0 { + i := bytes.IndexByte(b, '\\') + if i < 0 { + return false + } + + if i+1 < len(b) && strings.IndexByte(escapeChars, b[i+1]) >= 0 { + return true + } + b = b[i+1:] + } + return false +} + +// AppendUnescaped appends the unescaped version of src to dst +// and returns the resulting slice. +func AppendUnescaped(dst, src []byte) []byte { + var pos int + for len(src) > 0 { + next := bytes.IndexByte(src[pos:], '\\') + if next < 0 || pos+next+1 >= len(src) { + return append(dst, src...) + } + + if pos+next+1 < len(src) && strings.IndexByte(escapeChars, src[pos+next+1]) >= 0 { + if pos+next > 0 { + dst = append(dst, src[:pos+next]...) + } + src = src[pos+next+1:] + pos = 0 + } else { + pos += next + 1 + } + } + + return dst +} + +// Unescape returns a new slice containing the unescaped version of in. +func Unescape(in []byte) []byte { + if len(in) == 0 { + return nil + } + + if bytes.IndexByte(in, '\\') == -1 { + return in + } + + i := 0 + inLen := len(in) + + // The output size will be no more than inLen. Preallocating the + // capacity of the output is faster and uses less memory than + // letting append() do its own (over)allocation. + out := make([]byte, 0, inLen) + + for { + if i >= inLen { + break + } + if in[i] == '\\' && i+1 < inLen { + switch in[i+1] { + case ',': + out = append(out, ',') + i += 2 + continue + case '"': + out = append(out, '"') + i += 2 + continue + case ' ': + out = append(out, ' ') + i += 2 + continue + case '=': + out = append(out, '=') + i += 2 + continue + } + } + out = append(out, in[i]) + i += 1 + } + return out +} diff --git a/influxdb/client/pkg/escape/strings.go b/influxdb/client/pkg/escape/strings.go new file mode 100644 index 0000000..5de0a01 --- /dev/null +++ b/influxdb/client/pkg/escape/strings.go @@ -0,0 +1,34 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package escape + +import "strings" + +var ( + escaper = strings.NewReplacer(`,`, `\,`, `"`, `\"`, ` `, `\ `, `=`, `\=`) + unescaper = strings.NewReplacer(`\,`, `,`, `\"`, `"`, `\ `, ` `, `\=`, `=`) +) + +// UnescapeString returns unescaped version of in. +func UnescapeString(in string) string { + if strings.IndexByte(in, '\\') == -1 { + return in + } + return unescaper.Replace(in) +} + +// String returns the escaped version of in. +func String(in string) string { + return escaper.Replace(in) +} diff --git a/influxdb/client/v2/client.go b/influxdb/client/v2/client.go new file mode 100644 index 0000000..8a729dc --- /dev/null +++ b/influxdb/client/v2/client.go @@ -0,0 +1,823 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2" + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "mime" + "net" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models" +) + +type ContentEncoding string + +const ( + DefaultEncoding ContentEncoding = "" + GzipEncoding ContentEncoding = "gzip" + DefaultMaxIdleConns = 30 +) + +// HTTPConfig is the config data needed to create an HTTP Client. +type HTTPConfig struct { + // Addr should be of the form "http://host:port" + // or "http://[ipv6-host%zone]:port". + Addr string + + // Username is the influxdb username, optional. + Username string + + // Password is the influxdb password, optional. + Password string + + // UserAgent is the http User Agent, defaults to "InfluxDBClient". + UserAgent string + + // Timeout for influxdb writes, defaults to no timeout. + Timeout time.Duration + + // InsecureSkipVerify gets passed to the http client, if true, it will + // skip https certificate verification. Defaults to false. + InsecureSkipVerify bool + + // TLSConfig allows the user to set their own TLS config for the HTTP + // Client. If set, this option overrides InsecureSkipVerify. + TLSConfig *tls.Config + + // Proxy configures the Proxy function on the HTTP client. + Proxy func(req *http.Request) (*url.URL, error) + + // WriteEncoding specifies the encoding of write request + WriteEncoding ContentEncoding + + // MaxIdleConns controls the maximum number of idle (keep-alive) + MaxIdleConns int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + MaxIdleConnsPerHost int + + // IdleConnTimeout is the maximum amount of time an idle + IdleConnTimeout time.Duration +} + +// BatchPointsConfig is the config data needed to create an instance of the BatchPoints struct. +type BatchPointsConfig struct { + // Precision is the write precision of the points, defaults to "ns". + Precision string + + // Database is the database to write points to. + Database string + + // RetentionPolicy is the retention policy of the points. + RetentionPolicy string + + // Write consistency is the number of servers required to confirm write. + WriteConsistency string +} + +// Client is a client interface for writing & querying the database. +type Client interface { + // Ping checks that status of cluster, and will always return 0 time and no + // error for UDP clients. + Ping(timeout time.Duration) (time.Duration, string, error) + + // Write takes a BatchPoints object and writes all Points to InfluxDB. + Write(bp BatchPoints) error + + // Query makes an InfluxDB Query on the database. This will fail if using + // the UDP client. + Query(q Query) (*Response, error) + + // QueryAsChunk makes an InfluxDB Query on the database. This will fail if using + // the UDP client. + QueryAsChunk(q Query) (*ChunkedResponse, error) + + // Close releases any resources a Client may be using. + Close() error +} + +// NewHTTPClient returns a new Client from the provided config. +// Client is safe for concurrent use by multiple goroutines. +func NewHTTPClient(conf HTTPConfig) (Client, error) { + if conf.UserAgent == "" { + conf.UserAgent = "InfluxDBClient" + } + + u, err := url.Parse(conf.Addr) + if err != nil { + return nil, err + } else if u.Scheme != "http" && u.Scheme != "https" { + m := fmt.Sprintf("Unsupported protocol scheme: %s, your address"+ + " must start with http:// or https://", u.Scheme) + return nil, errors.New(m) + } + + switch conf.WriteEncoding { + case DefaultEncoding, GzipEncoding: + default: + return nil, fmt.Errorf("unsupported encoding %s", conf.WriteEncoding) + } + + if conf.MaxIdleConns == 0 { + conf.MaxIdleConns = DefaultMaxIdleConns + } + if conf.MaxIdleConnsPerHost == 0 { + conf.MaxIdleConnsPerHost = conf.MaxIdleConns + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: conf.InsecureSkipVerify, + }, + Proxy: conf.Proxy, + MaxIdleConns: conf.MaxIdleConns, + MaxIdleConnsPerHost: conf.MaxIdleConnsPerHost, + IdleConnTimeout: conf.IdleConnTimeout, + DialContext: (&net.Dialer{ + KeepAlive: time.Second * 60, + }).DialContext, + } + if conf.TLSConfig != nil { + tr.TLSClientConfig = conf.TLSConfig + } + return &client{ + url: *u, + username: conf.Username, + password: conf.Password, + useragent: conf.UserAgent, + httpClient: &http.Client{ + Timeout: conf.Timeout, + Transport: tr, + }, + transport: tr, + encoding: conf.WriteEncoding, + }, nil +} + +// Ping will check to see if the server is up with an optional timeout on waiting for leader. +// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred. +func (c *client) Ping(timeout time.Duration) (time.Duration, string, error) { + now := time.Now() + + u := c.url + u.Path = path.Join(u.Path, "ping") + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return 0, "", err + } + + req.Header.Set("User-Agent", c.useragent) + + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + if timeout > 0 { + params := req.URL.Query() + params.Set("wait_for_leader", fmt.Sprintf("%.0fs", timeout.Seconds())) + req.URL.RawQuery = params.Encode() + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + + if resp.StatusCode != http.StatusNoContent { + var err = errors.New(string(body)) + return 0, "", err + } + + version := resp.Header.Get("X-Influxdb-Version") + return time.Since(now), version, nil +} + +// Close releases the client's resources. +func (c *client) Close() error { + c.transport.CloseIdleConnections() + return nil +} + +// client is safe for concurrent use as the fields are all read-only +// once the client is instantiated. +type client struct { + // N.B - if url.UserInfo is accessed in future modifications to the + // methods on client, you will need to synchronize access to url. + url url.URL + username string + password string + useragent string + httpClient *http.Client + transport *http.Transport + encoding ContentEncoding +} + +// BatchPoints is an interface into a batched grouping of points to write into +// InfluxDB together. BatchPoints is NOT thread-safe, you must create a separate +// batch for each goroutine. +type BatchPoints interface { + // AddPoint adds the given point to the Batch of points. + AddPoint(p *Point) + // AddPoints adds the given points to the Batch of points. + AddPoints(ps []*Point) + // Points lists the points in the Batch. + Points() []*Point + // ClearPoints clear all points + ClearPoints() + + //ClearPoints get the number of Point + GetPointsNum() int + + // Precision returns the currently set precision of this Batch. + Precision() string + // SetPrecision sets the precision of this batch. + SetPrecision(s string) error + + // Database returns the currently set database of this Batch. + Database() string + // SetDatabase sets the database of this Batch. + SetDatabase(s string) + + // WriteConsistency returns the currently set write consistency of this Batch. + WriteConsistency() string + // SetWriteConsistency sets the write consistency of this Batch. + SetWriteConsistency(s string) + + // RetentionPolicy returns the currently set retention policy of this Batch. + RetentionPolicy() string + // SetRetentionPolicy sets the retention policy of this Batch. + SetRetentionPolicy(s string) +} + +// NewBatchPoints returns a BatchPoints interface based on the given config. +func NewBatchPoints(conf BatchPointsConfig) (BatchPoints, error) { + if conf.Precision == "" { + conf.Precision = "ns" + } + if _, err := time.ParseDuration("1" + conf.Precision); err != nil { + return nil, err + } + bp := &batchpoints{ + database: conf.Database, + precision: conf.Precision, + retentionPolicy: conf.RetentionPolicy, + writeConsistency: conf.WriteConsistency, + } + return bp, nil +} + +type batchpoints struct { + points []*Point + database string + precision string + retentionPolicy string + writeConsistency string +} + +func (bp *batchpoints) AddPoint(p *Point) { + bp.points = append(bp.points, p) +} + +func (bp *batchpoints) AddPoints(ps []*Point) { + bp.points = append(bp.points, ps...) +} + +func (bp *batchpoints) Points() []*Point { + return bp.points +} + +func (bp *batchpoints) ClearPoints() { + bp.points = bp.points[0:0] +} + +func (bp *batchpoints) GetPointsNum() int { + return len(bp.points) +} + +func (bp *batchpoints) Precision() string { + return bp.precision +} + +func (bp *batchpoints) Database() string { + return bp.database +} + +func (bp *batchpoints) WriteConsistency() string { + return bp.writeConsistency +} + +func (bp *batchpoints) RetentionPolicy() string { + return bp.retentionPolicy +} + +func (bp *batchpoints) SetPrecision(p string) error { + if _, err := time.ParseDuration("1" + p); err != nil { + return err + } + bp.precision = p + return nil +} + +func (bp *batchpoints) SetDatabase(db string) { + bp.database = db +} + +func (bp *batchpoints) SetWriteConsistency(wc string) { + bp.writeConsistency = wc +} + +func (bp *batchpoints) SetRetentionPolicy(rp string) { + bp.retentionPolicy = rp +} + +// Point represents a single data point. +type Point struct { + pt models.Point +} + +// NewPoint returns a point with the given timestamp. If a timestamp is not +// given, then data is sent to the database without a timestamp, in which case +// the server will assign local time upon reception. NOTE: it is recommended to +// send data with a timestamp. +func NewPoint( + name string, + tags map[string]string, + fields map[string]interface{}, + t ...time.Time, +) (*Point, error) { + var T time.Time + if len(t) > 0 { + T = t[0] + } + + pt, err := models.NewPoint(name, models.NewTags(tags), fields, T) + if err != nil { + return nil, err + } + return &Point{ + pt: pt, + }, nil +} + +// String returns a line-protocol string of the Point. +func (p *Point) String() string { + return p.pt.String() +} + +// PrecisionString returns a line-protocol string of the Point, +// with the timestamp formatted for the given precision. +func (p *Point) PrecisionString(precision string) string { + return p.pt.PrecisionString(precision) +} + +// Name returns the measurement name of the point. +func (p *Point) Name() string { + return string(p.pt.Name()) +} + +// Tags returns the tags associated with the point. +func (p *Point) Tags() map[string]string { + return p.pt.Tags().Map() +} + +// Time return the timestamp for the point. +func (p *Point) Time() time.Time { + return p.pt.Time() +} + +// UnixNano returns timestamp of the point in nanoseconds since Unix epoch. +func (p *Point) UnixNano() int64 { + return p.pt.UnixNano() +} + +// Fields returns the fields for the point. +func (p *Point) Fields() (map[string]interface{}, error) { + return p.pt.Fields() +} + +// NewPointFrom returns a point from the provided models.Point. +func NewPointFrom(pt models.Point) *Point { + return &Point{pt: pt} +} + +func (c *client) Write(bp BatchPoints) error { + var b bytes.Buffer + + var w io.Writer + if c.encoding == GzipEncoding { + w = gzip.NewWriter(&b) + } else { + w = &b + } + + for _, p := range bp.Points() { + if p == nil { + continue + } + if _, err := io.WriteString(w, p.pt.PrecisionString(bp.Precision())); err != nil && err != io.EOF { + return err + } + + if _, err := w.Write([]byte{'\n'}); err != nil && err != io.EOF { + return err + } + } + + // gzip writer should be closed to flush data into underlying buffer + if c, ok := w.(io.Closer); ok { + if err := c.Close(); err != nil && err != io.EOF { + return err + } + } + + u := c.url + u.Path = path.Join(u.Path, "write") + + req, err := http.NewRequest("POST", u.String(), &b) + if err == io.EOF { + err = nil + } + if err != nil { + return err + } + if c.encoding != DefaultEncoding { + req.Header.Set("Content-Encoding", string(c.encoding)) + } + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.useragent) + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + params := req.URL.Query() + params.Set("db", bp.Database()) + params.Set("rp", bp.RetentionPolicy()) + params.Set("precision", bp.Precision()) + params.Set("consistency", bp.WriteConsistency()) + req.URL.RawQuery = params.Encode() + + resp, err := c.httpClient.Do(req) + if err == io.EOF { + err = nil + } + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err == io.EOF { + err = nil + } + if err != nil { + return err + } + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + var err = errors.New(string(body)) + if err == io.EOF { + err = nil + } + return err + } + + return nil +} + +// Query defines a query to send to the server. +type Query struct { + Command string + Database string + RetentionPolicy string + Precision string + Chunked bool + ChunkSize int + Parameters map[string]interface{} +} + +// Params is a type alias to the query parameters. +type Params map[string]interface{} + +// NewQuery returns a query object. +// The database and precision arguments can be empty strings if they are not needed for the query. +func NewQuery(command, database, precision string) Query { + return Query{ + Command: command, + Database: database, + Precision: precision, + Parameters: make(map[string]interface{}), + } +} + +// NewQueryWithRP returns a query object. +// The database, retention policy, and precision arguments can be empty strings if they are not needed +// for the query. Setting the retention policy only works on InfluxDB versions 1.6 or greater. +func NewQueryWithRP(command, database, retentionPolicy, precision string) Query { + return Query{ + Command: command, + Database: database, + RetentionPolicy: retentionPolicy, + Precision: precision, + Parameters: make(map[string]interface{}), + } +} + +// NewQueryWithParameters returns a query object. +// The database and precision arguments can be empty strings if they are not needed for the query. +// parameters is a map of the parameter names used in the command to their values. +func NewQueryWithParameters(command, database, precision string, parameters map[string]interface{}) Query { + return Query{ + Command: command, + Database: database, + Precision: precision, + Parameters: parameters, + } +} + +// Response represents a list of statement results. +type Response struct { + Results []Result + Err string `json:"error,omitempty"` +} + +// Error returns the first error from any statement. +// It returns nil if no errors occurred on any statements. +func (r *Response) Error() error { + if r.Err != "" { + return errors.New(r.Err) + } + for _, result := range r.Results { + if result.Err != "" { + return errors.New(result.Err) + } + } + return nil +} + +// Message represents a user message. +type Message struct { + Level string + Text string +} + +// Result represents a resultset returned from a single statement. +type Result struct { + StatementId int `json:"statement_id"` + Series []models.Row + Messages []*Message + Err string `json:"error,omitempty"` +} + +// Query sends a command to the server and returns the Response. +func (c *client) Query(q Query) (*Response, error) { + req, err := c.createDefaultRequest(q) + if err != nil { + return nil, err + } + params := req.URL.Query() + if q.Chunked { + params.Set("chunked", "true") + if q.ChunkSize > 0 { + params.Set("chunk_size", strconv.Itoa(q.ChunkSize)) + } + req.URL.RawQuery = params.Encode() + } + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err := checkResponse(resp); err != nil { + return nil, err + } + + var response Response + if q.Chunked { + cr := NewChunkedResponse(resp.Body) + for { + r, err := cr.NextResponse() + if err != nil { + if err == io.EOF { + break + } + // If we got an error while decoding the response, send that back. + return nil, err + } + + if r == nil { + break + } + + response.Results = append(response.Results, r.Results...) + if r.Err != "" { + response.Err = r.Err + break + } + } + } else { + dec := json.NewDecoder(resp.Body) + dec.UseNumber() + decErr := dec.Decode(&response) + + // ignore this error if we got an invalid status code + if decErr != nil && decErr.Error() == "EOF" && resp.StatusCode != http.StatusOK { + decErr = nil + } + // If we got a valid decode error, send that back + if decErr != nil { + return nil, fmt.Errorf("unable to decode json: received status code %d err: %s", resp.StatusCode, decErr) + } + } + + // If we don't have an error in our json response, and didn't get statusOK + // then send back an error + if resp.StatusCode != http.StatusOK && response.Error() == nil { + return &response, fmt.Errorf("received status code %d from server", resp.StatusCode) + } + return &response, nil +} + +// QueryAsChunk sends a command to the server and returns the Response. +func (c *client) QueryAsChunk(q Query) (*ChunkedResponse, error) { + req, err := c.createDefaultRequest(q) + if err != nil { + return nil, err + } + params := req.URL.Query() + params.Set("chunked", "true") + if q.ChunkSize > 0 { + params.Set("chunk_size", strconv.Itoa(q.ChunkSize)) + } + req.URL.RawQuery = params.Encode() + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if err := checkResponse(resp); err != nil { + return nil, err + } + return NewChunkedResponse(resp.Body), nil +} + +func checkResponse(resp *http.Response) error { + // If we lack a X-Influxdb-Version header, then we didn't get a response from influxdb + // but instead some other service. If the error code is also a 500+ code, then some + // downstream loadbalancer/proxy/etc had an issue and we should report that. + if resp.Header.Get("X-Influxdb-Version") == "" && resp.StatusCode >= http.StatusInternalServerError { + body, err := ioutil.ReadAll(resp.Body) + if err != nil || len(body) == 0 { + return fmt.Errorf("received status code %d from downstream server", resp.StatusCode) + } + + return fmt.Errorf("received status code %d from downstream server, with response body: %q", resp.StatusCode, body) + } + + // If we get an unexpected content type, then it is also not from influx direct and therefore + // we want to know what we received and what status code was returned for debugging purposes. + if cType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")); cType != "application/json" { + // Read up to 1kb of the body to help identify downstream errors and limit the impact of things + // like downstream serving a large file + body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil || len(body) == 0 { + return fmt.Errorf("expected json response, got empty body, with status: %v", resp.StatusCode) + } + + return fmt.Errorf("expected json response, got %q, with status: %v and response body: %q", cType, resp.StatusCode, body) + } + return nil +} + +func (c *client) createDefaultRequest(q Query) (*http.Request, error) { + u := c.url + u.Path = path.Join(u.Path, "query") + + jsonParameters, err := json.Marshal(q.Parameters) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", u.String(), nil) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "") + req.Header.Set("User-Agent", c.useragent) + + if c.username != "" { + req.SetBasicAuth(c.username, c.password) + } + + params := req.URL.Query() + params.Set("q", q.Command) + params.Set("db", q.Database) + if q.RetentionPolicy != "" { + params.Set("rp", q.RetentionPolicy) + } + params.Set("params", string(jsonParameters)) + + if q.Precision != "" { + params.Set("epoch", q.Precision) + } + req.URL.RawQuery = params.Encode() + + return req, nil + +} + +// duplexReader reads responses and writes it to another writer while +// satisfying the reader interface. +type duplexReader struct { + r io.ReadCloser + w io.Writer +} + +func (r *duplexReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil { + r.w.Write(p[:n]) + } + return n, err +} + +// Close closes the response. +func (r *duplexReader) Close() error { + return r.r.Close() +} + +// ChunkedResponse represents a response from the server that +// uses chunking to stream the output. +type ChunkedResponse struct { + dec *json.Decoder + duplex *duplexReader + buf bytes.Buffer +} + +// NewChunkedResponse reads a stream and produces responses from the stream. +func NewChunkedResponse(r io.Reader) *ChunkedResponse { + rc, ok := r.(io.ReadCloser) + if !ok { + rc = ioutil.NopCloser(r) + } + resp := &ChunkedResponse{} + resp.duplex = &duplexReader{r: rc, w: &resp.buf} + resp.dec = json.NewDecoder(resp.duplex) + resp.dec.UseNumber() + return resp +} + +// NextResponse reads the next line of the stream and returns a response. +func (r *ChunkedResponse) NextResponse() (*Response, error) { + var response Response + if err := r.dec.Decode(&response); err != nil { + if err == io.EOF { + return nil, err + } + // A decoding error happened. This probably means the server crashed + // and sent a last-ditch error message to us. Ensure we have read the + // entirety of the connection to get any remaining error text. + io.Copy(ioutil.Discard, r.duplex) + return nil, errors.New(strings.TrimSpace(r.buf.String())) + } + + r.buf.Reset() + return &response, nil +} + +// Close closes the response. +func (r *ChunkedResponse) Close() error { + return r.duplex.Close() +} diff --git a/influxdb/client/v2/params.go b/influxdb/client/v2/params.go new file mode 100644 index 0000000..238ef66 --- /dev/null +++ b/influxdb/client/v2/params.go @@ -0,0 +1,86 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "encoding/json" + "time" +) + +type ( + // Identifier is an identifier value. + Identifier string + + // StringValue is a string literal. + StringValue string + + // RegexValue is a regexp literal. + RegexValue string + + // NumberValue is a number literal. + NumberValue float64 + + // IntegerValue is an integer literal. + IntegerValue int64 + + // BooleanValue is a boolean literal. + BooleanValue bool + + // TimeValue is a time literal. + TimeValue time.Time + + // DurationValue is a duration literal. + DurationValue time.Duration +) + +func (v Identifier) MarshalJSON() ([]byte, error) { + m := map[string]string{"identifier": string(v)} + return json.Marshal(m) +} + +func (v StringValue) MarshalJSON() ([]byte, error) { + m := map[string]string{"string": string(v)} + return json.Marshal(m) +} + +func (v RegexValue) MarshalJSON() ([]byte, error) { + m := map[string]string{"regex": string(v)} + return json.Marshal(m) +} + +func (v NumberValue) MarshalJSON() ([]byte, error) { + m := map[string]float64{"number": float64(v)} + return json.Marshal(m) +} + +func (v IntegerValue) MarshalJSON() ([]byte, error) { + m := map[string]int64{"integer": int64(v)} + return json.Marshal(m) +} + +func (v BooleanValue) MarshalJSON() ([]byte, error) { + m := map[string]bool{"boolean": bool(v)} + return json.Marshal(m) +} + +func (v TimeValue) MarshalJSON() ([]byte, error) { + t := time.Time(v) + m := map[string]string{"string": t.Format(time.RFC3339Nano)} + return json.Marshal(m) +} + +func (v DurationValue) MarshalJSON() ([]byte, error) { + m := map[string]int64{"duration": int64(v)} + return json.Marshal(m) +} diff --git a/influxdb/client/v2/udp.go b/influxdb/client/v2/udp.go new file mode 100644 index 0000000..c548a0e --- /dev/null +++ b/influxdb/client/v2/udp.go @@ -0,0 +1,129 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "fmt" + "io" + "net" + "time" +) + +const ( + // UDPPayloadSize is a reasonable default payload size for UDP packets that + // could be travelling over the internet. + UDPPayloadSize = 512 +) + +// UDPConfig is the config data needed to create a UDP Client. +type UDPConfig struct { + // Addr should be of the form "host:port" + // or "[ipv6-host%zone]:port". + Addr string + + // PayloadSize is the maximum size of a UDP client message, optional + // Tune this based on your network. Defaults to UDPPayloadSize. + PayloadSize int +} + +// NewUDPClient returns a client interface for writing to an InfluxDB UDP +// service from the given config. +func NewUDPClient(conf UDPConfig) (Client, error) { + var udpAddr *net.UDPAddr + udpAddr, err := net.ResolveUDPAddr("udp", conf.Addr) + if err != nil { + return nil, err + } + + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, err + } + + payloadSize := conf.PayloadSize + if payloadSize == 0 { + payloadSize = UDPPayloadSize + } + + return &udpclient{ + conn: conn, + payloadSize: payloadSize, + }, nil +} + +// Close releases the udpclient's resources. +func (uc *udpclient) Close() error { + return uc.conn.Close() +} + +type udpclient struct { + conn io.WriteCloser + payloadSize int +} + +func (uc *udpclient) Write(bp BatchPoints) error { + var b = make([]byte, 0, uc.payloadSize) // initial buffer size, it will grow as needed + var d, _ = time.ParseDuration("1" + bp.Precision()) + + var delayedError error + + var checkBuffer = func(n int) { + if len(b) > 0 && len(b)+n > uc.payloadSize { + if _, err := uc.conn.Write(b); err != nil { + delayedError = err + } + b = b[:0] + } + } + + for _, p := range bp.Points() { + p.pt.Round(d) + pointSize := p.pt.StringSize() + 1 // include newline in size + //point := p.pt.RoundedString(d) + "\n" + + checkBuffer(pointSize) + + if p.Time().IsZero() || pointSize <= uc.payloadSize { + b = p.pt.AppendString(b) + b = append(b, '\n') + continue + } + + points := p.pt.Split(uc.payloadSize - 1) // account for newline character + for _, sp := range points { + checkBuffer(sp.StringSize() + 1) + b = sp.AppendString(b) + b = append(b, '\n') + } + } + + if len(b) > 0 { + if _, err := uc.conn.Write(b); err != nil { + return err + } + } + return delayedError +} + +func (uc *udpclient) Query(q Query) (*Response, error) { + return nil, fmt.Errorf("Querying via UDP is not supported") +} + +func (uc *udpclient) QueryAsChunk(q Query) (*ChunkedResponse, error) { + return nil, fmt.Errorf("Querying via UDP is not supported") +} + +func (uc *udpclient) Ping(timeout time.Duration) (time.Duration, string, error) { + return 0, "", nil +} diff --git a/influxdb/config.go b/influxdb/config.go new file mode 100644 index 0000000..ea763ad --- /dev/null +++ b/influxdb/config.go @@ -0,0 +1,48 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package influxdb + +// Config configuration +type Config struct { + Enable bool `yaml:"enable"` //Service switch + Address string `yaml:"address"` + Port int `yaml:"port"` + UDPAddress string `yaml:"udp_address"` //influxdb UDP address of the database,ip:port + Database string `yaml:"database"` //Database name + Precision string `yaml:"precision"` //Accuracy n, u, ms, s, m or h + UserName string `yaml:"username"` + Password string `yaml:"password"` + MaxIdleConns int `yaml:"max-idle-conns"` + MaxIdleConnsPerHost int `yaml:"max-idle-conns-per-host"` + IdleConnTimeout int `yaml:"idle-conn-timeout"` +} + +// CustomConfig Custom configuration +type CustomConfig struct { + Enabled bool `yaml:"enabled"` //Service switch + Address string `yaml:"address"` + Port int `yaml:"port"` + UDPAddress string `yaml:"udp_address"` //influxdb UDP address of the database,ip:port + Database string `yaml:"database"` //Database name + Precision string `yaml:"precision"` //Accuracy n, u, ms, s, m or h + UserName string `yaml:"username"` + Password string `yaml:"password"` + ReadUserName string `yaml:"read-username"` + ReadPassword string `yaml:"read-password"` + MaxIdleConns int `yaml:"max-idle-conns"` + MaxIdleConnsPerHost int `yaml:"max-idle-conns-per-host"` + IdleConnTimeout int `yaml:"idle-conn-timeout"` + FlushSize int `yaml:"flush-size"` + FlushTime int `yaml:"flush-time"` +} diff --git a/influxdb/metrics.go b/influxdb/metrics.go new file mode 100644 index 0000000..c5b4152 --- /dev/null +++ b/influxdb/metrics.go @@ -0,0 +1,141 @@ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package influxdb + +import ( + "errors" + "fmt" + client "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2" + "github.com/ztdbp/ZACA/pkg/logger" + "io" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Metrics ... +type Metrics struct { + mu sync.Mutex + conf *CustomConfig + batchPoints client.BatchPoints + point chan *client.Point + flushTimer *time.Ticker + InfluxDBHttpClient *HTTPClient + counter uint64 +} + +// MetricsData ... +type MetricsData struct { + Measurement string `json:"measurement"` + Fields map[string]interface{} `json:"fields"` + Tags map[string]string `json:"tags"` +} + +// Response ... +type Response struct { + State int `json:"state"` + Data struct{} `json:"data"` + Msg string `json:"msg"` +} + +// NewMetrics ... +func NewMetrics(influxDBHttpClient *HTTPClient, conf *CustomConfig) (metrics *Metrics) { + bp, err := client.NewBatchPoints(influxDBHttpClient.BatchPointsConfig) + if err != nil { + logger.Named("metrics").Errorf("custom-influxdb client.NewBatchPoints err: %v", err) + return + } + metrics = &Metrics{ + conf: conf, + batchPoints: bp, + point: make(chan *client.Point, 16), + flushTimer: time.NewTicker(time.Duration(conf.FlushTime) * time.Second), + InfluxDBHttpClient: influxDBHttpClient, + } + go metrics.worker() + return +} + +func (mt *Metrics) AddPoint(metricsData *MetricsData) { + if mt == nil { + return + } + //atomic.AddUint64(&mt.counter, 1) + pt, err := client.NewPoint(metricsData.Measurement, metricsData.Tags, metricsData.Fields, time.Now()) + if err != nil { + logger.Named("metrics").Errorf("custom-influxdb client.NewPoint err: %s", err) + return + } + mt.point <- pt +} + +func (mt *Metrics) worker() { + for { + select { + case p, ok := <-mt.point: + if !ok { + mt.flush() + return + } + mt.batchPoints.AddPoint(p) + // When the number of points reaches 50, send data + if mt.batchPoints.GetPointsNum() >= mt.conf.FlushSize { + mt.flush() + } + case <-mt.flushTimer.C: + mt.flush() + } + } +} + +func (mt *Metrics) flush() { + mt.mu.Lock() + defer mt.mu.Unlock() + if mt.batchPoints.GetPointsNum() == 0 { + return + } + err := mt.Write() + if err != nil { + if strings.Contains(err.Error(), io.EOF.Error()) { + err = nil + } else { + logger.Named("metric").Errorf("custom-influxdb client.Write err: %s", err) + } + } + defer mt.InfluxDBHttpClient.FluxDBHttpClose() + // Clear all points + mt.batchPoints.ClearPoints() +} + +// Write data timeout processing +func (mt *Metrics) Write() error { + ch := make(chan error, 1) + go func() { + ch <- mt.InfluxDBHttpClient.FluxDBHttpWrite(mt.batchPoints) + }() + select { + case err := <-ch: + return err + case <-time.After(800 * time.Millisecond): + return errors.New("write timeout") + } +} + +func (mt *Metrics) count() { + for { + time.Sleep(time.Second) + fmt.Println("Counter:", atomic.LoadUint64(&mt.counter)) + } +} diff --git a/logger/hook/hook.go b/logger/hook/hook.go new file mode 100644 index 0000000..087e598 --- /dev/null +++ b/logger/hook/hook.go @@ -0,0 +1,165 @@ +// Copyright 2022-present The ZTDBP Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hook + +import ( + "fmt" + "os" + + "github.com/LyricTian/queue" + "github.com/sirupsen/logrus" +) + +var defaultOptions = options{ + maxQueues: 512, + maxWorkers: 1, + levels: []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + logrus.WarnLevel, + logrus.InfoLevel, + logrus.DebugLevel, + logrus.TraceLevel, + }, +} + +// ExecCloser write the logrus entry to the store and close the store +type ExecCloser interface { + Exec(entry *logrus.Entry) error + Close() error +} + +// FilterHandle a filter handler +type FilterHandle func(*logrus.Entry) *logrus.Entry + +type options struct { + maxQueues int + maxWorkers int + extra map[string]interface{} + filter FilterHandle + levels []logrus.Level +} + +// SetMaxQueues set the number of buffers +func SetMaxQueues(maxQueues int) Option { + return func(o *options) { + o.maxQueues = maxQueues + } +} + +// SetMaxWorkers set the number of worker threads +func SetMaxWorkers(maxWorkers int) Option { + return func(o *options) { + o.maxWorkers = maxWorkers + } +} + +// SetExtra set extended parameters +func SetExtra(extra map[string]interface{}) Option { + return func(o *options) { + o.extra = extra + } +} + +// SetFilter set the entry filter +func SetFilter(filter FilterHandle) Option { + return func(o *options) { + o.filter = filter + } +} + +// SetLevels set the available log level +func SetLevels(levels ...logrus.Level) Option { + return func(o *options) { + if len(levels) == 0 { + return + } + o.levels = levels + } +} + +// Option a hook parameter options +type Option func(*options) + +// New creates a hook to be added to an instance of logger +func New(exec ExecCloser, opt ...Option) *Hook { + opts := defaultOptions + for _, o := range opt { + o(&opts) + } + + q := queue.NewQueue(opts.maxQueues, opts.maxWorkers) + q.Run() + + return &Hook{ + opts: opts, + q: q, + e: exec, + } +} + +// Hook to send logs to a mongo database +type Hook struct { + opts options + q *queue.Queue + e ExecCloser +} + +// Levels returns the available logging levels +func (h *Hook) Levels() []logrus.Level { + return h.opts.levels +} + +// Fire is called when a log event is fired +func (h *Hook) Fire(entry *logrus.Entry) error { + entry = h.copyEntry(entry) + h.q.Push(queue.NewJob(entry, func(v interface{}) { + h.exec(v.(*logrus.Entry)) + })) + return nil +} + +func (h *Hook) copyEntry(e *logrus.Entry) *logrus.Entry { + entry := logrus.NewEntry(e.Logger) + entry.Data = make(logrus.Fields) + entry.Time = e.Time + entry.Level = e.Level + entry.Message = e.Message + for k, v := range e.Data { + entry.Data[k] = v + } + return entry +} + +func (h *Hook) exec(entry *logrus.Entry) { + for k, v := range h.opts.extra { + if _, ok := entry.Data[k]; !ok { + entry.Data[k] = v + } + } + + if filter := h.opts.filter; filter != nil { + entry = filter(entry) + } + + err := h.e.Exec(entry) + if err != nil { + fmt.Fprintf(os.Stderr, "[logrus-hook] execution error: %s", err.Error()) + } +} + +// Flush waits for the log queue to be empty +func (h *Hook) Flush() { + h.q.Terminate() + h.e.Close() +} diff --git a/logger/hook/redis/redis.go b/logger/hook/redis/redis.go new file mode 100644 index 0000000..5142f02 --- /dev/null +++ b/logger/hook/redis/redis.go @@ -0,0 +1,64 @@ +// Copyright 2022-present The ZTDBP Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "github.com/ztdbp/ZASentinel/pkg/util/json" + + "github.com/go-redis/redis" + "github.com/sirupsen/logrus" +) + +type Config struct { + Addr string + Password string + Key string +} + +func New(c *Config) *Hook { + redisdb := redis.NewClient(&redis.Options{ + Addr: c.Addr, // use default Addr + Password: c.Password, // no password set + DB: 0, // use default DB + }) + + _, err := redisdb.Ping().Result() + if err != nil { + fmt.Println("error creating message for REDIS:", err) + panic(err) + } + return &Hook{ + cli: redisdb, + key: c.Key, + } +} + +type Hook struct { + cli *redis.Client + key string +} + +func (h *Hook) Exec(entry *logrus.Entry) error { + fields := make(map[string]interface{}) + for k, v := range entry.Data { + fields[k] = v + } + fields["level"] = entry.Level.String() + fields["message"] = entry.Message + b, _ := json.Marshal(fields) + return h.cli.RPush(h.key, string(b)).Err() +} + +func (h *Hook) Close() error { + return h.cli.Close() +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..bed90c1 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,155 @@ +// Copyright 2022-present The ZTDBP Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logger + +import ( + "context" + "io" + + "github.com/sirupsen/logrus" +) + +// Define key +const ( + TraceIDKey = "trace_id" + UserIDKey = "user_id" + TagKey = "tag" + VersionKey = "version" + StackKey = "stack" +) + +var version string + +type Logger = logrus.Logger + +type Entry = logrus.Entry + +type Hook = logrus.Hook + +func StandardLogger() *Logger { + return logrus.StandardLogger() +} + +func SetLevel(level int) { + logrus.SetLevel(logrus.Level(level)) +} + +func SetFormatter(format string) { + switch format { + case "json": + logrus.SetFormatter(new(logrus.JSONFormatter)) + default: + logrus.SetFormatter(new(logrus.TextFormatter)) + } +} + +func SetOutput(out io.Writer) { + logrus.SetOutput(out) +} + +func SetVersion(v string) { + version = v +} + +func AddHook(hook Hook) { + logrus.AddHook(hook) +} + +type ( + traceIDKey struct{} + userIDKey struct{} + tagKey struct{} + stackKey struct{} +) + +func NewTraceIDContext(ctx context.Context, traceID string) context.Context { + return context.WithValue(ctx, traceIDKey{}, traceID) +} + +func FromTraceIDContext(ctx context.Context) string { + v := ctx.Value(traceIDKey{}) + if v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func NewUserIDContext(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, userIDKey{}, userID) +} + +func FromUserIDContext(ctx context.Context) string { + v := ctx.Value(userIDKey{}) + if v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func NewTagContext(ctx context.Context, tag string) context.Context { + return context.WithValue(ctx, tagKey{}, tag) +} + +func FromTagContext(ctx context.Context) string { + v := ctx.Value(tagKey{}) + if v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func NewStackContext(ctx context.Context, stack error) context.Context { + return context.WithValue(ctx, stackKey{}, stack) +} + +func FromStackContext(ctx context.Context) error { + v := ctx.Value(stackKey{}) + if v != nil { + if s, ok := v.(error); ok { + return s + } + } + return nil +} + +func WithErrorStack(ctx context.Context, err error) *Entry { + if ctx == nil { + ctx = context.Background() + } + return WithContext(NewStackContext(ctx, err)) +} + +func WithContext(ctx context.Context) *Entry { + if ctx == nil { + ctx = context.Background() + } + + return logrus.WithContext(ctx) +} + +// Define logrus alias +var ( + Tracef = logrus.Tracef + Debugf = logrus.Debugf + Infof = logrus.Infof + Warnf = logrus.Warnf + Errorf = logrus.Errorf + Fatalf = logrus.Fatalf + Panicf = logrus.Panicf + Printf = logrus.Printf +) diff --git a/memorycacher/README.md b/memorycacher/README.md new file mode 100644 index 0000000..7411ca1 --- /dev/null +++ b/memorycacher/README.md @@ -0,0 +1,83 @@ + + +# Base on go-cache + +go-cache is an in-memory key:value store/cache similar to memcached that is +suitable for applications running on a single machine. Its major advantage is +that, being essentially a thread-safe `map[string]interface{}` with expiration +times, it doesn't need to serialize or transmit its contents over the network. + +Any object can be stored, for a given duration or forever, and the cache can be +safely used by multiple goroutines. + +Although go-cache isn't meant to be used as a persistent datastore, the entire +cache can be saved to and loaded from a file (using `c.Items()` to retrieve the +items map to serialize, and `NewFrom()` to create a cache from a deserialized +one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.) + +### Usage + +```go +import ( + "fmt" + "memorycache" + "time" +) + +func main() { + // Create a cache with a default expiration time of 5 minutes, and which + // purges expired items every 10 minutes + c := memorycache.New(5*time.Minute, 10*time.Minute) + + // Set the value of the key "foo" to "bar", with the default expiration time + c.Set("foo", "bar", cache.DefaultExpiration) + + // Set the value of the key "baz" to 42, with no expiration time + // (the item won't be removed until it is re-set, or removed using + // c.Delete("baz") + c.Set("baz", 42, cache.NoExpiration) + + // Get the string associated with the key "foo" from the cache + foo, found := c.Get("foo") + if found { + fmt.Println(foo) + } + + // Since Go is statically typed, and cache values can be anything, type + // assertion is needed when values are being passed to functions that don't + // take arbitrary types, (i.e. interface{}). The simplest way to do this for + // values which will only be used once--e.g. for passing to another + // function--is: + foo, found := c.Get("foo") + if found { + MyFunction(foo.(string)) + } + + // This gets tedious if the value is used several times in the same function. + // You might do either of the following instead: + if x, found := c.Get("foo"); found { + foo := x.(string) + // ... + } + // or + var foo string + if x, found := c.Get("foo"); found { + foo = x.(string) + } + // ... + // foo can then be passed around freely as a string + + // Want performance? Store pointers! + c.Set("foo", &MyStruct, cache.DefaultExpiration) + if x, found := c.Get("foo"); found { + foo := x.(*MyStruct) + // ... + } +} +``` \ No newline at end of file diff --git a/memorycacher/cache.go b/memorycacher/cache.go new file mode 100644 index 0000000..56f3939 --- /dev/null +++ b/memorycacher/cache.go @@ -0,0 +1,1234 @@ +/* + * @Author: patrickmn,gitsrc + * @Date: 2020-07-09 13:17:30 + * @LastEditors: gitsrc + * @LastEditTime: 2020-07-09 13:22:16 + * @FilePath: /ServiceCar/utils/memorycache/cache.go + */ + +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package memorycacher + +import ( + "encoding/gob" + "errors" + "fmt" + "io" + "os" + "runtime" + "sync" + "time" +) + +type Item struct { + Object interface{} + Expiration int64 +} + +// const ( +// cleanStatusPending = 0 +// cleanStatusRunning = 1 +// ) + +var ( + maxItemsCountErr = errors.New("reach max items count.") +) + +// Returns true if the item has expired. +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +const ( + // For use with functions that take an expiration time. + NoExpiration time.Duration = -1 + // For use with functions that take an expiration time. Equivalent to + // passing in the same expiration duration as was given to New() or + // NewFrom() when the cache was created (e.g. 5 minutes.) + DefaultExpiration time.Duration = 0 +) + +type Cache struct { + *cache + // If this is confusing, see the comment at the bottom of New() +} + +type cache struct { + defaultExpiration time.Duration + maxItemsCount int //Maximum number of items + items map[string]Item + mu sync.RWMutex + onEvicted func(string, interface{}) + lastCleanTime time.Time //Last cleaning time + janitor *janitor +} + +//Determine whether the amount of schema inside the map schema structure has reached the maximum number limit +func (c *cache) IsReachMaxItemsCount() bool { + return c.ItemCount() >= c.maxItemsCount +} + +// Add an item to the cache, replacing any existing item. If the duration is 0 +// (DefaultExpiration), the cache's default expiration time is used. If it is -1 +// (NoExpiration), the item never expires. +func (c *cache) Set(k string, x interface{}, d time.Duration) { + if c.IsReachMaxItemsCount() { + c.ShoudClean() + return + } + + // "Inlining" of set + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.mu.Lock() + c.items[k] = Item{ + Object: x, + Expiration: e, + } + // TODO: Calls to mu.Unlock are currently not deferred because defer + // adds ~200 ns (as of go1.) + c.mu.Unlock() +} + +func (c *cache) set(k string, x interface{}, d time.Duration) { + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.items[k] = Item{ + Object: x, + Expiration: e, + } +} + +// Add an item to the cache, replacing any existing item, using the default +// expiration. +func (c *cache) SetDefault(k string, x interface{}) { + c.Set(k, x, DefaultExpiration) +} + +// Add an item to the cache only if an item doesn't already exist for the given +// key, or if the existing item has expired. Returns an error otherwise. +func (c *cache) Add(k string, x interface{}, d time.Duration) error { + if c.IsReachMaxItemsCount() { + c.ShoudClean() + return maxItemsCountErr + } + + c.mu.Lock() + _, found := c.get(k) + if found { + c.mu.Unlock() + return fmt.Errorf("Item %s already exists", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Set a new value for the cache key only if it already exists, and the existing +// item hasn't expired. Returns an error otherwise. +func (c *cache) Replace(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if !found { + c.mu.Unlock() + return fmt.Errorf("Item %s doesn't exist", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Get an item from the cache. Returns the item or nil, and a bool indicating +// whether the key was found. +func (c *cache) Get(k string) (interface{}, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, false + } + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, false + } + } + c.mu.RUnlock() + return item.Object, true +} + +// GetWithExpiration returns an item and its expiration time from the cache. +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + // Return the item and the expiration time + c.mu.RUnlock() + return item.Object, time.Unix(0, item.Expiration), true + } + + // If expiration <= 0 (i.e. no expiration time set) then return the item + // and a zeroed time.Time + c.mu.RUnlock() + return item.Object, time.Time{}, true +} + +func (c *cache) get(k string) (interface{}, bool) { + item, found := c.items[k] + if !found { + return nil, false + } + // "Inlining" of Expired + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + return nil, false + } + } + return item.Object, true +} + +//send map clean to cache janitor +func (c *cache) ShoudClean() { + if c.janitor.shoudClean == nil { + return + } + select { + case c.janitor.shoudClean <- true: + default: + } +} + +// Increment an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to increment it by n. To retrieve the incremented value, use one +// of the specialized methods, e.g. IncrementInt64. +func (c *cache) Increment(k string, n int64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) + int(n) + case int8: + v.Object = v.Object.(int8) + int8(n) + case int16: + v.Object = v.Object.(int16) + int16(n) + case int32: + v.Object = v.Object.(int32) + int32(n) + case int64: + v.Object = v.Object.(int64) + n + case uint: + v.Object = v.Object.(uint) + uint(n) + case uintptr: + v.Object = v.Object.(uintptr) + uintptr(n) + case uint8: + v.Object = v.Object.(uint8) + uint8(n) + case uint16: + v.Object = v.Object.(uint16) + uint16(n) + case uint32: + v.Object = v.Object.(uint32) + uint32(n) + case uint64: + v.Object = v.Object.(uint64) + uint64(n) + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to increment it by n. Pass a negative number to decrement the +// value. To retrieve the incremented value, use one of the specialized methods, +// e.g. IncrementFloat64. +func (c *cache) IncrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint8 by n. Returns an error if the item's value +// is not an uint8, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to decrement it by n. To retrieve the decremented value, use one +// of the specialized methods, e.g. DecrementInt64. +func (c *cache) Decrement(k string, n int64) error { + // TODO: Implement Increment and Decrement more cleanly. + // (Cannot do Increment(k, n*-1) for uints.) + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item not found") + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) - int(n) + case int8: + v.Object = v.Object.(int8) - int8(n) + case int16: + v.Object = v.Object.(int16) - int16(n) + case int32: + v.Object = v.Object.(int32) - int32(n) + case int64: + v.Object = v.Object.(int64) - n + case uint: + v.Object = v.Object.(uint) - uint(n) + case uintptr: + v.Object = v.Object.(uintptr) - uintptr(n) + case uint8: + v.Object = v.Object.(uint8) - uint8(n) + case uint16: + v.Object = v.Object.(uint16) - uint16(n) + case uint32: + v.Object = v.Object.(uint32) - uint32(n) + case uint64: + v.Object = v.Object.(uint64) - uint64(n) + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to decrement it by n. Pass a negative number to decrement the +// value. To retrieve the decremented value, use one of the specialized methods, +// e.g. DecrementFloat64. +func (c *cache) DecrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint8 by n. Returns an error if the item's value is +// not an uint8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Delete an item from the cache. Does nothing if the key is not in the cache. +func (c *cache) Delete(k string) { + c.mu.Lock() + v, evicted := c.delete(k) + c.mu.Unlock() + if evicted { + c.onEvicted(k, v) + } +} + +func (c *cache) delete(k string) (interface{}, bool) { + if c.onEvicted != nil { + if v, found := c.items[k]; found { + delete(c.items, k) + return v.Object, true + } + } + delete(c.items, k) + return nil, false +} + +type keyAndValue struct { + key string + value interface{} +} + +// Delete all expired items from the cache. +func (c *cache) DeleteExpired() { + var evictedItems []keyAndValue + nowTime := time.Now() + now := nowTime.UnixNano() + c.mu.Lock() + for k, v := range c.items { + // "Inlining" of expired + if v.Expiration > 0 && now > v.Expiration { + ov, evicted := c.delete(k) + if evicted { + evictedItems = append(evictedItems, keyAndValue{k, ov}) + } + } + } + c.lastCleanTime = nowTime + c.mu.Unlock() + for _, v := range evictedItems { + c.onEvicted(v.key, v.value) + } +} + +// Sets an (optional) function that is called with the key and value when an +// item is evicted from the cache. (Including when it is deleted manually, but +// not when it is overwritten.) Set to nil to disable. +func (c *cache) OnEvicted(f func(string, interface{})) { + c.mu.Lock() + c.onEvicted = f + c.mu.Unlock() +} + +// Write the cache's items (using Gob) to an io.Writer. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Save(w io.Writer) (err error) { + enc := gob.NewEncoder(w) + defer func() { + if x := recover(); x != nil { + err = fmt.Errorf("Error registering item types with Gob library") + } + }() + c.mu.RLock() + defer c.mu.RUnlock() + for _, v := range c.items { + gob.Register(v.Object) + } + err = enc.Encode(&c.items) + return +} + +// Save the cache's items to the given filename, creating the file if it +// doesn't exist, and overwriting it if it does. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) SaveFile(fname string) error { + fp, err := os.Create(fname) + if err != nil { + return err + } + err = c.Save(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Add (Gob-serialized) cache items from an io.Reader, excluding any items with +// keys that already exist (and haven't expired) in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Load(r io.Reader) error { + dec := gob.NewDecoder(r) + items := map[string]Item{} + err := dec.Decode(&items) + if err == nil { + c.mu.Lock() + defer c.mu.Unlock() + for k, v := range items { + ov, found := c.items[k] + if !found || ov.Expired() { + c.items[k] = v + } + } + } + return err +} + +// Load and add cache items from the given filename, excluding any items with +// keys that already exist in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) LoadFile(fname string) error { + fp, err := os.Open(fname) + if err != nil { + return err + } + err = c.Load(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Copies all unexpired items in the cache into a new map and returns it. +func (c *cache) Items() map[string]Item { + c.mu.RLock() + defer c.mu.RUnlock() + m := make(map[string]Item, len(c.items)) + now := time.Now().UnixNano() + for k, v := range c.items { + // "Inlining" of Expired + if v.Expiration > 0 { + if now > v.Expiration { + continue + } + } + m[k] = v + } + return m +} + +// Returns the number of items in the cache. This may include items that have +// expired, but have not yet been cleaned up. +func (c *cache) ItemCount() int { + c.mu.RLock() + n := len(c.items) + c.mu.RUnlock() + return n +} + +// Delete all items from the cache. +func (c *cache) Flush() { + c.mu.Lock() + c.items = map[string]Item{} + c.mu.Unlock() +} + +type janitor struct { + Interval time.Duration + stop chan bool + shoudClean chan bool +} + +func (j *janitor) Run(c *cache) { + ticker := time.NewTicker(j.Interval) + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-j.shoudClean: + c.mu.RLock() + lastCleanTime := c.lastCleanTime + c.mu.RUnlock() + + if lastCleanTime.Add(time.Second * 1).Before(time.Now()) { + c.DeleteExpired() + } + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- true +} + +func runJanitor(c *cache, ci time.Duration) { + j := &janitor{ + Interval: ci, + stop: make(chan bool), + shoudClean: make(chan bool), + } + c.janitor = j + go j.Run(c) +} + +func newCache(de time.Duration, maxItemsCount int, m map[string]Item) *cache { + if de == 0 { + de = -1 + } + c := &cache{ + defaultExpiration: de, + maxItemsCount: maxItemsCount, + items: m, + lastCleanTime: time.Now(), + } + return c +} + +func newCacheWithJanitor(de time.Duration, ci time.Duration, maxItemsCount int, m map[string]Item) *Cache { + c := newCache(de, maxItemsCount, m) + // This trick ensures that the janitor goroutine (which--granted it + // was enabled--is running DeleteExpired on c forever) does not keep + // the returned C object from being garbage collected. When it is + // garbage collected, the finalizer stops the janitor goroutine, after + // which c can be collected. + C := &Cache{c} + if ci > 0 { + runJanitor(c, ci) + runtime.SetFinalizer(C, stopJanitor) + } + return C +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +func New(defaultExpiration, cleanupInterval time.Duration, maxItemsCount int) *Cache { + items := make(map[string]Item) + return newCacheWithJanitor(defaultExpiration, cleanupInterval, maxItemsCount, items) +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +// +// NewFrom() also accepts an items map which will serve as the underlying map +// for the cache. This is useful for starting from a deserialized cache +// (serialized using e.g. gob.Encode() on c.Items()), or passing in e.g. +// make(map[string]Item, 500) to improve startup performance when the cache +// is expected to reach a certain minimum size. +// +// Only the cache's methods synchronize access to this map, so it is not +// recommended to keep any references to the map around after creating a cache. +// If need be, the map can be accessed at a later point using c.Items() (subject +// to the same caveat.) +// +// Note regarding serialization: When using e.g. gob, make sure to +// gob.Register() the individual types stored in the cache before encoding a +// map retrieved with c.Items(), and to register those same types before +// decoding a blob containing an items map. +func NewFrom(defaultExpiration, cleanupInterval time.Duration, maxItemsCount int, items map[string]Item) *Cache { + return newCacheWithJanitor(defaultExpiration, cleanupInterval, maxItemsCount, items) +} diff --git a/memorycacher/cache_test.go b/memorycacher/cache_test.go new file mode 100644 index 0000000..209a237 --- /dev/null +++ b/memorycacher/cache_test.go @@ -0,0 +1,1796 @@ +/* + * @Author: patrickmn,gitsrc + * @Date: 2020-07-09 13:17:30 + * @LastEditors: gitsrc + * @LastEditTime: 2020-07-10 10:06:28 + * @FilePath: /ServiceCar/utils/memorycacher/cache_test.go + */ + +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package memorycacher + +import ( + "bytes" + "io/ioutil" + "runtime" + "strconv" + "sync" + "testing" + "time" +) + +type TestStruct struct { + Num int + Children []*TestStruct +} + +const ( + maxItemsCount = 10000 +) + +func TestCache(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + + a, found := tc.Get("a") + if found || a != nil { + t.Error("Getting A found value that shouldn't exist:", a) + } + + b, found := tc.Get("b") + if found || b != nil { + t.Error("Getting B found value that shouldn't exist:", b) + } + + c, found := tc.Get("c") + if found || c != nil { + t.Error("Getting C found value that shouldn't exist:", c) + } + + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", 3.5, DefaultExpiration) + + x, found := tc.Get("a") + if !found { + t.Error("a was not found while getting a2") + } + if x == nil { + t.Error("x for a is nil") + } else if a2 := x.(int); a2+2 != 3 { + t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) + } + + x, found = tc.Get("b") + if !found { + t.Error("b was not found while getting b2") + } + if x == nil { + t.Error("x for b is nil") + } else if b2 := x.(string); b2+"B" != "bB" { + t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) + } + + x, found = tc.Get("c") + if !found { + t.Error("c was not found while getting c2") + } + if x == nil { + t.Error("x for c is nil") + } else if c2 := x.(float64); c2+1.2 != 4.7 { + t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) + } +} + +func TestCacheTimes(t *testing.T) { + var found bool + + tc := New(50*time.Millisecond, 1*time.Millisecond, maxItemsCount) + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", 2, NoExpiration) + tc.Set("c", 3, 20*time.Millisecond) + tc.Set("d", 4, 70*time.Millisecond) + + <-time.After(25 * time.Millisecond) + _, found = tc.Get("c") + if found { + t.Error("Found c when it should have been automatically deleted") + } + + <-time.After(30 * time.Millisecond) + _, found = tc.Get("a") + if found { + t.Error("Found a when it should have been automatically deleted") + } + + _, found = tc.Get("b") + if !found { + t.Error("Did not find b even though it was set to never expire") + } + + _, found = tc.Get("d") + if !found { + t.Error("Did not find d even though it was set to expire later than the default") + } + + <-time.After(20 * time.Millisecond) + _, found = tc.Get("d") + if found { + t.Error("Found d when it should have been automatically deleted (later than the default)") + } +} + +func TestNewFrom(t *testing.T) { + m := map[string]Item{ + "a": Item{ + Object: 1, + Expiration: 0, + }, + "b": Item{ + Object: 2, + Expiration: 0, + }, + } + tc := NewFrom(DefaultExpiration, 0, maxItemsCount, m) + a, found := tc.Get("a") + if !found { + t.Fatal("Did not find a") + } + if a.(int) != 1 { + t.Fatal("a is not 1") + } + b, found := tc.Get("b") + if !found { + t.Fatal("Did not find b") + } + if b.(int) != 2 { + t.Fatal("b is not 2") + } +} + +func TestStorePointerToStruct(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", &TestStruct{Num: 1}, DefaultExpiration) + x, found := tc.Get("foo") + if !found { + t.Fatal("*TestStruct was not found for foo") + } + foo := x.(*TestStruct) + foo.Num++ + + y, found := tc.Get("foo") + if !found { + t.Fatal("*TestStruct was not found for foo (second time)") + } + bar := y.(*TestStruct) + if bar.Num != 2 { + t.Fatal("TestStruct.Num is not 2") + } +} + +func TestIncrementWithInt(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint", 1, DefaultExpiration) + err := tc.Increment("tint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint") + if !found { + t.Error("tint was not found") + } + if x.(int) != 3 { + t.Error("tint is not 3:", x) + } +} + +func TestIncrementWithInt8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint8", int8(1), DefaultExpiration) + err := tc.Increment("tint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint8") + if !found { + t.Error("tint8 was not found") + } + if x.(int8) != 3 { + t.Error("tint8 is not 3:", x) + } +} + +func TestIncrementWithInt16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint16", int16(1), DefaultExpiration) + err := tc.Increment("tint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint16") + if !found { + t.Error("tint16 was not found") + } + if x.(int16) != 3 { + t.Error("tint16 is not 3:", x) + } +} + +func TestIncrementWithInt32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint32", int32(1), DefaultExpiration) + err := tc.Increment("tint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint32") + if !found { + t.Error("tint32 was not found") + } + if x.(int32) != 3 { + t.Error("tint32 is not 3:", x) + } +} + +func TestIncrementWithInt64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint64", int64(1), DefaultExpiration) + err := tc.Increment("tint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint64") + if !found { + t.Error("tint64 was not found") + } + if x.(int64) != 3 { + t.Error("tint64 is not 3:", x) + } +} + +func TestIncrementWithUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint", uint(1), DefaultExpiration) + err := tc.Increment("tuint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint") + if !found { + t.Error("tuint was not found") + } + if x.(uint) != 3 { + t.Error("tuint is not 3:", x) + } +} + +func TestIncrementWithUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuintptr", uintptr(1), DefaultExpiration) + err := tc.Increment("tuintptr", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuintptr") + if !found { + t.Error("tuintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("tuintptr is not 3:", x) + } +} + +func TestIncrementWithUint8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint8", uint8(1), DefaultExpiration) + err := tc.Increment("tuint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint8") + if !found { + t.Error("tuint8 was not found") + } + if x.(uint8) != 3 { + t.Error("tuint8 is not 3:", x) + } +} + +func TestIncrementWithUint16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint16", uint16(1), DefaultExpiration) + err := tc.Increment("tuint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuint16") + if !found { + t.Error("tuint16 was not found") + } + if x.(uint16) != 3 { + t.Error("tuint16 is not 3:", x) + } +} + +func TestIncrementWithUint32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint32", uint32(1), DefaultExpiration) + err := tc.Increment("tuint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint32") + if !found { + t.Error("tuint32 was not found") + } + if x.(uint32) != 3 { + t.Error("tuint32 is not 3:", x) + } +} + +func TestIncrementWithUint64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint64", uint64(1), DefaultExpiration) + err := tc.Increment("tuint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuint64") + if !found { + t.Error("tuint64 was not found") + } + if x.(uint64) != 3 { + t.Error("tuint64 is not 3:", x) + } +} + +func TestIncrementWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(1.5), DefaultExpiration) + err := tc.Increment("float32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(1.5), DefaultExpiration) + err := tc.Increment("float64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestIncrementFloatWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(1.5), DefaultExpiration) + err := tc.IncrementFloat("float32", 2) + if err != nil { + t.Error("Error incrementfloating:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementFloatWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(1.5), DefaultExpiration) + err := tc.IncrementFloat("float64", 2) + if err != nil { + t.Error("Error incrementfloating:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestDecrementWithInt(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int", int(5), DefaultExpiration) + err := tc.Decrement("int", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int") + if !found { + t.Error("int was not found") + } + if x.(int) != 3 { + t.Error("int is not 3:", x) + } +} + +func TestDecrementWithInt8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int8", int8(5), DefaultExpiration) + err := tc.Decrement("int8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int8") + if !found { + t.Error("int8 was not found") + } + if x.(int8) != 3 { + t.Error("int8 is not 3:", x) + } +} + +func TestDecrementWithInt16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int16", int16(5), DefaultExpiration) + err := tc.Decrement("int16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int16") + if !found { + t.Error("int16 was not found") + } + if x.(int16) != 3 { + t.Error("int16 is not 3:", x) + } +} + +func TestDecrementWithInt32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int32", int32(5), DefaultExpiration) + err := tc.Decrement("int32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int32") + if !found { + t.Error("int32 was not found") + } + if x.(int32) != 3 { + t.Error("int32 is not 3:", x) + } +} + +func TestDecrementWithInt64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int64", int64(5), DefaultExpiration) + err := tc.Decrement("int64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int64") + if !found { + t.Error("int64 was not found") + } + if x.(int64) != 3 { + t.Error("int64 is not 3:", x) + } +} + +func TestDecrementWithUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint", uint(5), DefaultExpiration) + err := tc.Decrement("uint", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint") + if !found { + t.Error("uint was not found") + } + if x.(uint) != 3 { + t.Error("uint is not 3:", x) + } +} + +func TestDecrementWithUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uintptr", uintptr(5), DefaultExpiration) + err := tc.Decrement("uintptr", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uintptr") + if !found { + t.Error("uintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("uintptr is not 3:", x) + } +} + +func TestDecrementWithUint8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint8", uint8(5), DefaultExpiration) + err := tc.Decrement("uint8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint8") + if !found { + t.Error("uint8 was not found") + } + if x.(uint8) != 3 { + t.Error("uint8 is not 3:", x) + } +} + +func TestDecrementWithUint16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint16", uint16(5), DefaultExpiration) + err := tc.Decrement("uint16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint16") + if !found { + t.Error("uint16 was not found") + } + if x.(uint16) != 3 { + t.Error("uint16 is not 3:", x) + } +} + +func TestDecrementWithUint32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint32", uint32(5), DefaultExpiration) + err := tc.Decrement("uint32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint32") + if !found { + t.Error("uint32 was not found") + } + if x.(uint32) != 3 { + t.Error("uint32 is not 3:", x) + } +} + +func TestDecrementWithUint64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint64", uint64(5), DefaultExpiration) + err := tc.Decrement("uint64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint64") + if !found { + t.Error("uint64 was not found") + } + if x.(uint64) != 3 { + t.Error("uint64 is not 3:", x) + } +} + +func TestDecrementWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(5.5), DefaultExpiration) + err := tc.Decrement("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(5.5), DefaultExpiration) + err := tc.Decrement("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3:", x) + } +} + +func TestDecrementFloatWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(5.5), DefaultExpiration) + err := tc.DecrementFloat("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementFloatWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(5.5), DefaultExpiration) + err := tc.DecrementFloat("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3:", x) + } +} + +func TestIncrementInt(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint", 1, DefaultExpiration) + n, err := tc.IncrementInt("tint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint") + if !found { + t.Error("tint was not found") + } + if x.(int) != 3 { + t.Error("tint is not 3:", x) + } +} + +func TestIncrementInt8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint8", int8(1), DefaultExpiration) + n, err := tc.IncrementInt8("tint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint8") + if !found { + t.Error("tint8 was not found") + } + if x.(int8) != 3 { + t.Error("tint8 is not 3:", x) + } +} + +func TestIncrementInt16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint16", int16(1), DefaultExpiration) + n, err := tc.IncrementInt16("tint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint16") + if !found { + t.Error("tint16 was not found") + } + if x.(int16) != 3 { + t.Error("tint16 is not 3:", x) + } +} + +func TestIncrementInt32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint32", int32(1), DefaultExpiration) + n, err := tc.IncrementInt32("tint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint32") + if !found { + t.Error("tint32 was not found") + } + if x.(int32) != 3 { + t.Error("tint32 is not 3:", x) + } +} + +func TestIncrementInt64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tint64", int64(1), DefaultExpiration) + n, err := tc.IncrementInt64("tint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint64") + if !found { + t.Error("tint64 was not found") + } + if x.(int64) != 3 { + t.Error("tint64 is not 3:", x) + } +} + +func TestIncrementUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint", uint(1), DefaultExpiration) + n, err := tc.IncrementUint("tuint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint") + if !found { + t.Error("tuint was not found") + } + if x.(uint) != 3 { + t.Error("tuint is not 3:", x) + } +} + +func TestIncrementUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuintptr", uintptr(1), DefaultExpiration) + n, err := tc.IncrementUintptr("tuintptr", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuintptr") + if !found { + t.Error("tuintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("tuintptr is not 3:", x) + } +} + +func TestIncrementUint8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint8", uint8(1), DefaultExpiration) + n, err := tc.IncrementUint8("tuint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint8") + if !found { + t.Error("tuint8 was not found") + } + if x.(uint8) != 3 { + t.Error("tuint8 is not 3:", x) + } +} + +func TestIncrementUint16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint16", uint16(1), DefaultExpiration) + n, err := tc.IncrementUint16("tuint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint16") + if !found { + t.Error("tuint16 was not found") + } + if x.(uint16) != 3 { + t.Error("tuint16 is not 3:", x) + } +} + +func TestIncrementUint32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint32", uint32(1), DefaultExpiration) + n, err := tc.IncrementUint32("tuint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint32") + if !found { + t.Error("tuint32 was not found") + } + if x.(uint32) != 3 { + t.Error("tuint32 is not 3:", x) + } +} + +func TestIncrementUint64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("tuint64", uint64(1), DefaultExpiration) + n, err := tc.IncrementUint64("tuint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint64") + if !found { + t.Error("tuint64 was not found") + } + if x.(uint64) != 3 { + t.Error("tuint64 is not 3:", x) + } +} + +func TestIncrementFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(1.5), DefaultExpiration) + n, err := tc.IncrementFloat32("float32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3.5 { + t.Error("Returned number is not 3.5:", n) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(1.5), DefaultExpiration) + n, err := tc.IncrementFloat64("float64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3.5 { + t.Error("Returned number is not 3.5:", n) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestDecrementInt8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int8", int8(5), DefaultExpiration) + n, err := tc.DecrementInt8("int8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int8") + if !found { + t.Error("int8 was not found") + } + if x.(int8) != 3 { + t.Error("int8 is not 3:", x) + } +} + +func TestDecrementInt16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int16", int16(5), DefaultExpiration) + n, err := tc.DecrementInt16("int16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int16") + if !found { + t.Error("int16 was not found") + } + if x.(int16) != 3 { + t.Error("int16 is not 3:", x) + } +} + +func TestDecrementInt32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int32", int32(5), DefaultExpiration) + n, err := tc.DecrementInt32("int32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int32") + if !found { + t.Error("int32 was not found") + } + if x.(int32) != 3 { + t.Error("int32 is not 3:", x) + } +} + +func TestDecrementInt64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int64", int64(5), DefaultExpiration) + n, err := tc.DecrementInt64("int64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int64") + if !found { + t.Error("int64 was not found") + } + if x.(int64) != 3 { + t.Error("int64 is not 3:", x) + } +} + +func TestDecrementUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint", uint(5), DefaultExpiration) + n, err := tc.DecrementUint("uint", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint") + if !found { + t.Error("uint was not found") + } + if x.(uint) != 3 { + t.Error("uint is not 3:", x) + } +} + +func TestDecrementUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uintptr", uintptr(5), DefaultExpiration) + n, err := tc.DecrementUintptr("uintptr", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uintptr") + if !found { + t.Error("uintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("uintptr is not 3:", x) + } +} + +func TestDecrementUint8(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint8", uint8(5), DefaultExpiration) + n, err := tc.DecrementUint8("uint8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint8") + if !found { + t.Error("uint8 was not found") + } + if x.(uint8) != 3 { + t.Error("uint8 is not 3:", x) + } +} + +func TestDecrementUint16(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint16", uint16(5), DefaultExpiration) + n, err := tc.DecrementUint16("uint16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint16") + if !found { + t.Error("uint16 was not found") + } + if x.(uint16) != 3 { + t.Error("uint16 is not 3:", x) + } +} + +func TestDecrementUint32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint32", uint32(5), DefaultExpiration) + n, err := tc.DecrementUint32("uint32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint32") + if !found { + t.Error("uint32 was not found") + } + if x.(uint32) != 3 { + t.Error("uint32 is not 3:", x) + } +} + +func TestDecrementUint64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint64", uint64(5), DefaultExpiration) + n, err := tc.DecrementUint64("uint64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint64") + if !found { + t.Error("uint64 was not found") + } + if x.(uint64) != 3 { + t.Error("uint64 is not 3:", x) + } +} + +func TestDecrementFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float32", float32(5), DefaultExpiration) + n, err := tc.DecrementFloat32("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("float64", float64(5), DefaultExpiration) + n, err := tc.DecrementFloat64("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3 { + t.Error("float64 is not 3:", x) + } +} + +func TestAdd(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + err := tc.Add("foo", "bar", DefaultExpiration) + if err != nil { + t.Error("Couldn't add foo even though it shouldn't exist") + } + err = tc.Add("foo", "baz", DefaultExpiration) + if err == nil { + t.Error("Successfully added another foo when it should have returned an error") + } +} + +func TestReplace(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + err := tc.Replace("foo", "bar", DefaultExpiration) + if err == nil { + t.Error("Replaced foo when it shouldn't exist") + } + tc.Set("foo", "bar", DefaultExpiration) + err = tc.Replace("foo", "bar", DefaultExpiration) + if err != nil { + t.Error("Couldn't replace existing key foo") + } +} + +func TestDelete(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", "bar", DefaultExpiration) + tc.Delete("foo") + x, found := tc.Get("foo") + if found { + t.Error("foo was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } +} + +func TestItemCount(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", "1", DefaultExpiration) + tc.Set("bar", "2", DefaultExpiration) + tc.Set("baz", "3", DefaultExpiration) + if n := tc.ItemCount(); n != 3 { + t.Errorf("Item count is not 3: %d", n) + } +} + +func TestFlush(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", "bar", DefaultExpiration) + tc.Set("baz", "yes", DefaultExpiration) + tc.Flush() + x, found := tc.Get("foo") + if found { + t.Error("foo was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } + x, found = tc.Get("baz") + if found { + t.Error("baz was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } +} + +func TestIncrementOverflowInt(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("int8", int8(127), DefaultExpiration) + err := tc.Increment("int8", 1) + if err != nil { + t.Error("Error incrementing int8:", err) + } + x, _ := tc.Get("int8") + int8 := x.(int8) + if int8 != -128 { + t.Error("int8 did not overflow as expected; value:", int8) + } + +} + +func TestIncrementOverflowUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint8", uint8(255), DefaultExpiration) + err := tc.Increment("uint8", 1) + if err != nil { + t.Error("Error incrementing int8:", err) + } + x, _ := tc.Get("uint8") + uint8 := x.(uint8) + if uint8 != 0 { + t.Error("uint8 did not overflow as expected; value:", uint8) + } +} + +func TestDecrementUnderflowUint(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("uint8", uint8(0), DefaultExpiration) + err := tc.Decrement("uint8", 1) + if err != nil { + t.Error("Error decrementing int8:", err) + } + x, _ := tc.Get("uint8") + uint8 := x.(uint8) + if uint8 != 255 { + t.Error("uint8 did not underflow as expected; value:", uint8) + } +} + +func TestOnEvicted(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", 3, DefaultExpiration) + if tc.onEvicted != nil { + t.Fatal("tc.onEvicted is not nil") + } + works := false + tc.OnEvicted(func(k string, v interface{}) { + if k == "foo" && v.(int) == 3 { + works = true + } + tc.Set("bar", 4, DefaultExpiration) + }) + tc.Delete("foo") + x, _ := tc.Get("bar") + if !works { + t.Error("works bool not true") + } + if x.(int) != 4 { + t.Error("bar was not 4") + } +} + +func TestCacheSerialization(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + testFillAndSerialize(t, tc) + + // Check if gob.Register behaves properly even after multiple gob.Register + // on c.Items (many of which will be the same type) + testFillAndSerialize(t, tc) +} + +func testFillAndSerialize(t *testing.T, tc *Cache) { + tc.Set("a", "a", DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", "c", DefaultExpiration) + tc.Set("expired", "foo", 1*time.Millisecond) + tc.Set("*struct", &TestStruct{Num: 1}, DefaultExpiration) + tc.Set("[]struct", []TestStruct{ + {Num: 2}, + {Num: 3}, + }, DefaultExpiration) + tc.Set("[]*struct", []*TestStruct{ + &TestStruct{Num: 4}, + &TestStruct{Num: 5}, + }, DefaultExpiration) + tc.Set("structception", &TestStruct{ + Num: 42, + Children: []*TestStruct{ + &TestStruct{Num: 6174}, + &TestStruct{Num: 4716}, + }, + }, DefaultExpiration) + + fp := &bytes.Buffer{} + err := tc.Save(fp) + if err != nil { + t.Fatal("Couldn't save cache to fp:", err) + } + + oc := New(DefaultExpiration, 0, maxItemsCount) + err = oc.Load(fp) + if err != nil { + t.Fatal("Couldn't load cache from fp:", err) + } + + a, found := oc.Get("a") + if !found { + t.Error("a was not found") + } + if a.(string) != "a" { + t.Error("a is not a") + } + + b, found := oc.Get("b") + if !found { + t.Error("b was not found") + } + if b.(string) != "b" { + t.Error("b is not b") + } + + c, found := oc.Get("c") + if !found { + t.Error("c was not found") + } + if c.(string) != "c" { + t.Error("c is not c") + } + + <-time.After(5 * time.Millisecond) + _, found = oc.Get("expired") + if found { + t.Error("expired was found") + } + + s1, found := oc.Get("*struct") + if !found { + t.Error("*struct was not found") + } + if s1.(*TestStruct).Num != 1 { + t.Error("*struct.Num is not 1") + } + + s2, found := oc.Get("[]struct") + if !found { + t.Error("[]struct was not found") + } + s2r := s2.([]TestStruct) + if len(s2r) != 2 { + t.Error("Length of s2r is not 2") + } + if s2r[0].Num != 2 { + t.Error("s2r[0].Num is not 2") + } + if s2r[1].Num != 3 { + t.Error("s2r[1].Num is not 3") + } + + s3, found := oc.get("[]*struct") + if !found { + t.Error("[]*struct was not found") + } + s3r := s3.([]*TestStruct) + if len(s3r) != 2 { + t.Error("Length of s3r is not 2") + } + if s3r[0].Num != 4 { + t.Error("s3r[0].Num is not 4") + } + if s3r[1].Num != 5 { + t.Error("s3r[1].Num is not 5") + } + + s4, found := oc.get("structception") + if !found { + t.Error("structception was not found") + } + s4r := s4.(*TestStruct) + if len(s4r.Children) != 2 { + t.Error("Length of s4r.Children is not 2") + } + if s4r.Children[0].Num != 6174 { + t.Error("s4r.Children[0].Num is not 6174") + } + if s4r.Children[1].Num != 4716 { + t.Error("s4r.Children[1].Num is not 4716") + } +} + +func TestFileSerialization(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Add("a", "a", DefaultExpiration) + tc.Add("b", "b", DefaultExpiration) + f, err := ioutil.TempFile("", "go-cache-cache.dat") + if err != nil { + t.Fatal("Couldn't create cache file:", err) + } + fname := f.Name() + f.Close() + tc.SaveFile(fname) + + oc := New(DefaultExpiration, 0, maxItemsCount) + oc.Add("a", "aa", 0) // this should not be overwritten + err = oc.LoadFile(fname) + if err != nil { + t.Error(err) + } + a, found := oc.Get("a") + if !found { + t.Error("a was not found") + } + astr := a.(string) + if astr != "aa" { + if astr == "a" { + t.Error("a was overwritten") + } else { + t.Error("a is not aa") + } + } + b, found := oc.Get("b") + if !found { + t.Error("b was not found") + } + if b.(string) != "b" { + t.Error("b is not b") + } +} + +func TestSerializeUnserializable(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + ch := make(chan bool, 1) + ch <- true + tc.Set("chan", ch, DefaultExpiration) + fp := &bytes.Buffer{} + err := tc.Save(fp) // this should fail gracefully + if err.Error() != "gob NewTypeObject can't handle type: chan bool" { + t.Error("Error from Save was not gob NewTypeObject can't handle type chan bool:", err) + } +} + +func BenchmarkCacheGetExpiring(b *testing.B) { + benchmarkCacheGet(b, 5*time.Minute) +} + +func BenchmarkCacheGetNotExpiring(b *testing.B) { + benchmarkCacheGet(b, NoExpiration) +} + +func benchmarkCacheGet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0, maxItemsCount) + tc.Set("foo", "bar", DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Get("foo") + } +} + +func BenchmarkRWMutexMapGet(b *testing.B) { + b.StopTimer() + m := map[string]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } +} + +func BenchmarkRWMutexInterfaceMapGetStruct(b *testing.B) { + b.StopTimer() + s := struct{ name string }{name: "foo"} + m := map[interface{}]string{ + s: "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m[s] + mu.RUnlock() + } +} + +func BenchmarkRWMutexInterfaceMapGetString(b *testing.B) { + b.StopTimer() + m := map[interface{}]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } +} + +func BenchmarkCacheGetConcurrentExpiring(b *testing.B) { + benchmarkCacheGetConcurrent(b, 5*time.Minute) +} + +func BenchmarkCacheGetConcurrentNotExpiring(b *testing.B) { + benchmarkCacheGetConcurrent(b, NoExpiration) +} + +func benchmarkCacheGetConcurrent(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0, maxItemsCount) + tc.Set("foo", "bar", DefaultExpiration) + wg := new(sync.WaitGroup) + workers := runtime.NumCPU() + each := b.N / workers + wg.Add(workers) + b.StartTimer() + for i := 0; i < workers; i++ { + go func() { + for j := 0; j < each; j++ { + tc.Get("foo") + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkRWMutexMapGetConcurrent(b *testing.B) { + b.StopTimer() + m := map[string]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + wg := new(sync.WaitGroup) + workers := runtime.NumCPU() + each := b.N / workers + wg.Add(workers) + b.StartTimer() + for i := 0; i < workers; i++ { + go func() { + for j := 0; j < each; j++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkCacheGetManyConcurrentExpiring(b *testing.B) { + benchmarkCacheGetManyConcurrent(b, 5*time.Minute) +} + +func BenchmarkCacheGetManyConcurrentNotExpiring(b *testing.B) { + benchmarkCacheGetManyConcurrent(b, NoExpiration) +} + +func benchmarkCacheGetManyConcurrent(b *testing.B, exp time.Duration) { + // This is the same as BenchmarkCacheGetConcurrent, but its result + // can be compared against BenchmarkShardedCacheGetManyConcurrent + // in sharded_test.go. + b.StopTimer() + n := 10000 + tc := New(exp, 0, maxItemsCount) + keys := make([]string, n) + for i := 0; i < n; i++ { + k := "foo" + strconv.Itoa(i) + keys[i] = k + tc.Set(k, "bar", DefaultExpiration) + } + each := b.N / n + wg := new(sync.WaitGroup) + wg.Add(n) + for _, v := range keys { + go func(k string) { + for j := 0; j < each; j++ { + tc.Get(k) + } + wg.Done() + }(v) + } + b.StartTimer() + wg.Wait() +} + +func BenchmarkCacheSetExpiring(b *testing.B) { + benchmarkCacheSet(b, 5*time.Minute) +} + +func BenchmarkCacheSetNotExpiring(b *testing.B) { + benchmarkCacheSet(b, NoExpiration) +} + +func benchmarkCacheSet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0, maxItemsCount) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Set("foo", "bar", DefaultExpiration) + } +} + +func BenchmarkRWMutexMapSet(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + mu.Unlock() + } +} + +func BenchmarkCacheSetDelete(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0, maxItemsCount) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Set("foo", "bar", DefaultExpiration) + tc.Delete("foo") + } +} + +func BenchmarkRWMutexMapSetDelete(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + mu.Unlock() + mu.Lock() + delete(m, "foo") + mu.Unlock() + } +} + +func BenchmarkCacheSetDeleteSingleLock(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0, maxItemsCount) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.mu.Lock() + tc.set("foo", "bar", DefaultExpiration) + tc.delete("foo") + tc.mu.Unlock() + } +} + +func BenchmarkRWMutexMapSetDeleteSingleLock(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + delete(m, "foo") + mu.Unlock() + } +} + +func BenchmarkIncrementInt(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0, maxItemsCount) + tc.Set("foo", 0, DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.IncrementInt("foo", 1) + } +} + +func BenchmarkDeleteExpiredLoop(b *testing.B) { + b.StopTimer() + tc := New(5*time.Minute, 0, maxItemsCount) + tc.mu.Lock() + for i := 0; i < 100000; i++ { + tc.set(strconv.Itoa(i), "bar", DefaultExpiration) + } + tc.mu.Unlock() + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.DeleteExpired() + } +} + +func TestGetWithExpiration(t *testing.T) { + tc := New(DefaultExpiration, 0, maxItemsCount) + + a, expiration, found := tc.GetWithExpiration("a") + if found || a != nil || !expiration.IsZero() { + t.Error("Getting A found value that shouldn't exist:", a) + } + + b, expiration, found := tc.GetWithExpiration("b") + if found || b != nil || !expiration.IsZero() { + t.Error("Getting B found value that shouldn't exist:", b) + } + + c, expiration, found := tc.GetWithExpiration("c") + if found || c != nil || !expiration.IsZero() { + t.Error("Getting C found value that shouldn't exist:", c) + } + + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", 3.5, DefaultExpiration) + tc.Set("d", 1, NoExpiration) + tc.Set("e", 1, 50*time.Millisecond) + + x, expiration, found := tc.GetWithExpiration("a") + if !found { + t.Error("a was not found while getting a2") + } + if x == nil { + t.Error("x for a is nil") + } else if a2 := x.(int); a2+2 != 3 { + t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) + } + if !expiration.IsZero() { + t.Error("expiration for a is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("b") + if !found { + t.Error("b was not found while getting b2") + } + if x == nil { + t.Error("x for b is nil") + } else if b2 := x.(string); b2+"B" != "bB" { + t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) + } + if !expiration.IsZero() { + t.Error("expiration for b is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("c") + if !found { + t.Error("c was not found while getting c2") + } + if x == nil { + t.Error("x for c is nil") + } else if c2 := x.(float64); c2+1.2 != 4.7 { + t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) + } + if !expiration.IsZero() { + t.Error("expiration for c is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("d") + if !found { + t.Error("d was not found while getting d2") + } + if x == nil { + t.Error("x for d is nil") + } else if d2 := x.(int); d2+2 != 3 { + t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2) + } + if !expiration.IsZero() { + t.Error("expiration for d is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("e") + if !found { + t.Error("e was not found while getting e2") + } + if x == nil { + t.Error("x for e is nil") + } else if e2 := x.(int); e2+2 != 3 { + t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2) + } + if expiration.UnixNano() != tc.items["e"].Expiration { + t.Error("expiration for e is not the correct time") + } + if expiration.UnixNano() < time.Now().UnixNano() { + t.Error("expiration for e is in the past") + } +} diff --git a/memorycacher/sharded.go b/memorycacher/sharded.go new file mode 100644 index 0000000..e507068 --- /dev/null +++ b/memorycacher/sharded.go @@ -0,0 +1,221 @@ +/* + * @Author: patrickmn,gitsrc + * @Date: 2020-07-09 13:17:30 + * @LastEditors: gitsrc + * @LastEditTime: 2020-07-09 13:22:41 + * @FilePath: /ServiceCar/utils/memorycache/sharded.go + */ + +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package memorycacher + +import ( + "crypto/rand" + "math" + "math/big" + insecurerand "math/rand" + "os" + "runtime" + "time" +) + +// This is an experimental and unexported (for now) attempt at making a cache +// with better algorithmic complexity than the standard one, namely by +// preventing write locks of the entire cache when an item is added. As of the +// time of writing, the overhead of selecting buckets results in cache +// operations being about twice as slow as for the standard cache with small +// total cache sizes, and faster for larger ones. +// +// See cache_test.go for a few benchmarks. + +type unexportedShardedCache struct { + *shardedCache +} + +type shardedCache struct { + seed uint32 + m uint32 + cs []*cache + lastCleanTime time.Time //Last cleanup time + janitor *shardedJanitor +} + +// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead. +func djb33(seed uint32, k string) uint32 { + var ( + l = uint32(len(k)) + d = 5381 + seed + l + i = uint32(0) + ) + // Why is all this 5x faster than a for loop? + if l >= 4 { + for i < l-4 { + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + d = (d * 33) ^ uint32(k[i+3]) + i += 4 + } + } + switch l - i { + case 1: + case 2: + d = (d * 33) ^ uint32(k[i]) + case 3: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + case 4: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + } + return d ^ (d >> 16) +} + +func (sc *shardedCache) bucket(k string) *cache { + return sc.cs[djb33(sc.seed, k)%sc.m] +} + +func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) { + sc.bucket(k).Set(k, x, d) +} + +func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Add(k, x, d) +} + +func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Replace(k, x, d) +} + +func (sc *shardedCache) Get(k string) (interface{}, bool) { + return sc.bucket(k).Get(k) +} + +func (sc *shardedCache) Increment(k string, n int64) error { + return sc.bucket(k).Increment(k, n) +} + +func (sc *shardedCache) IncrementFloat(k string, n float64) error { + return sc.bucket(k).IncrementFloat(k, n) +} + +func (sc *shardedCache) Decrement(k string, n int64) error { + return sc.bucket(k).Decrement(k, n) +} + +func (sc *shardedCache) Delete(k string) { + sc.bucket(k).Delete(k) +} + +func (sc *shardedCache) DeleteExpired() { + for _, v := range sc.cs { + v.DeleteExpired() + } +} + +// Returns the items in the cache. This may include items that have expired, +// but have not yet been cleaned up. If this is significant, the Expiration +// fields of the items should be checked. Note that explicit synchronization +// is needed to use a cache and its corresponding Items() return values at +// the same time, as the maps are shared. +func (sc *shardedCache) Items() []map[string]Item { + res := make([]map[string]Item, len(sc.cs)) + for i, v := range sc.cs { + res[i] = v.Items() + } + return res +} + +func (sc *shardedCache) Flush() { + for _, v := range sc.cs { + v.Flush() + } +} + +type shardedJanitor struct { + Interval time.Duration + shoudClean chan bool //Signal should be cleaned up + stop chan bool +} + +func (j *shardedJanitor) Run(sc *shardedCache) { + j.stop = make(chan bool) + tick := time.Tick(j.Interval) + for { + select { + case <-tick: + sc.DeleteExpired() + case <-j.shoudClean: //If received should Clean signal + + sc.DeleteExpired() + case <-j.stop: + return + } + } +} + +func stopShardedJanitor(sc *unexportedShardedCache) { + sc.janitor.stop <- true +} + +func runShardedJanitor(sc *shardedCache, ci time.Duration) { + j := &shardedJanitor{ + Interval: ci, + shoudClean: make(chan bool), + } + sc.janitor = j + go j.Run(sc) +} + +func newShardedCache(n int, de time.Duration, maxItemsCount int) *shardedCache { + max := big.NewInt(0).SetUint64(uint64(math.MaxUint32)) + rnd, err := rand.Int(rand.Reader, max) + var seed uint32 + if err != nil { + os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n")) + seed = insecurerand.Uint32() + } else { + seed = uint32(rnd.Uint64()) + } + sc := &shardedCache{ + seed: seed, + m: uint32(n), + cs: make([]*cache, n), + } + for i := 0; i < n; i++ { + c := &cache{ + defaultExpiration: de, + items: map[string]Item{}, + maxItemsCount: maxItemsCount, + lastCleanTime: time.Now(), + } + sc.cs[i] = c + } + return sc +} + +func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int, maxItemsCount int) *unexportedShardedCache { + if defaultExpiration == 0 { + defaultExpiration = -1 + } + sc := newShardedCache(shards, defaultExpiration, maxItemsCount) + SC := &unexportedShardedCache{sc} + if cleanupInterval > 0 { + runShardedJanitor(sc, cleanupInterval) + runtime.SetFinalizer(SC, stopShardedJanitor) + } + return SC +} diff --git a/memorycacher/sharded_test.go b/memorycacher/sharded_test.go new file mode 100644 index 0000000..19fe6cc --- /dev/null +++ b/memorycacher/sharded_test.go @@ -0,0 +1,105 @@ +/* + * @Author: patrickmn,gitsrc + * @Date: 2020-07-09 13:17:30 + * @LastEditors: gitsrc + * @LastEditTime: 2020-07-10 10:06:42 + * @FilePath: /ServiceCar/utils/memorycacher/sharded_test.go + */ +/* +Copyright 2022-present The ZTDBP Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package memorycacher + +import ( + "strconv" + "sync" + "testing" + "time" +) + +// func TestDjb33(t *testing.T) { +// } + +var shardedKeys = []string{ + "f", + "fo", + "foo", + "barf", + "barfo", + "foobar", + "bazbarf", + "bazbarfo", + "bazbarfoo", + "foobarbazq", + "foobarbazqu", + "foobarbazquu", + "foobarbazquux", +} + +func TestShardedCache(t *testing.T) { + tc := unexportedNewSharded(DefaultExpiration, time.Minute*2, 13, 100) + for _, v := range shardedKeys { + tc.Set(v, "value", DefaultExpiration) + } +} + +func BenchmarkShardedCacheGetExpiring(b *testing.B) { + benchmarkShardedCacheGet(b, 5*time.Minute) +} + +func BenchmarkShardedCacheGetNotExpiring(b *testing.B) { + benchmarkShardedCacheGet(b, NoExpiration) +} + +func benchmarkShardedCacheGet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := unexportedNewSharded(exp, 0, 10, 0) + tc.Set("foobarba", "zquux", DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Get("foobarba") + } +} + +func BenchmarkShardedCacheGetManyConcurrentExpiring(b *testing.B) { + benchmarkShardedCacheGetManyConcurrent(b, 5*time.Minute) +} + +func BenchmarkShardedCacheGetManyConcurrentNotExpiring(b *testing.B) { + benchmarkShardedCacheGetManyConcurrent(b, NoExpiration) +} + +func benchmarkShardedCacheGetManyConcurrent(b *testing.B, exp time.Duration) { + b.StopTimer() + n := 10000 + tsc := unexportedNewSharded(exp, 0, 20, 100000) + keys := make([]string, n) + for i := 0; i < n; i++ { + k := "foo" + strconv.Itoa(i) + keys[i] = k + tsc.Set(k, "bar", DefaultExpiration) + } + each := b.N / n + wg := new(sync.WaitGroup) + wg.Add(n) + for _, v := range keys { + go func(k string) { + for j := 0; j < each; j++ { + tsc.Get(k) + } + wg.Done() + }(v) + } + b.StartTimer() + wg.Wait() +}