diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1e9c4bf
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+.idea
+cert
+.vscode
+bin
+.DS_store
+.history
+vendor
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..b7f21fc
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/zta-tools.iml b/.idea/zta-tools.iml
new file mode 100644
index 0000000..5e764c4
--- /dev/null
+++ b/.idea/zta-tools.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/attrmgr/attrmgr.go b/attrmgr/attrmgr.go
new file mode 100644
index 0000000..7c692ca
--- /dev/null
+++ b/attrmgr/attrmgr.go
@@ -0,0 +1,231 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package attrmgr
+
+import (
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/asn1"
+ "encoding/json"
+ "fmt"
+
+ "github.com/pkg/errors"
+)
+
+var (
+ // AttrOID is the ASN.1 object identifier for an attribute extension in an
+ // X509 certificate
+ AttrOID = asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6, 7, 8, 1}
+ // AttrOIDString is the string version of AttrOID
+ AttrOIDString = "1.2.3.4.5.6.7.8.1"
+)
+
+// Attribute is a name/value pair
+type Attribute interface {
+ // GetName returns the name of the attribute
+ GetName() string
+ // GetValue returns the value of the attribute
+ GetValue() interface{}
+}
+
+// AttributeRequest is a request for an attribute
+type AttributeRequest interface {
+ // GetName returns the name of an attribute
+ GetName() string
+ // IsRequired returns true if the attribute is required
+ IsRequired() bool
+}
+
+// New constructs an attribute manager
+func New() *Mgr { return &Mgr{} }
+
+// Mgr is the attribute manager and is the main object for this package
+type Mgr struct{}
+
+// ProcessAttributeRequestsForCert add attributes to an X509 certificate, given
+// attribute requests and attributes.
+func (mgr *Mgr) ProcessAttributeRequestsForCert(requests []AttributeRequest, attributes []Attribute, cert *x509.Certificate) error {
+ attrs, err := mgr.ProcessAttributeRequests(requests, attributes)
+ if err != nil {
+ return err
+ }
+ return mgr.AddAttributesToCert(attrs, cert)
+}
+
+// ProcessAttributeRequests takes an array of attribute requests and an identity's attributes
+// and returns an Attributes object containing the requested attributes.
+func (mgr *Mgr) ProcessAttributeRequests(requests []AttributeRequest, attributes []Attribute) (*Attributes, error) {
+ attrsMap := map[string]interface{}{}
+ attrs := &Attributes{Attrs: attrsMap}
+ missingRequiredAttrs := []string{}
+ // For each of the attribute requests
+ for _, req := range requests {
+ // Get the attribute
+ name := req.GetName()
+ attr := getAttrByName(name, attributes)
+ if attr == nil {
+ if req.IsRequired() {
+ // Didn't find attribute and it was required; return error below
+ missingRequiredAttrs = append(missingRequiredAttrs, name)
+ }
+ // Skip attribute requests which aren't required
+ continue
+ }
+ attrsMap[name] = attr.GetValue()
+ }
+ if len(missingRequiredAttrs) > 0 {
+ return nil, errors.Errorf("The following required attributes are missing: %+v",
+ missingRequiredAttrs)
+ }
+ return attrs, nil
+}
+
+// ToPkixExtension ...
+func (mgr *Mgr) ToPkixExtension(attrs *Attributes) (pkix.Extension, error) {
+ buf, err := json.Marshal(attrs)
+ if err != nil {
+ return pkix.Extension{}, errors.Wrap(err, "Failed to marshal attributes")
+ }
+ ext := pkix.Extension{
+ Id: AttrOID,
+ Critical: false,
+ Value: buf,
+ }
+ return ext, nil
+}
+
+// AddAttributesToCertRequest ...
+func (mgr *Mgr) AddAttributesToCertRequest(attrs *Attributes, cert *x509.CertificateRequest) error {
+ buf, err := json.Marshal(attrs)
+ if err != nil {
+ return errors.Wrap(err, "Failed to marshal attributes")
+ }
+ ext := pkix.Extension{
+ Id: AttrOID,
+ Critical: false,
+ Value: buf,
+ }
+ cert.Extensions = append(cert.Extensions, ext)
+ return nil
+}
+
+// AddAttributesToCert adds public attribute info to an X509 certificate.
+func (mgr *Mgr) AddAttributesToCert(attrs *Attributes, cert *x509.Certificate) error {
+ buf, err := json.Marshal(attrs)
+ if err != nil {
+ return errors.Wrap(err, "Failed to marshal attributes")
+ }
+ ext := pkix.Extension{
+ Id: AttrOID,
+ Critical: false,
+ Value: buf,
+ }
+ cert.Extensions = append(cert.Extensions, ext)
+ return nil
+}
+
+// GetAttributesFromCert gets the attributes from a certificate.
+func (mgr *Mgr) GetAttributesFromCert(cert *x509.Certificate) (*Attributes, error) {
+ // Get certificate attributes from the certificate if it exists
+ buf, err := getAttributesFromCert(cert)
+ if err != nil {
+ return nil, err
+ }
+ // Unmarshal into attributes object
+ attrs := &Attributes{}
+ if buf != nil {
+ err := json.Unmarshal(buf, attrs)
+ if err != nil {
+ return nil, errors.Wrap(err, "Failed to unmarshal attributes from certificate")
+ }
+ }
+ return attrs, nil
+}
+
+// Attributes contains attribute names and values
+type Attributes struct {
+ Attrs map[string]interface{} `json:"attrs"`
+}
+
+// Names returns the names of the attributes
+func (a *Attributes) Names() []string {
+ i := 0
+ names := make([]string, len(a.Attrs))
+ for name := range a.Attrs {
+ names[i] = name
+ i++
+ }
+ return names
+}
+
+// Contains returns true if the named attribute is found
+func (a *Attributes) Contains(name string) bool {
+ _, ok := a.Attrs[name]
+ return ok
+}
+
+// Value returns an attribute's value
+func (a *Attributes) Value(name string) (interface{}, bool, error) {
+ attr, ok := a.Attrs[name]
+ return attr, ok, nil
+}
+
+// True returns nil if the value of attribute 'name' is true;
+// otherwise, an appropriate error is returned.
+func (a *Attributes) True(name string) error {
+ val, ok, err := a.Value(name)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ return fmt.Errorf("Attribute '%s' was not found", name)
+ }
+ if val != "true" {
+ return fmt.Errorf("Attribute '%s' is not true", name)
+ }
+ return nil
+}
+
+// Get the attribute info from a certificate extension, or return nil if not found
+func getAttributesFromCert(cert *x509.Certificate) ([]byte, error) {
+ for _, ext := range cert.Extensions {
+ if isAttrOID(ext.Id) {
+ return ext.Value, nil
+ }
+ }
+ return nil, nil
+}
+
+// Is the object ID equal to the attribute info object ID?
+func isAttrOID(oid asn1.ObjectIdentifier) bool {
+ if len(oid) != len(AttrOID) {
+ return false
+ }
+ for idx, val := range oid {
+ if val != AttrOID[idx] {
+ return false
+ }
+ }
+ return true
+}
+
+// Get an attribute from 'attrs' by its name, or nil if not found
+func getAttrByName(name string, attrs []Attribute) Attribute {
+ for _, attr := range attrs {
+ if attr.GetName() == name {
+ return attr
+ }
+ }
+ return nil
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..9e5baef
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,24 @@
+module github.com/ztalab/zta-tools
+
+go 1.19
+
+require (
+ github.com/LyricTian/queue v1.3.0
+ github.com/go-redis/redis v6.15.9+incompatible
+ github.com/pkg/errors v0.9.1
+ github.com/sirupsen/logrus v1.9.0
+ github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1
+ github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5
+)
+
+require (
+ github.com/garyburd/redigo v1.6.3 // indirect
+ github.com/json-iterator/go v1.1.12 // indirect
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
+ github.com/modern-go/reflect2 v1.0.2 // indirect
+ github.com/nxadm/tail v1.4.8 // indirect
+ go.uber.org/atomic v1.9.0 // indirect
+ go.uber.org/multierr v1.8.0 // indirect
+ go.uber.org/zap v1.24.0 // indirect
+ golang.org/x/sys v0.4.0 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..4b4e578
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,59 @@
+github.com/LyricTian/queue v1.3.0 h1:1xEFZlteW6iu5Qbrz7ZsiSKMKaxY1bQHsbx0jrB1pDA=
+github.com/LyricTian/queue v1.3.0/go.mod h1:pbkoplz/zRToCay3pRjz75P8fQAgvkRKJdEzVUQYhXY=
+github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
+github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI=
+github.com/garyburd/redigo v1.6.3 h1:HCeeRluvAgMusMomi1+6Y5dmFOdYV/JzoRrrbFlkGIc=
+github.com/garyburd/redigo v1.6.3/go.mod h1:rTb6epsqigu3kYKBnaF028A7Tf/Aw5s0cqA47doKKqw=
+github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg=
+github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
+github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
+github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
+github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
+github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
+github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
+github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
+github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
+github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1 h1:++PKKjmPGMivF1f6wTyZmu4YkBfVhxQW3I1uH9AKrzo=
+github.com/ztdbp/ZACA v0.0.0-20230130085917-758df0add0c1/go.mod h1:DM+b+eGl8VRwit889SUB6ylBgfsrI9r7EUn3axB0K5I=
+github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5 h1:NMoBEvuJ2/edm+GAg0dJyPhdvEy3qEjX2+s4YssyXF4=
+github.com/ztdbp/ZASentinel v0.0.0-20230117034106-d7354bb47cc5/go.mod h1:HLhrcqY4QjHdVR/VN+3CYPOH6cnBFjq2lOj0SLc3/Hg=
+go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
+go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
+go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8=
+go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
+go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
+go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
+golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw=
+golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
+golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
diff --git a/influxdb/client.go b/influxdb/client.go
new file mode 100644
index 0000000..b1c2f92
--- /dev/null
+++ b/influxdb/client.go
@@ -0,0 +1,60 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package influxdb
+
+import (
+ _ "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client" // this is important because of the bug in go mod
+ client "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2"
+ "github.com/ztdbp/ZACA/pkg/logger"
+)
+
+// UDPClient UDP Client
+type UDPClient struct {
+ Conf *Config
+ BatchPointsConfig client.BatchPointsConfig
+ client client.Client
+}
+
+func (p *UDPClient) newUDPV1Client() *UDPClient {
+ udpClient, err := client.NewUDPClient(client.UDPConfig{
+ Addr: p.Conf.UDPAddress,
+ })
+ if err != nil {
+ logger.Errorf("InfluxDBUDPClient err: %v", err)
+ }
+ p.client = udpClient
+ return p
+}
+
+// FluxDBUDPWrite ...
+func (p *UDPClient) FluxDBUDPWrite(bp client.BatchPoints) (err error) {
+ err = p.newUDPV1Client().client.Write(bp)
+ return
+}
+
+// HTTPClient HTTP Client
+type HTTPClient struct {
+ Client client.Client
+ BatchPointsConfig client.BatchPointsConfig
+}
+
+// FluxDBHttpWrite ...
+func (p *HTTPClient) FluxDBHttpWrite(bp client.BatchPoints) (err error) {
+ return p.Client.Write(bp)
+}
+
+// FluxDBHttpClose ...
+func (p *HTTPClient) FluxDBHttpClose() (err error) {
+ return p.Client.Close()
+}
diff --git a/influxdb/client/influxdb.go b/influxdb/client/influxdb.go
new file mode 100644
index 0000000..d605624
--- /dev/null
+++ b/influxdb/client/influxdb.go
@@ -0,0 +1,881 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package client // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client"
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "path"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models"
+)
+
+const (
+ // DefaultHost is the default host used to connect to an InfluxDB instance
+ DefaultHost = "localhost"
+
+ // DefaultPort is the default port used to connect to an InfluxDB instance
+ DefaultPort = 8086
+
+ // DefaultTimeout is the default connection timeout used to connect to an InfluxDB instance
+ DefaultTimeout = 0
+)
+
+// Query is used to send a command to the server. Both Command and Database are required.
+type Query struct {
+ Command string
+ Database string
+
+ // RetentionPolicy tells the server which retention policy to use by default.
+ // This option is only effective when querying a server of version 1.6.0 or later.
+ RetentionPolicy string
+
+ // Chunked tells the server to send back chunked responses. This places
+ // less load on the server by sending back chunks of the response rather
+ // than waiting for the entire response all at once.
+ Chunked bool
+
+ // ChunkSize sets the maximum number of rows that will be returned per
+ // chunk. Chunks are either divided based on their series or if they hit
+ // the chunk size limit.
+ //
+ // Chunked must be set to true for this option to be used.
+ ChunkSize int
+
+ // NodeID sets the data node to use for the query results. This option only
+ // has any effect in the enterprise version of the software where there can be
+ // more than one data node and is primarily useful for analyzing differences in
+ // data. The default behavior is to automatically select the appropriate data
+ // nodes to retrieve all of the data. On a database where the number of data nodes
+ // is greater than the replication factor, it is expected that setting this option
+ // will only retrieve partial data.
+ NodeID int
+}
+
+// ParseConnectionString will parse a string to create a valid connection URL
+func ParseConnectionString(path string, ssl bool) (url.URL, error) {
+ var host string
+ var port int
+
+ h, p, err := net.SplitHostPort(path)
+ if err != nil {
+ if path == "" {
+ host = DefaultHost
+ } else {
+ host = path
+ }
+ // If they didn't specify a port, always use the default port
+ port = DefaultPort
+ } else {
+ host = h
+ port, err = strconv.Atoi(p)
+ if err != nil {
+ return url.URL{}, fmt.Errorf("invalid port number %q: %s\n", path, err)
+ }
+ }
+
+ u := url.URL{
+ Scheme: "http",
+ Host: host,
+ }
+ if ssl {
+ u.Scheme = "https"
+ if port != 443 {
+ u.Host = net.JoinHostPort(host, strconv.Itoa(port))
+ }
+ } else if port != 80 {
+ u.Host = net.JoinHostPort(host, strconv.Itoa(port))
+ }
+
+ return u, nil
+}
+
+// Config is used to specify what server to connect to.
+// URL: The URL of the server connecting to.
+// Username/Password are optional. They will be passed via basic auth if provided.
+// UserAgent: If not provided, will default "InfluxDBClient",
+// Timeout: If not provided, will default to 0 (no timeout)
+type Config struct {
+ URL url.URL
+ UnixSocket string
+ Username string
+ Password string
+ UserAgent string
+ Timeout time.Duration
+ Precision string
+ WriteConsistency string
+ UnsafeSsl bool
+ Proxy func(req *http.Request) (*url.URL, error)
+ TLS *tls.Config
+}
+
+// NewConfig will create a config to be used in connecting to the client
+func NewConfig() Config {
+ return Config{
+ Timeout: DefaultTimeout,
+ }
+}
+
+// Client is used to make calls to the server.
+type Client struct {
+ url url.URL
+ unixSocket string
+ username string
+ password string
+ httpClient *http.Client
+ userAgent string
+ precision string
+}
+
+const (
+ // ConsistencyOne requires at least one data node acknowledged a write.
+ ConsistencyOne = "one"
+
+ // ConsistencyAll requires all data nodes to acknowledge a write.
+ ConsistencyAll = "all"
+
+ // ConsistencyQuorum requires a quorum of data nodes to acknowledge a write.
+ ConsistencyQuorum = "quorum"
+
+ // ConsistencyAny allows for hinted hand off, potentially no write happened yet.
+ ConsistencyAny = "any"
+)
+
+// NewClient will instantiate and return a connected client to issue commands to the server.
+func NewClient(c Config) (*Client, error) {
+ tlsConfig := new(tls.Config)
+ if c.TLS != nil {
+ tlsConfig = c.TLS.Clone()
+ }
+ tlsConfig.InsecureSkipVerify = c.UnsafeSsl
+
+ tr := &http.Transport{
+ Proxy: c.Proxy,
+ TLSClientConfig: tlsConfig,
+ }
+
+ if c.UnixSocket != "" {
+ // No need for compression in local communications.
+ tr.DisableCompression = true
+
+ tr.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", c.UnixSocket)
+ }
+ }
+
+ client := Client{
+ url: c.URL,
+ unixSocket: c.UnixSocket,
+ username: c.Username,
+ password: c.Password,
+ httpClient: &http.Client{Timeout: c.Timeout, Transport: tr},
+ userAgent: c.UserAgent,
+ precision: c.Precision,
+ }
+ if client.userAgent == "" {
+ client.userAgent = "InfluxDBClient"
+ }
+ return &client, nil
+}
+
+// SetAuth will update the username and passwords
+func (c *Client) SetAuth(u, p string) {
+ c.username = u
+ c.password = p
+}
+
+// SetPrecision will update the precision
+func (c *Client) SetPrecision(precision string) {
+ c.precision = precision
+}
+
+// Query sends a command to the server and returns the Response
+func (c *Client) Query(q Query) (*Response, error) {
+ return c.QueryContext(context.Background(), q)
+}
+
+// QueryContext sends a command to the server and returns the Response
+// It uses a context that can be cancelled by the command line client
+func (c *Client) QueryContext(ctx context.Context, q Query) (*Response, error) {
+ u := c.url
+ u.Path = path.Join(u.Path, "query")
+
+ values := u.Query()
+ values.Set("q", q.Command)
+ values.Set("db", q.Database)
+ if q.RetentionPolicy != "" {
+ values.Set("rp", q.RetentionPolicy)
+ }
+ if q.Chunked {
+ values.Set("chunked", "true")
+ if q.ChunkSize > 0 {
+ values.Set("chunk_size", strconv.Itoa(q.ChunkSize))
+ }
+ }
+ if q.NodeID > 0 {
+ values.Set("node_id", strconv.Itoa(q.NodeID))
+ }
+ if c.precision != "" {
+ values.Set("epoch", c.precision)
+ }
+ u.RawQuery = values.Encode()
+
+ req, err := http.NewRequest("POST", u.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("User-Agent", c.userAgent)
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ req = req.WithContext(ctx)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var response Response
+ if q.Chunked {
+ cr := NewChunkedResponse(resp.Body)
+ for {
+ r, err := cr.NextResponse()
+ if err != nil {
+ // If we got an error while decoding the response, send that back.
+ return nil, err
+ }
+
+ if r == nil {
+ break
+ }
+
+ response.Results = append(response.Results, r.Results...)
+ if r.Err != nil {
+ response.Err = r.Err
+ break
+ }
+ }
+ } else {
+ dec := json.NewDecoder(resp.Body)
+ dec.UseNumber()
+ if err := dec.Decode(&response); err != nil {
+ // Ignore EOF errors if we got an invalid status code.
+ if !(err == io.EOF && resp.StatusCode != http.StatusOK) {
+ return nil, err
+ }
+ }
+ }
+
+ // If we don't have an error in our json response, and didn't get StatusOK,
+ // then send back an error.
+ if resp.StatusCode != http.StatusOK && response.Error() == nil {
+ return &response, fmt.Errorf("received status code %d from server", resp.StatusCode)
+ }
+ return &response, nil
+}
+
+// Write takes BatchPoints and allows for writing of multiple points with defaults
+// If successful, error is nil and Response is nil
+// If an error occurs, Response may contain additional information if populated.
+func (c *Client) Write(bp BatchPoints) (*Response, error) {
+ u := c.url
+ u.Path = path.Join(u.Path, "write")
+
+ var b bytes.Buffer
+ for _, p := range bp.Points {
+ err := checkPointTypes(p)
+ if err != nil {
+ return nil, err
+ }
+ if p.Raw != "" {
+ if _, err := b.WriteString(p.Raw); err != nil {
+ return nil, err
+ }
+ } else {
+ for k, v := range bp.Tags {
+ if p.Tags == nil {
+ p.Tags = make(map[string]string, len(bp.Tags))
+ }
+ p.Tags[k] = v
+ }
+
+ if _, err := b.WriteString(p.MarshalString()); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := b.WriteByte('\n'); err != nil {
+ return nil, err
+ }
+ }
+
+ req, err := http.NewRequest("POST", u.String(), &b)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "")
+ req.Header.Set("User-Agent", c.userAgent)
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ precision := bp.Precision
+ if precision == "" {
+ precision = c.precision
+ }
+
+ params := req.URL.Query()
+ params.Set("db", bp.Database)
+ params.Set("rp", bp.RetentionPolicy)
+ params.Set("precision", precision)
+ params.Set("consistency", bp.WriteConsistency)
+ req.URL.RawQuery = params.Encode()
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var response Response
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
+ var err = fmt.Errorf(string(body))
+ response.Err = err
+ return &response, err
+ }
+
+ return nil, nil
+}
+
+// WriteLineProtocol takes a string with line returns to delimit each write
+// If successful, error is nil and Response is nil
+// If an error occurs, Response may contain additional information if populated.
+func (c *Client) WriteLineProtocol(data, database, retentionPolicy, precision, writeConsistency string) (*Response, error) {
+ u := c.url
+ u.Path = path.Join(u.Path, "write")
+
+ r := strings.NewReader(data)
+
+ req, err := http.NewRequest("POST", u.String(), r)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "")
+ req.Header.Set("User-Agent", c.userAgent)
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+ params := req.URL.Query()
+ params.Set("db", database)
+ params.Set("rp", retentionPolicy)
+ params.Set("precision", precision)
+ params.Set("consistency", writeConsistency)
+ req.URL.RawQuery = params.Encode()
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var response Response
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
+ err := fmt.Errorf(string(body))
+ response.Err = err
+ return &response, err
+ }
+
+ return nil, nil
+}
+
+// Ping will check to see if the server is up
+// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred.
+func (c *Client) Ping() (time.Duration, string, error) {
+ now := time.Now()
+
+ u := c.url
+ u.Path = path.Join(u.Path, "ping")
+
+ req, err := http.NewRequest("GET", u.String(), nil)
+ if err != nil {
+ return 0, "", err
+ }
+ req.Header.Set("User-Agent", c.userAgent)
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return 0, "", err
+ }
+ defer resp.Body.Close()
+
+ version := resp.Header.Get("X-Influxdb-Version")
+ return time.Since(now), version, nil
+}
+
+// Structs
+
+// Message represents a user message.
+type Message struct {
+ Level string `json:"level,omitempty"`
+ Text string `json:"text,omitempty"`
+}
+
+// Result represents a resultset returned from a single statement.
+type Result struct {
+ Series []models.Row
+ Messages []*Message
+ Err error
+}
+
+// MarshalJSON encodes the result into JSON.
+func (r *Result) MarshalJSON() ([]byte, error) {
+ // Define a struct that outputs "error" as a string.
+ var o struct {
+ Series []models.Row `json:"series,omitempty"`
+ Messages []*Message `json:"messages,omitempty"`
+ Err string `json:"error,omitempty"`
+ }
+
+ // Copy fields to output struct.
+ o.Series = r.Series
+ o.Messages = r.Messages
+ if r.Err != nil {
+ o.Err = r.Err.Error()
+ }
+
+ return json.Marshal(&o)
+}
+
+// UnmarshalJSON decodes the data into the Result struct
+func (r *Result) UnmarshalJSON(b []byte) error {
+ var o struct {
+ Series []models.Row `json:"series,omitempty"`
+ Messages []*Message `json:"messages,omitempty"`
+ Err string `json:"error,omitempty"`
+ }
+
+ dec := json.NewDecoder(bytes.NewBuffer(b))
+ dec.UseNumber()
+ err := dec.Decode(&o)
+ if err != nil {
+ return err
+ }
+ r.Series = o.Series
+ r.Messages = o.Messages
+ if o.Err != "" {
+ r.Err = errors.New(o.Err)
+ }
+ return nil
+}
+
+// Response represents a list of statement results.
+type Response struct {
+ Results []Result
+ Err error
+}
+
+// MarshalJSON encodes the response into JSON.
+func (r *Response) MarshalJSON() ([]byte, error) {
+ // Define a struct that outputs "error" as a string.
+ var o struct {
+ Results []Result `json:"results,omitempty"`
+ Err string `json:"error,omitempty"`
+ }
+
+ // Copy fields to output struct.
+ o.Results = r.Results
+ if r.Err != nil {
+ o.Err = r.Err.Error()
+ }
+
+ return json.Marshal(&o)
+}
+
+// UnmarshalJSON decodes the data into the Response struct
+func (r *Response) UnmarshalJSON(b []byte) error {
+ var o struct {
+ Results []Result `json:"results,omitempty"`
+ Err string `json:"error,omitempty"`
+ }
+
+ dec := json.NewDecoder(bytes.NewBuffer(b))
+ dec.UseNumber()
+ err := dec.Decode(&o)
+ if err != nil {
+ return err
+ }
+ r.Results = o.Results
+ if o.Err != "" {
+ r.Err = errors.New(o.Err)
+ }
+ return nil
+}
+
+// Error returns the first error from any statement.
+// Returns nil if no errors occurred on any statements.
+func (r *Response) Error() error {
+ if r.Err != nil {
+ return r.Err
+ }
+ for _, result := range r.Results {
+ if result.Err != nil {
+ return result.Err
+ }
+ }
+ return nil
+}
+
+// duplexReader reads responses and writes it to another writer while
+// satisfying the reader interface.
+type duplexReader struct {
+ r io.Reader
+ w io.Writer
+}
+
+func (r *duplexReader) Read(p []byte) (n int, err error) {
+ n, err = r.r.Read(p)
+ if err == nil {
+ r.w.Write(p[:n])
+ }
+ return n, err
+}
+
+// ChunkedResponse represents a response from the server that
+// uses chunking to stream the output.
+type ChunkedResponse struct {
+ dec *json.Decoder
+ duplex *duplexReader
+ buf bytes.Buffer
+}
+
+// NewChunkedResponse reads a stream and produces responses from the stream.
+func NewChunkedResponse(r io.Reader) *ChunkedResponse {
+ resp := &ChunkedResponse{}
+ resp.duplex = &duplexReader{r: r, w: &resp.buf}
+ resp.dec = json.NewDecoder(resp.duplex)
+ resp.dec.UseNumber()
+ return resp
+}
+
+// NextResponse reads the next line of the stream and returns a response.
+func (r *ChunkedResponse) NextResponse() (*Response, error) {
+ var response Response
+ if err := r.dec.Decode(&response); err != nil {
+ if err == io.EOF {
+ return nil, nil
+ }
+ // A decoding error happened. This probably means the server crashed
+ // and sent a last-ditch error message to us. Ensure we have read the
+ // entirety of the connection to get any remaining error text.
+ io.Copy(ioutil.Discard, r.duplex)
+ return nil, errors.New(strings.TrimSpace(r.buf.String()))
+ }
+ r.buf.Reset()
+ return &response, nil
+}
+
+// Point defines the fields that will be written to the database
+// Measurement, Time, and Fields are required
+// Precision can be specified if the time is in epoch format (integer).
+// Valid values for Precision are n, u, ms, s, m, and h
+type Point struct {
+ Measurement string
+ Tags map[string]string
+ Time time.Time
+ Fields map[string]interface{}
+ Precision string
+ Raw string
+}
+
+// MarshalJSON will format the time in RFC3339Nano
+// Precision is also ignored as it is only used for writing, not reading
+// Or another way to say it is we always send back in nanosecond precision
+func (p *Point) MarshalJSON() ([]byte, error) {
+ point := struct {
+ Measurement string `json:"measurement,omitempty"`
+ Tags map[string]string `json:"tags,omitempty"`
+ Time string `json:"time,omitempty"`
+ Fields map[string]interface{} `json:"fields,omitempty"`
+ Precision string `json:"precision,omitempty"`
+ }{
+ Measurement: p.Measurement,
+ Tags: p.Tags,
+ Fields: p.Fields,
+ Precision: p.Precision,
+ }
+ // Let it omit empty if it's really zero
+ if !p.Time.IsZero() {
+ point.Time = p.Time.UTC().Format(time.RFC3339Nano)
+ }
+ return json.Marshal(&point)
+}
+
+// MarshalString renders string representation of a Point with specified
+// precision. The default precision is nanoseconds.
+func (p *Point) MarshalString() string {
+ pt, err := models.NewPoint(p.Measurement, models.NewTags(p.Tags), p.Fields, p.Time)
+ if err != nil {
+ return "# ERROR: " + err.Error() + " " + p.Measurement
+ }
+ if p.Precision == "" || p.Precision == "ns" || p.Precision == "n" {
+ return pt.String()
+ }
+ return pt.PrecisionString(p.Precision)
+}
+
+// UnmarshalJSON decodes the data into the Point struct
+func (p *Point) UnmarshalJSON(b []byte) error {
+ var normal struct {
+ Measurement string `json:"measurement"`
+ Tags map[string]string `json:"tags"`
+ Time time.Time `json:"time"`
+ Precision string `json:"precision"`
+ Fields map[string]interface{} `json:"fields"`
+ }
+ var epoch struct {
+ Measurement string `json:"measurement"`
+ Tags map[string]string `json:"tags"`
+ Time *int64 `json:"time"`
+ Precision string `json:"precision"`
+ Fields map[string]interface{} `json:"fields"`
+ }
+
+ if err := func() error {
+ var err error
+ dec := json.NewDecoder(bytes.NewBuffer(b))
+ dec.UseNumber()
+ if err = dec.Decode(&epoch); err != nil {
+ return err
+ }
+ // Convert from epoch to time.Time, but only if Time
+ // was actually set.
+ var ts time.Time
+ if epoch.Time != nil {
+ ts, err = EpochToTime(*epoch.Time, epoch.Precision)
+ if err != nil {
+ return err
+ }
+ }
+ p.Measurement = epoch.Measurement
+ p.Tags = epoch.Tags
+ p.Time = ts
+ p.Precision = epoch.Precision
+ p.Fields = normalizeFields(epoch.Fields)
+ return nil
+ }(); err == nil {
+ return nil
+ }
+
+ dec := json.NewDecoder(bytes.NewBuffer(b))
+ dec.UseNumber()
+ if err := dec.Decode(&normal); err != nil {
+ return err
+ }
+ normal.Time = SetPrecision(normal.Time, normal.Precision)
+ p.Measurement = normal.Measurement
+ p.Tags = normal.Tags
+ p.Time = normal.Time
+ p.Precision = normal.Precision
+ p.Fields = normalizeFields(normal.Fields)
+
+ return nil
+}
+
+// Remove any notion of json.Number
+func normalizeFields(fields map[string]interface{}) map[string]interface{} {
+ newFields := map[string]interface{}{}
+
+ for k, v := range fields {
+ switch v := v.(type) {
+ case json.Number:
+ jv, e := v.Float64()
+ if e != nil {
+ panic(fmt.Sprintf("unable to convert json.Number to float64: %s", e))
+ }
+ newFields[k] = jv
+ default:
+ newFields[k] = v
+ }
+ }
+ return newFields
+}
+
+// BatchPoints is used to send batched data in a single write.
+// Database and Points are required
+// If no retention policy is specified, it will use the databases default retention policy.
+// If tags are specified, they will be "merged" with all points. If a point already has that tag, it will be ignored.
+// If time is specified, it will be applied to any point with an empty time.
+// Precision can be specified if the time is in epoch format (integer).
+// Valid values for Precision are n, u, ms, s, m, and h
+type BatchPoints struct {
+ Points []Point `json:"points,omitempty"`
+ Database string `json:"database,omitempty"`
+ RetentionPolicy string `json:"retentionPolicy,omitempty"`
+ Tags map[string]string `json:"tags,omitempty"`
+ Time time.Time `json:"time,omitempty"`
+ Precision string `json:"precision,omitempty"`
+ WriteConsistency string `json:"-"`
+}
+
+// UnmarshalJSON decodes the data into the BatchPoints struct
+func (bp *BatchPoints) UnmarshalJSON(b []byte) error {
+ var normal struct {
+ Points []Point `json:"points"`
+ Database string `json:"database"`
+ RetentionPolicy string `json:"retentionPolicy"`
+ Tags map[string]string `json:"tags"`
+ Time time.Time `json:"time"`
+ Precision string `json:"precision"`
+ }
+ var epoch struct {
+ Points []Point `json:"points"`
+ Database string `json:"database"`
+ RetentionPolicy string `json:"retentionPolicy"`
+ Tags map[string]string `json:"tags"`
+ Time *int64 `json:"time"`
+ Precision string `json:"precision"`
+ }
+
+ if err := func() error {
+ var err error
+ if err = json.Unmarshal(b, &epoch); err != nil {
+ return err
+ }
+ // Convert from epoch to time.Time
+ var ts time.Time
+ if epoch.Time != nil {
+ ts, err = EpochToTime(*epoch.Time, epoch.Precision)
+ if err != nil {
+ return err
+ }
+ }
+ bp.Points = epoch.Points
+ bp.Database = epoch.Database
+ bp.RetentionPolicy = epoch.RetentionPolicy
+ bp.Tags = epoch.Tags
+ bp.Time = ts
+ bp.Precision = epoch.Precision
+ return nil
+ }(); err == nil {
+ return nil
+ }
+
+ if err := json.Unmarshal(b, &normal); err != nil {
+ return err
+ }
+ normal.Time = SetPrecision(normal.Time, normal.Precision)
+ bp.Points = normal.Points
+ bp.Database = normal.Database
+ bp.RetentionPolicy = normal.RetentionPolicy
+ bp.Tags = normal.Tags
+ bp.Time = normal.Time
+ bp.Precision = normal.Precision
+
+ return nil
+}
+
+// utility functions
+
+// Addr provides the current url as a string of the server the client is connected to.
+func (c *Client) Addr() string {
+ if c.unixSocket != "" {
+ return c.unixSocket
+ }
+ return c.url.String()
+}
+
+// checkPointTypes ensures no unsupported types are submitted to influxdb, returning error if they are found.
+func checkPointTypes(p Point) error {
+ for _, v := range p.Fields {
+ switch v.(type) {
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool, string, nil:
+ return nil
+ default:
+ return fmt.Errorf("unsupported point type: %T", v)
+ }
+ }
+ return nil
+}
+
+// helper functions
+
+// EpochToTime takes a unix epoch time and uses precision to return back a time.Time
+func EpochToTime(epoch int64, precision string) (time.Time, error) {
+ if precision == "" {
+ precision = "s"
+ }
+ var t time.Time
+ switch precision {
+ case "h":
+ t = time.Unix(0, epoch*int64(time.Hour))
+ case "m":
+ t = time.Unix(0, epoch*int64(time.Minute))
+ case "s":
+ t = time.Unix(0, epoch*int64(time.Second))
+ case "ms":
+ t = time.Unix(0, epoch*int64(time.Millisecond))
+ case "u":
+ t = time.Unix(0, epoch*int64(time.Microsecond))
+ case "n":
+ t = time.Unix(0, epoch)
+ default:
+ return time.Time{}, fmt.Errorf("Unknown precision %q", precision)
+ }
+ return t, nil
+}
+
+// SetPrecision will round a time to the specified precision
+func SetPrecision(t time.Time, precision string) time.Time {
+ switch precision {
+ case "n":
+ case "u":
+ return t.Round(time.Microsecond)
+ case "ms":
+ return t.Round(time.Millisecond)
+ case "s":
+ return t.Round(time.Second)
+ case "m":
+ return t.Round(time.Minute)
+ case "h":
+ return t.Round(time.Hour)
+ }
+ return t
+}
diff --git a/influxdb/client/models/inline_fnv.go b/influxdb/client/models/inline_fnv.go
new file mode 100644
index 0000000..4c89b4b
--- /dev/null
+++ b/influxdb/client/models/inline_fnv.go
@@ -0,0 +1,45 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models"
+
+// from stdlib hash/fnv/fnv.go
+const (
+ prime64 = 1099511628211
+ offset64 = 14695981039346656037
+)
+
+// InlineFNV64a is an alloc-free port of the standard library's fnv64a.
+// See https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function.
+type InlineFNV64a uint64
+
+// NewInlineFNV64a returns a new instance of InlineFNV64a.
+func NewInlineFNV64a() InlineFNV64a {
+ return offset64
+}
+
+// Write adds data to the running hash.
+func (s *InlineFNV64a) Write(data []byte) (int, error) {
+ hash := uint64(*s)
+ for _, c := range data {
+ hash ^= uint64(c)
+ hash *= prime64
+ }
+ *s = InlineFNV64a(hash)
+ return len(data), nil
+}
+
+// Sum64 returns the uint64 of the current resulting hash.
+func (s *InlineFNV64a) Sum64() uint64 {
+ return uint64(*s)
+}
diff --git a/influxdb/client/models/inline_strconv_parse.go b/influxdb/client/models/inline_strconv_parse.go
new file mode 100644
index 0000000..9f763b0
--- /dev/null
+++ b/influxdb/client/models/inline_strconv_parse.go
@@ -0,0 +1,57 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models"
+
+import (
+ "reflect"
+ "strconv"
+ "unsafe"
+)
+
+// parseIntBytes is a zero-alloc wrapper around strconv.ParseInt.
+func parseIntBytes(b []byte, base int, bitSize int) (i int64, err error) {
+ s := unsafeBytesToString(b)
+ return strconv.ParseInt(s, base, bitSize)
+}
+
+// parseUintBytes is a zero-alloc wrapper around strconv.ParseUint.
+func parseUintBytes(b []byte, base int, bitSize int) (i uint64, err error) {
+ s := unsafeBytesToString(b)
+ return strconv.ParseUint(s, base, bitSize)
+}
+
+// parseFloatBytes is a zero-alloc wrapper around strconv.ParseFloat.
+func parseFloatBytes(b []byte, bitSize int) (float64, error) {
+ s := unsafeBytesToString(b)
+ return strconv.ParseFloat(s, bitSize)
+}
+
+// parseBoolBytes is a zero-alloc wrapper around strconv.ParseBool.
+func parseBoolBytes(b []byte) (bool, error) {
+ return strconv.ParseBool(unsafeBytesToString(b))
+}
+
+// unsafeBytesToString converts a []byte to a string without a heap allocation.
+//
+// It is unsafe, and is intended to prepare input to short-lived functions
+// that require strings.
+func unsafeBytesToString(in []byte) string {
+ src := *(*reflect.SliceHeader)(unsafe.Pointer(&in))
+ dst := reflect.StringHeader{
+ Data: src.Data,
+ Len: src.Len,
+ }
+ s := *(*string)(unsafe.Pointer(&dst))
+ return s
+}
diff --git a/influxdb/client/models/points.go b/influxdb/client/models/points.go
new file mode 100644
index 0000000..1694aa3
--- /dev/null
+++ b/influxdb/client/models/points.go
@@ -0,0 +1,2425 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models"
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/pkg/escape"
+)
+
+type escapeSet struct {
+ k [1]byte
+ esc [2]byte
+}
+
+var (
+ measurementEscapeCodes = [...]escapeSet{
+ {k: [1]byte{','}, esc: [2]byte{'\\', ','}},
+ {k: [1]byte{' '}, esc: [2]byte{'\\', ' '}},
+ }
+
+ tagEscapeCodes = [...]escapeSet{
+ {k: [1]byte{','}, esc: [2]byte{'\\', ','}},
+ {k: [1]byte{' '}, esc: [2]byte{'\\', ' '}},
+ {k: [1]byte{'='}, esc: [2]byte{'\\', '='}},
+ }
+
+ // ErrPointMustHaveAField is returned when operating on a point that does not have any fields.
+ ErrPointMustHaveAField = errors.New("point without fields is unsupported")
+
+ // ErrInvalidNumber is returned when a number is expected but not provided.
+ ErrInvalidNumber = errors.New("invalid number")
+
+ // ErrInvalidPoint is returned when a point cannot be parsed correctly.
+ ErrInvalidPoint = errors.New("point is invalid")
+)
+
+const (
+ // MaxKeyLength is the largest allowed size of the combined measurement and tag keys.
+ MaxKeyLength = 65535
+)
+
+// enableUint64Support will enable uint64 support if set to true.
+var enableUint64Support = false
+
+// EnableUintSupport manually enables uint support for the point parser.
+// This function will be removed in the future and only exists for unit tests during the
+// transition.
+func EnableUintSupport() {
+ enableUint64Support = true
+}
+
+// Point defines the values that will be written to the database.
+type Point interface {
+ // Name return the measurement name for the point.
+ Name() []byte
+
+ // SetName updates the measurement name for the point.
+ SetName(string)
+
+ // Tags returns the tag set for the point.
+ Tags() Tags
+
+ // ForEachTag iterates over each tag invoking fn. If fn return false, iteration stops.
+ ForEachTag(fn func(k, v []byte) bool)
+
+ // AddTag adds or replaces a tag value for a point.
+ AddTag(key, value string)
+
+ // SetTags replaces the tags for the point.
+ SetTags(tags Tags)
+
+ // HasTag returns true if the tag exists for the point.
+ HasTag(tag []byte) bool
+
+ // Fields returns the fields for the point.
+ Fields() (Fields, error)
+
+ // Time return the timestamp for the point.
+ Time() time.Time
+
+ // SetTime updates the timestamp for the point.
+ SetTime(t time.Time)
+
+ // UnixNano returns the timestamp of the point as nanoseconds since Unix epoch.
+ UnixNano() int64
+
+ // HashID returns a non-cryptographic checksum of the point's key.
+ HashID() uint64
+
+ // Key returns the key (measurement joined with tags) of the point.
+ Key() []byte
+
+ // String returns a string representation of the point. If there is a
+ // timestamp associated with the point then it will be specified with the default
+ // precision of nanoseconds.
+ String() string
+
+ // MarshalBinary returns a binary representation of the point.
+ MarshalBinary() ([]byte, error)
+
+ // PrecisionString returns a string representation of the point. If there
+ // is a timestamp associated with the point then it will be specified in the
+ // given unit.
+ PrecisionString(precision string) string
+
+ // RoundedString returns a string representation of the point. If there
+ // is a timestamp associated with the point, then it will be rounded to the
+ // given duration.
+ RoundedString(d time.Duration) string
+
+ // Split will attempt to return multiple points with the same timestamp whose
+ // string representations are no longer than size. Points with a single field or
+ // a point without a timestamp may exceed the requested size.
+ Split(size int) []Point
+
+ // Round will round the timestamp of the point to the given duration.
+ Round(d time.Duration)
+
+ // StringSize returns the length of the string that would be returned by String().
+ StringSize() int
+
+ // AppendString appends the result of String() to the provided buffer and returns
+ // the result, potentially reducing string allocations.
+ AppendString(buf []byte) []byte
+
+ // FieldIterator retuns a FieldIterator that can be used to traverse the
+ // fields of a point without constructing the in-memory map.
+ FieldIterator() FieldIterator
+}
+
+// FieldType represents the type of a field.
+type FieldType int
+
+const (
+ // Integer indicates the field's type is integer.
+ Integer FieldType = iota
+
+ // Float indicates the field's type is float.
+ Float
+
+ // Boolean indicates the field's type is boolean.
+ Boolean
+
+ // String indicates the field's type is string.
+ String
+
+ // Empty is used to indicate that there is no field.
+ Empty
+
+ // Unsigned indicates the field's type is an unsigned integer.
+ Unsigned
+)
+
+// FieldIterator provides a low-allocation interface to iterate through a point's fields.
+type FieldIterator interface {
+ // Next indicates whether there any fields remaining.
+ Next() bool
+
+ // FieldKey returns the key of the current field.
+ FieldKey() []byte
+
+ // Type returns the FieldType of the current field.
+ Type() FieldType
+
+ // StringValue returns the string value of the current field.
+ StringValue() string
+
+ // IntegerValue returns the integer value of the current field.
+ IntegerValue() (int64, error)
+
+ // UnsignedValue returns the unsigned value of the current field.
+ UnsignedValue() (uint64, error)
+
+ // BooleanValue returns the boolean value of the current field.
+ BooleanValue() (bool, error)
+
+ // FloatValue returns the float value of the current field.
+ FloatValue() (float64, error)
+
+ // Reset resets the iterator to its initial state.
+ Reset()
+}
+
+// Points represents a sortable list of points by timestamp.
+type Points []Point
+
+// Len implements sort.Interface.
+func (a Points) Len() int { return len(a) }
+
+// Less implements sort.Interface.
+func (a Points) Less(i, j int) bool { return a[i].Time().Before(a[j].Time()) }
+
+// Swap implements sort.Interface.
+func (a Points) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+
+// point is the default implementation of Point.
+type point struct {
+ time time.Time
+
+ // text encoding of measurement and tags
+ // key must always be stored sorted by tags, if the original line was not sorted,
+ // we need to resort it
+ key []byte
+
+ // text encoding of field data
+ fields []byte
+
+ // text encoding of timestamp
+ ts []byte
+
+ // cached version of parsed fields from data
+ cachedFields map[string]interface{}
+
+ // cached version of parsed name from key
+ cachedName string
+
+ // cached version of parsed tags
+ cachedTags Tags
+
+ it fieldIterator
+}
+
+// type assertions
+var (
+ _ Point = (*point)(nil)
+ _ FieldIterator = (*point)(nil)
+)
+
+const (
+ // the number of characters for the largest possible int64 (9223372036854775807)
+ maxInt64Digits = 19
+
+ // the number of characters for the smallest possible int64 (-9223372036854775808)
+ minInt64Digits = 20
+
+ // the number of characters for the largest possible uint64 (18446744073709551615)
+ maxUint64Digits = 20
+
+ // the number of characters required for the largest float64 before a range check
+ // would occur during parsing
+ maxFloat64Digits = 25
+
+ // the number of characters required for smallest float64 before a range check occur
+ // would occur during parsing
+ minFloat64Digits = 27
+)
+
+// ParsePoints returns a slice of Points from a text representation of a point
+// with each point separated by newlines. If any points fail to parse, a non-nil error
+// will be returned in addition to the points that parsed successfully.
+func ParsePoints(buf []byte) ([]Point, error) {
+ return ParsePointsWithPrecision(buf, time.Now().UTC(), "n")
+}
+
+// ParsePointsString is identical to ParsePoints but accepts a string.
+func ParsePointsString(buf string) ([]Point, error) {
+ return ParsePoints([]byte(buf))
+}
+
+// ParseKey returns the measurement name and tags from a point.
+//
+// NOTE: to minimize heap allocations, the returned Tags will refer to subslices of buf.
+// This can have the unintended effect preventing buf from being garbage collected.
+func ParseKey(buf []byte) (string, Tags) {
+ name, tags := ParseKeyBytes(buf)
+ return string(name), tags
+}
+
+func ParseKeyBytes(buf []byte) ([]byte, Tags) {
+ return ParseKeyBytesWithTags(buf, nil)
+}
+
+func ParseKeyBytesWithTags(buf []byte, tags Tags) ([]byte, Tags) {
+ // Ignore the error because scanMeasurement returns "missing fields" which we ignore
+ // when just parsing a key
+ state, i, _ := scanMeasurement(buf, 0)
+
+ var name []byte
+ if state == tagKeyState {
+ tags = parseTags(buf, tags)
+ // scanMeasurement returns the location of the comma if there are tags, strip that off
+ name = buf[:i-1]
+ } else {
+ name = buf[:i]
+ }
+ return unescapeMeasurement(name), tags
+}
+
+func ParseTags(buf []byte) Tags {
+ return parseTags(buf, nil)
+}
+
+func ParseName(buf []byte) []byte {
+ // Ignore the error because scanMeasurement returns "missing fields" which we ignore
+ // when just parsing a key
+ state, i, _ := scanMeasurement(buf, 0)
+ var name []byte
+ if state == tagKeyState {
+ name = buf[:i-1]
+ } else {
+ name = buf[:i]
+ }
+
+ return unescapeMeasurement(name)
+}
+
+// ParsePointsWithPrecision is similar to ParsePoints, but allows the
+// caller to provide a precision for time.
+//
+// NOTE: to minimize heap allocations, the returned Points will refer to subslices of buf.
+// This can have the unintended effect preventing buf from being garbage collected.
+func ParsePointsWithPrecision(buf []byte, defaultTime time.Time, precision string) ([]Point, error) {
+ points := make([]Point, 0, bytes.Count(buf, []byte{'\n'})+1)
+ var (
+ pos int
+ block []byte
+ failed []string
+ )
+ for pos < len(buf) {
+ pos, block = scanLine(buf, pos)
+ pos++
+
+ if len(block) == 0 {
+ continue
+ }
+
+ start := skipWhitespace(block, 0)
+
+ // If line is all whitespace, just skip it
+ if start >= len(block) {
+ continue
+ }
+
+ // lines which start with '#' are comments
+ if block[start] == '#' {
+ continue
+ }
+
+ // strip the newline if one is present
+ if block[len(block)-1] == '\n' {
+ block = block[:len(block)-1]
+ }
+
+ pt, err := parsePoint(block[start:], defaultTime, precision)
+ if err != nil {
+ failed = append(failed, fmt.Sprintf("unable to parse '%s': %v", string(block[start:]), err))
+ } else {
+ points = append(points, pt)
+ }
+
+ }
+ if len(failed) > 0 {
+ return points, fmt.Errorf("%s", strings.Join(failed, "\n"))
+ }
+ return points, nil
+
+}
+
+func parsePoint(buf []byte, defaultTime time.Time, precision string) (Point, error) {
+ // scan the first block which is measurement[,tag1=value1,tag2=value2...]
+ pos, key, err := scanKey(buf, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ // measurement name is required
+ if len(key) == 0 {
+ return nil, fmt.Errorf("missing measurement")
+ }
+
+ if len(key) > MaxKeyLength {
+ return nil, fmt.Errorf("max key length exceeded: %v > %v", len(key), MaxKeyLength)
+ }
+
+ // scan the second block is which is field1=value1[,field2=value2,...]
+ pos, fields, err := scanFields(buf, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ // at least one field is required
+ if len(fields) == 0 {
+ return nil, fmt.Errorf("missing fields")
+ }
+
+ var maxKeyErr error
+ err = walkFields(fields, func(k, v []byte) bool {
+ if sz := seriesKeySize(key, k); sz > MaxKeyLength {
+ maxKeyErr = fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength)
+ return false
+ }
+ return true
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ if maxKeyErr != nil {
+ return nil, maxKeyErr
+ }
+
+ // scan the last block which is an optional integer timestamp
+ pos, ts, err := scanTime(buf, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ pt := &point{
+ key: key,
+ fields: fields,
+ ts: ts,
+ }
+
+ if len(ts) == 0 {
+ pt.time = defaultTime
+ pt.SetPrecision(precision)
+ } else {
+ ts, err := parseIntBytes(ts, 10, 64)
+ if err != nil {
+ return nil, err
+ }
+ pt.time, err = SafeCalcTime(ts, precision)
+ if err != nil {
+ return nil, err
+ }
+
+ // Determine if there are illegal non-whitespace characters after the
+ // timestamp block.
+ for pos < len(buf) {
+ if buf[pos] != ' ' {
+ return nil, ErrInvalidPoint
+ }
+ pos++
+ }
+ }
+ return pt, nil
+}
+
+// GetPrecisionMultiplier will return a multiplier for the precision specified.
+func GetPrecisionMultiplier(precision string) int64 {
+ d := time.Nanosecond
+ switch precision {
+ case "u":
+ d = time.Microsecond
+ case "ms":
+ d = time.Millisecond
+ case "s":
+ d = time.Second
+ case "m":
+ d = time.Minute
+ case "h":
+ d = time.Hour
+ }
+ return int64(d)
+}
+
+// scanKey scans buf starting at i for the measurement and tag portion of the point.
+// It returns the ending position and the byte slice of key within buf. If there
+// are tags, they will be sorted if they are not already.
+func scanKey(buf []byte, i int) (int, []byte, error) {
+ start := skipWhitespace(buf, i)
+
+ i = start
+
+ // Determines whether the tags are sort, assume they are
+ sorted := true
+
+ // indices holds the indexes within buf of the start of each tag. For example,
+ // a buf of 'cpu,host=a,region=b,zone=c' would have indices slice of [4,11,20]
+ // which indicates that the first tag starts at buf[4], seconds at buf[11], and
+ // last at buf[20]
+ indices := make([]int, 100)
+
+ // tracks how many commas we've seen so we know how many values are indices.
+ // Since indices is an arbitrarily large slice,
+ // we need to know how many values in the buffer are in use.
+ commas := 0
+
+ // First scan the Point's measurement.
+ state, i, err := scanMeasurement(buf, i)
+ if err != nil {
+ return i, buf[start:i], err
+ }
+
+ // Optionally scan tags if needed.
+ if state == tagKeyState {
+ i, commas, indices, err = scanTags(buf, i, indices)
+ if err != nil {
+ return i, buf[start:i], err
+ }
+ }
+
+ // Now we know where the key region is within buf, and the location of tags, we
+ // need to determine if duplicate tags exist and if the tags are sorted. This iterates
+ // over the list comparing each tag in the sequence with each other.
+ for j := 0; j < commas-1; j++ {
+ // get the left and right tags
+ _, left := scanTo(buf[indices[j]:indices[j+1]-1], 0, '=')
+ _, right := scanTo(buf[indices[j+1]:indices[j+2]-1], 0, '=')
+
+ // If left is greater than right, the tags are not sorted. We do not have to
+ // continue because the short path no longer works.
+ // If the tags are equal, then there are duplicate tags, and we should abort.
+ // If the tags are not sorted, this pass may not find duplicate tags and we
+ // need to do a more exhaustive search later.
+ if cmp := bytes.Compare(left, right); cmp > 0 {
+ sorted = false
+ break
+ } else if cmp == 0 {
+ return i, buf[start:i], fmt.Errorf("duplicate tags")
+ }
+ }
+
+ // If the tags are not sorted, then sort them. This sort is inline and
+ // uses the tag indices we created earlier. The actual buffer is not sorted, the
+ // indices are using the buffer for value comparison. After the indices are sorted,
+ // the buffer is reconstructed from the sorted indices.
+ if !sorted && commas > 0 {
+ // Get the measurement name for later
+ measurement := buf[start : indices[0]-1]
+
+ // Sort the indices
+ indices := indices[:commas]
+ insertionSort(0, commas, buf, indices)
+
+ // Create a new key using the measurement and sorted indices
+ b := make([]byte, len(buf[start:i]))
+ pos := copy(b, measurement)
+ for _, i := range indices {
+ b[pos] = ','
+ pos++
+ _, v := scanToSpaceOr(buf, i, ',')
+ pos += copy(b[pos:], v)
+ }
+
+ // Check again for duplicate tags now that the tags are sorted.
+ for j := 0; j < commas-1; j++ {
+ // get the left and right tags
+ _, left := scanTo(buf[indices[j]:], 0, '=')
+ _, right := scanTo(buf[indices[j+1]:], 0, '=')
+
+ // If the tags are equal, then there are duplicate tags, and we should abort.
+ // If the tags are not sorted, this pass may not find duplicate tags and we
+ // need to do a more exhaustive search later.
+ if bytes.Equal(left, right) {
+ return i, b, fmt.Errorf("duplicate tags")
+ }
+ }
+
+ return i, b, nil
+ }
+
+ return i, buf[start:i], nil
+}
+
+// The following constants allow us to specify which state to move to
+// next, when scanning sections of a Point.
+const (
+ tagKeyState = iota
+ tagValueState
+ fieldsState
+)
+
+// scanMeasurement examines the measurement part of a Point, returning
+// the next state to move to, and the current location in the buffer.
+func scanMeasurement(buf []byte, i int) (int, int, error) {
+ // Check first byte of measurement, anything except a comma is fine.
+ // It can't be a space, since whitespace is stripped prior to this
+ // function call.
+ if i >= len(buf) || buf[i] == ',' {
+ return -1, i, fmt.Errorf("missing measurement")
+ }
+
+ for {
+ i++
+ if i >= len(buf) {
+ // cpu
+ return -1, i, fmt.Errorf("missing fields")
+ }
+
+ if buf[i-1] == '\\' {
+ // Skip character (it's escaped).
+ continue
+ }
+
+ // Unescaped comma; move onto scanning the tags.
+ if buf[i] == ',' {
+ return tagKeyState, i + 1, nil
+ }
+
+ // Unescaped space; move onto scanning the fields.
+ if buf[i] == ' ' {
+ // cpu value=1.0
+ return fieldsState, i, nil
+ }
+ }
+}
+
+// scanTags examines all the tags in a Point, keeping track of and
+// returning the updated indices slice, number of commas and location
+// in buf where to start examining the Point fields.
+func scanTags(buf []byte, i int, indices []int) (int, int, []int, error) {
+ var (
+ err error
+ commas int
+ state = tagKeyState
+ )
+
+ for {
+ switch state {
+ case tagKeyState:
+ // Grow our indices slice if we have too many tags.
+ if commas >= len(indices) {
+ newIndics := make([]int, cap(indices)*2)
+ copy(newIndics, indices)
+ indices = newIndics
+ }
+ indices[commas] = i
+ commas++
+
+ i, err = scanTagsKey(buf, i)
+ state = tagValueState // tag value always follows a tag key
+ case tagValueState:
+ state, i, err = scanTagsValue(buf, i)
+ case fieldsState:
+ indices[commas] = i + 1
+ return i, commas, indices, nil
+ }
+
+ if err != nil {
+ return i, commas, indices, err
+ }
+ }
+}
+
+// scanTagsKey scans each character in a tag key.
+func scanTagsKey(buf []byte, i int) (int, error) {
+ // First character of the key.
+ if i >= len(buf) || buf[i] == ' ' || buf[i] == ',' || buf[i] == '=' {
+ // cpu,{'', ' ', ',', '='}
+ return i, fmt.Errorf("missing tag key")
+ }
+
+ // Examine each character in the tag key until we hit an unescaped
+ // equals (the tag value), or we hit an error (i.e., unescaped
+ // space or comma).
+ for {
+ i++
+
+ // Either we reached the end of the buffer or we hit an
+ // unescaped comma or space.
+ if i >= len(buf) ||
+ ((buf[i] == ' ' || buf[i] == ',') && buf[i-1] != '\\') {
+ // cpu,tag{'', ' ', ','}
+ return i, fmt.Errorf("missing tag value")
+ }
+
+ if buf[i] == '=' && buf[i-1] != '\\' {
+ // cpu,tag=
+ return i + 1, nil
+ }
+ }
+}
+
+// scanTagsValue scans each character in a tag value.
+func scanTagsValue(buf []byte, i int) (int, int, error) {
+ // Tag value cannot be empty.
+ if i >= len(buf) || buf[i] == ',' || buf[i] == ' ' {
+ // cpu,tag={',', ' '}
+ return -1, i, fmt.Errorf("missing tag value")
+ }
+
+ // Examine each character in the tag value until we hit an unescaped
+ // comma (move onto next tag key), an unescaped space (move onto
+ // fields), or we error out.
+ for {
+ i++
+ if i >= len(buf) {
+ // cpu,tag=value
+ return -1, i, fmt.Errorf("missing fields")
+ }
+
+ // An unescaped equals sign is an invalid tag value.
+ if buf[i] == '=' && buf[i-1] != '\\' {
+ // cpu,tag={'=', 'fo=o'}
+ return -1, i, fmt.Errorf("invalid tag format")
+ }
+
+ if buf[i] == ',' && buf[i-1] != '\\' {
+ // cpu,tag=foo,
+ return tagKeyState, i + 1, nil
+ }
+
+ // cpu,tag=foo value=1.0
+ // cpu, tag=foo\= value=1.0
+ if buf[i] == ' ' && buf[i-1] != '\\' {
+ return fieldsState, i, nil
+ }
+ }
+}
+
+func insertionSort(l, r int, buf []byte, indices []int) {
+ for i := l + 1; i < r; i++ {
+ for j := i; j > l && less(buf, indices, j, j-1); j-- {
+ indices[j], indices[j-1] = indices[j-1], indices[j]
+ }
+ }
+}
+
+func less(buf []byte, indices []int, i, j int) bool {
+ // This grabs the tag names for i & j, it ignores the values
+ _, a := scanTo(buf, indices[i], '=')
+ _, b := scanTo(buf, indices[j], '=')
+ return bytes.Compare(a, b) < 0
+}
+
+// scanFields scans buf, starting at i for the fields section of a point. It returns
+// the ending position and the byte slice of the fields within buf.
+func scanFields(buf []byte, i int) (int, []byte, error) {
+ start := skipWhitespace(buf, i)
+ i = start
+ quoted := false
+
+ // tracks how many '=' we've seen
+ equals := 0
+
+ // tracks how many commas we've seen
+ commas := 0
+
+ for {
+ // reached the end of buf?
+ if i >= len(buf) {
+ break
+ }
+
+ // escaped characters?
+ if buf[i] == '\\' && i+1 < len(buf) {
+ i += 2
+ continue
+ }
+
+ // If the value is quoted, scan until we get to the end quote
+ // Only quote values in the field value since quotes are not significant
+ // in the field key
+ if buf[i] == '"' && equals > commas {
+ quoted = !quoted
+ i++
+ continue
+ }
+
+ // If we see an =, ensure that there is at least on char before and after it
+ if buf[i] == '=' && !quoted {
+ equals++
+
+ // check for "... =123" but allow "a\ =123"
+ if buf[i-1] == ' ' && buf[i-2] != '\\' {
+ return i, buf[start:i], fmt.Errorf("missing field key")
+ }
+
+ // check for "...a=123,=456" but allow "a=123,a\,=456"
+ if buf[i-1] == ',' && buf[i-2] != '\\' {
+ return i, buf[start:i], fmt.Errorf("missing field key")
+ }
+
+ // check for "... value="
+ if i+1 >= len(buf) {
+ return i, buf[start:i], fmt.Errorf("missing field value")
+ }
+
+ // check for "... value=,value2=..."
+ if buf[i+1] == ',' || buf[i+1] == ' ' {
+ return i, buf[start:i], fmt.Errorf("missing field value")
+ }
+
+ if isNumeric(buf[i+1]) || buf[i+1] == '-' || buf[i+1] == 'N' || buf[i+1] == 'n' {
+ var err error
+ i, err = scanNumber(buf, i+1)
+ if err != nil {
+ return i, buf[start:i], err
+ }
+ continue
+ }
+ // If next byte is not a double-quote, the value must be a boolean
+ if buf[i+1] != '"' {
+ var err error
+ i, _, err = scanBoolean(buf, i+1)
+ if err != nil {
+ return i, buf[start:i], err
+ }
+ continue
+ }
+ }
+
+ if buf[i] == ',' && !quoted {
+ commas++
+ }
+
+ // reached end of block?
+ if buf[i] == ' ' && !quoted {
+ break
+ }
+ i++
+ }
+
+ if quoted {
+ return i, buf[start:i], fmt.Errorf("unbalanced quotes")
+ }
+
+ // check that all field sections had key and values (e.g. prevent "a=1,b"
+ if equals == 0 || commas != equals-1 {
+ return i, buf[start:i], fmt.Errorf("invalid field format")
+ }
+
+ return i, buf[start:i], nil
+}
+
+// scanTime scans buf, starting at i for the time section of a point. It
+// returns the ending position and the byte slice of the timestamp within buf
+// and and error if the timestamp is not in the correct numeric format.
+func scanTime(buf []byte, i int) (int, []byte, error) {
+ start := skipWhitespace(buf, i)
+ i = start
+
+ for {
+ // reached the end of buf?
+ if i >= len(buf) {
+ break
+ }
+
+ // Reached end of block or trailing whitespace?
+ if buf[i] == '\n' || buf[i] == ' ' {
+ break
+ }
+
+ // Handle negative timestamps
+ if i == start && buf[i] == '-' {
+ i++
+ continue
+ }
+
+ // Timestamps should be integers, make sure they are so we don't need
+ // to actually parse the timestamp until needed.
+ if buf[i] < '0' || buf[i] > '9' {
+ return i, buf[start:i], fmt.Errorf("bad timestamp")
+ }
+ i++
+ }
+ return i, buf[start:i], nil
+}
+
+func isNumeric(b byte) bool {
+ return (b >= '0' && b <= '9') || b == '.'
+}
+
+// scanNumber returns the end position within buf, start at i after
+// scanning over buf for an integer, or float. It returns an
+// error if a invalid number is scanned.
+func scanNumber(buf []byte, i int) (int, error) {
+ start := i
+ var isInt, isUnsigned bool
+
+ // Is negative number?
+ if i < len(buf) && buf[i] == '-' {
+ i++
+ // There must be more characters now, as just '-' is illegal.
+ if i == len(buf) {
+ return i, ErrInvalidNumber
+ }
+ }
+
+ // how many decimal points we've see
+ decimal := false
+
+ // indicates the number is float in scientific notation
+ scientific := false
+
+ for {
+ if i >= len(buf) {
+ break
+ }
+
+ if buf[i] == ',' || buf[i] == ' ' {
+ break
+ }
+
+ if buf[i] == 'i' && i > start && !(isInt || isUnsigned) {
+ isInt = true
+ i++
+ continue
+ } else if buf[i] == 'u' && i > start && !(isInt || isUnsigned) {
+ isUnsigned = true
+ i++
+ continue
+ }
+
+ if buf[i] == '.' {
+ // Can't have more than 1 decimal (e.g. 1.1.1 should fail)
+ if decimal {
+ return i, ErrInvalidNumber
+ }
+ decimal = true
+ }
+
+ // `e` is valid for floats but not as the first char
+ if i > start && (buf[i] == 'e' || buf[i] == 'E') {
+ scientific = true
+ i++
+ continue
+ }
+
+ // + and - are only valid at this point if they follow an e (scientific notation)
+ if (buf[i] == '+' || buf[i] == '-') && (buf[i-1] == 'e' || buf[i-1] == 'E') {
+ i++
+ continue
+ }
+
+ // NaN is an unsupported value
+ if i+2 < len(buf) && (buf[i] == 'N' || buf[i] == 'n') {
+ return i, ErrInvalidNumber
+ }
+
+ if !isNumeric(buf[i]) {
+ return i, ErrInvalidNumber
+ }
+ i++
+ }
+
+ if (isInt || isUnsigned) && (decimal || scientific) {
+ return i, ErrInvalidNumber
+ }
+
+ numericDigits := i - start
+ if isInt {
+ numericDigits--
+ }
+ if decimal {
+ numericDigits--
+ }
+ if buf[start] == '-' {
+ numericDigits--
+ }
+
+ if numericDigits == 0 {
+ return i, ErrInvalidNumber
+ }
+
+ // It's more common that numbers will be within min/max range for their type but we need to prevent
+ // out or range numbers from being parsed successfully. This uses some simple heuristics to decide
+ // if we should parse the number to the actual type. It does not do it all the time because it incurs
+ // extra allocations and we end up converting the type again when writing points to disk.
+ if isInt {
+ // Make sure the last char is an 'i' for integers (e.g. 9i10 is not valid)
+ if buf[i-1] != 'i' {
+ return i, ErrInvalidNumber
+ }
+ // Parse the int to check bounds the number of digits could be larger than the max range
+ // We subtract 1 from the index to remove the `i` from our tests
+ if len(buf[start:i-1]) >= maxInt64Digits || len(buf[start:i-1]) >= minInt64Digits {
+ if _, err := parseIntBytes(buf[start:i-1], 10, 64); err != nil {
+ return i, fmt.Errorf("unable to parse integer %s: %s", buf[start:i-1], err)
+ }
+ }
+ } else if isUnsigned {
+ // Return an error if uint64 support has not been enabled.
+ if !enableUint64Support {
+ return i, ErrInvalidNumber
+ }
+ // Make sure the last char is a 'u' for unsigned
+ if buf[i-1] != 'u' {
+ return i, ErrInvalidNumber
+ }
+ // Make sure the first char is not a '-' for unsigned
+ if buf[start] == '-' {
+ return i, ErrInvalidNumber
+ }
+ // Parse the uint to check bounds the number of digits could be larger than the max range
+ // We subtract 1 from the index to remove the `u` from our tests
+ if len(buf[start:i-1]) >= maxUint64Digits {
+ if _, err := parseUintBytes(buf[start:i-1], 10, 64); err != nil {
+ return i, fmt.Errorf("unable to parse unsigned %s: %s", buf[start:i-1], err)
+ }
+ }
+ } else {
+ // Parse the float to check bounds if it's scientific or the number of digits could be larger than the max range
+ if scientific || len(buf[start:i]) >= maxFloat64Digits || len(buf[start:i]) >= minFloat64Digits {
+ if _, err := parseFloatBytes(buf[start:i], 10); err != nil {
+ return i, fmt.Errorf("invalid float")
+ }
+ }
+ }
+
+ return i, nil
+}
+
+// scanBoolean returns the end position within buf, start at i after
+// scanning over buf for boolean. Valid values for a boolean are
+// t, T, true, TRUE, f, F, false, FALSE. It returns an error if a invalid boolean
+// is scanned.
+func scanBoolean(buf []byte, i int) (int, []byte, error) {
+ start := i
+
+ if i < len(buf) && (buf[i] != 't' && buf[i] != 'f' && buf[i] != 'T' && buf[i] != 'F') {
+ return i, buf[start:i], fmt.Errorf("invalid boolean")
+ }
+
+ i++
+ for {
+ if i >= len(buf) {
+ break
+ }
+
+ if buf[i] == ',' || buf[i] == ' ' {
+ break
+ }
+ i++
+ }
+
+ // Single char bool (t, T, f, F) is ok
+ if i-start == 1 {
+ return i, buf[start:i], nil
+ }
+
+ // length must be 4 for true or TRUE
+ if (buf[start] == 't' || buf[start] == 'T') && i-start != 4 {
+ return i, buf[start:i], fmt.Errorf("invalid boolean")
+ }
+
+ // length must be 5 for false or FALSE
+ if (buf[start] == 'f' || buf[start] == 'F') && i-start != 5 {
+ return i, buf[start:i], fmt.Errorf("invalid boolean")
+ }
+
+ // Otherwise
+ valid := false
+ switch buf[start] {
+ case 't':
+ valid = bytes.Equal(buf[start:i], []byte("true"))
+ case 'f':
+ valid = bytes.Equal(buf[start:i], []byte("false"))
+ case 'T':
+ valid = bytes.Equal(buf[start:i], []byte("TRUE")) || bytes.Equal(buf[start:i], []byte("True"))
+ case 'F':
+ valid = bytes.Equal(buf[start:i], []byte("FALSE")) || bytes.Equal(buf[start:i], []byte("False"))
+ }
+
+ if !valid {
+ return i, buf[start:i], fmt.Errorf("invalid boolean")
+ }
+
+ return i, buf[start:i], nil
+
+}
+
+// skipWhitespace returns the end position within buf, starting at i after
+// scanning over spaces in tags.
+func skipWhitespace(buf []byte, i int) int {
+ for i < len(buf) {
+ if buf[i] != ' ' && buf[i] != '\t' && buf[i] != 0 {
+ break
+ }
+ i++
+ }
+ return i
+}
+
+// scanLine returns the end position in buf and the next line found within
+// buf.
+func scanLine(buf []byte, i int) (int, []byte) {
+ start := i
+ quoted := false
+ fields := false
+
+ // tracks how many '=' and commas we've seen
+ // this duplicates some of the functionality in scanFields
+ equals := 0
+ commas := 0
+ for {
+ // reached the end of buf?
+ if i >= len(buf) {
+ break
+ }
+
+ // skip past escaped characters
+ if buf[i] == '\\' && i+2 < len(buf) {
+ i += 2
+ continue
+ }
+
+ if buf[i] == ' ' {
+ fields = true
+ }
+
+ // If we see a double quote, makes sure it is not escaped
+ if fields {
+ if !quoted && buf[i] == '=' {
+ i++
+ equals++
+ continue
+ } else if !quoted && buf[i] == ',' {
+ i++
+ commas++
+ continue
+ } else if buf[i] == '"' && equals > commas {
+ i++
+ quoted = !quoted
+ continue
+ }
+ }
+
+ if buf[i] == '\n' && !quoted {
+ break
+ }
+
+ i++
+ }
+
+ return i, buf[start:i]
+}
+
+// scanTo returns the end position in buf and the next consecutive block
+// of bytes, starting from i and ending with stop byte, where stop byte
+// has not been escaped.
+//
+// If there are leading spaces, they are skipped.
+func scanTo(buf []byte, i int, stop byte) (int, []byte) {
+ start := i
+ for {
+ // reached the end of buf?
+ if i >= len(buf) {
+ break
+ }
+
+ // Reached unescaped stop value?
+ if buf[i] == stop && (i == 0 || buf[i-1] != '\\') {
+ break
+ }
+ i++
+ }
+
+ return i, buf[start:i]
+}
+
+// scanTo returns the end position in buf and the next consecutive block
+// of bytes, starting from i and ending with stop byte. If there are leading
+// spaces, they are skipped.
+func scanToSpaceOr(buf []byte, i int, stop byte) (int, []byte) {
+ start := i
+ if buf[i] == stop || buf[i] == ' ' {
+ return i, buf[start:i]
+ }
+
+ for {
+ i++
+ if buf[i-1] == '\\' {
+ continue
+ }
+
+ // reached the end of buf?
+ if i >= len(buf) {
+ return i, buf[start:i]
+ }
+
+ // reached end of block?
+ if buf[i] == stop || buf[i] == ' ' {
+ return i, buf[start:i]
+ }
+ }
+}
+
+func scanTagValue(buf []byte, i int) (int, []byte) {
+ start := i
+ for {
+ if i >= len(buf) {
+ break
+ }
+
+ if buf[i] == ',' && buf[i-1] != '\\' {
+ break
+ }
+ i++
+ }
+ if i > len(buf) {
+ return i, nil
+ }
+ return i, buf[start:i]
+}
+
+func scanFieldValue(buf []byte, i int) (int, []byte) {
+ start := i
+ quoted := false
+ for i < len(buf) {
+ // Only escape char for a field value is a double-quote and backslash
+ if buf[i] == '\\' && i+1 < len(buf) && (buf[i+1] == '"' || buf[i+1] == '\\') {
+ i += 2
+ continue
+ }
+
+ // Quoted value? (e.g. string)
+ if buf[i] == '"' {
+ i++
+ quoted = !quoted
+ continue
+ }
+
+ if buf[i] == ',' && !quoted {
+ break
+ }
+ i++
+ }
+ return i, buf[start:i]
+}
+
+func EscapeMeasurement(in []byte) []byte {
+ for _, c := range measurementEscapeCodes {
+ if bytes.IndexByte(in, c.k[0]) != -1 {
+ in = bytes.Replace(in, c.k[:], c.esc[:], -1)
+ }
+ }
+ return in
+}
+
+func unescapeMeasurement(in []byte) []byte {
+ if bytes.IndexByte(in, '\\') == -1 {
+ return in
+ }
+
+ for i := range measurementEscapeCodes {
+ c := &measurementEscapeCodes[i]
+ if bytes.IndexByte(in, c.k[0]) != -1 {
+ in = bytes.Replace(in, c.esc[:], c.k[:], -1)
+ }
+ }
+ return in
+}
+
+func escapeTag(in []byte) []byte {
+ for i := range tagEscapeCodes {
+ c := &tagEscapeCodes[i]
+ if bytes.IndexByte(in, c.k[0]) != -1 {
+ in = bytes.Replace(in, c.k[:], c.esc[:], -1)
+ }
+ }
+ return in
+}
+
+func unescapeTag(in []byte) []byte {
+ if bytes.IndexByte(in, '\\') == -1 {
+ return in
+ }
+
+ for i := range tagEscapeCodes {
+ c := &tagEscapeCodes[i]
+ if bytes.IndexByte(in, c.k[0]) != -1 {
+ in = bytes.Replace(in, c.esc[:], c.k[:], -1)
+ }
+ }
+ return in
+}
+
+// escapeStringFieldReplacer replaces double quotes and backslashes
+// with the same character preceded by a backslash.
+// As of Go 1.7 this benchmarked better in allocations and CPU time
+// compared to iterating through a string byte-by-byte and appending to a new byte slice,
+// calling strings.Replace twice, and better than (*Regex).ReplaceAllString.
+var escapeStringFieldReplacer = strings.NewReplacer(`"`, `\"`, `\`, `\\`)
+
+// EscapeStringField returns a copy of in with any double quotes or
+// backslashes with escaped values.
+func EscapeStringField(in string) string {
+ return escapeStringFieldReplacer.Replace(in)
+}
+
+// unescapeStringField returns a copy of in with any escaped double-quotes
+// or backslashes unescaped.
+func unescapeStringField(in string) string {
+ if strings.IndexByte(in, '\\') == -1 {
+ return in
+ }
+
+ var out []byte
+ i := 0
+ for {
+ if i >= len(in) {
+ break
+ }
+ // unescape backslashes
+ if in[i] == '\\' && i+1 < len(in) && in[i+1] == '\\' {
+ out = append(out, '\\')
+ i += 2
+ continue
+ }
+ // unescape double-quotes
+ if in[i] == '\\' && i+1 < len(in) && in[i+1] == '"' {
+ out = append(out, '"')
+ i += 2
+ continue
+ }
+ out = append(out, in[i])
+ i++
+
+ }
+ return string(out)
+}
+
+// NewPoint returns a new point with the given measurement name, tags, fields and timestamp. If
+// an unsupported field value (NaN, or +/-Inf) or out of range time is passed, this function
+// returns an error.
+func NewPoint(name string, tags Tags, fields Fields, t time.Time) (Point, error) {
+ key, err := pointKey(name, tags, fields, t)
+ if err != nil {
+ return nil, err
+ }
+
+ return &point{
+ key: key,
+ time: t,
+ fields: fields.MarshalBinary(),
+ }, nil
+}
+
+// pointKey checks some basic requirements for valid points, and returns the
+// key, along with an possible error.
+func pointKey(measurement string, tags Tags, fields Fields, t time.Time) ([]byte, error) {
+ if len(fields) == 0 {
+ return nil, ErrPointMustHaveAField
+ }
+
+ if !t.IsZero() {
+ if err := CheckTime(t); err != nil {
+ return nil, err
+ }
+ }
+
+ for key, value := range fields {
+ switch value := value.(type) {
+ case float64:
+ // Ensure the caller validates and handles invalid field values
+ if math.IsInf(value, 0) {
+ return nil, fmt.Errorf("+/-Inf is an unsupported value for field %s", key)
+ }
+ if math.IsNaN(value) {
+ return nil, fmt.Errorf("NaN is an unsupported value for field %s", key)
+ }
+ case float32:
+ // Ensure the caller validates and handles invalid field values
+ if math.IsInf(float64(value), 0) {
+ return nil, fmt.Errorf("+/-Inf is an unsupported value for field %s", key)
+ }
+ if math.IsNaN(float64(value)) {
+ return nil, fmt.Errorf("NaN is an unsupported value for field %s", key)
+ }
+ }
+ if len(key) == 0 {
+ return nil, fmt.Errorf("all fields must have non-empty names")
+ }
+ }
+
+ key := MakeKey([]byte(measurement), tags)
+ for field := range fields {
+ sz := seriesKeySize(key, []byte(field))
+ if sz > MaxKeyLength {
+ return nil, fmt.Errorf("max key length exceeded: %v > %v", sz, MaxKeyLength)
+ }
+ }
+
+ return key, nil
+}
+
+func seriesKeySize(key, field []byte) int {
+ // 4 is the length of the tsm1.fieldKeySeparator constant. It's inlined here to avoid a circular
+ // dependency.
+ return len(key) + 4 + len(field)
+}
+
+// NewPointFromBytes returns a new Point from a marshalled Point.
+func NewPointFromBytes(b []byte) (Point, error) {
+ p := &point{}
+ if err := p.UnmarshalBinary(b); err != nil {
+ return nil, err
+ }
+
+ // This does some basic validation to ensure there are fields and they
+ // can be unmarshalled as well.
+ iter := p.FieldIterator()
+ var hasField bool
+ for iter.Next() {
+ if len(iter.FieldKey()) == 0 {
+ continue
+ }
+ hasField = true
+ switch iter.Type() {
+ case Float:
+ _, err := iter.FloatValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ case Integer:
+ _, err := iter.IntegerValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ case Unsigned:
+ _, err := iter.UnsignedValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ case String:
+ // Skip since this won't return an error
+ case Boolean:
+ _, err := iter.BooleanValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ }
+ }
+
+ if !hasField {
+ return nil, ErrPointMustHaveAField
+ }
+
+ return p, nil
+}
+
+// MustNewPoint returns a new point with the given measurement name, tags, fields and timestamp. If
+// an unsupported field value (NaN) is passed, this function panics.
+func MustNewPoint(name string, tags Tags, fields Fields, time time.Time) Point {
+ pt, err := NewPoint(name, tags, fields, time)
+ if err != nil {
+ panic(err.Error())
+ }
+ return pt
+}
+
+// Key returns the key (measurement joined with tags) of the point.
+func (p *point) Key() []byte {
+ return p.key
+}
+
+func (p *point) name() []byte {
+ _, name := scanTo(p.key, 0, ',')
+ return name
+}
+
+func (p *point) Name() []byte {
+ return escape.Unescape(p.name())
+}
+
+// SetName updates the measurement name for the point.
+func (p *point) SetName(name string) {
+ p.cachedName = ""
+ p.key = MakeKey([]byte(name), p.Tags())
+}
+
+// Time return the timestamp for the point.
+func (p *point) Time() time.Time {
+ return p.time
+}
+
+// SetTime updates the timestamp for the point.
+func (p *point) SetTime(t time.Time) {
+ p.time = t
+}
+
+// Round will round the timestamp of the point to the given duration.
+func (p *point) Round(d time.Duration) {
+ p.time = p.time.Round(d)
+}
+
+// Tags returns the tag set for the point.
+func (p *point) Tags() Tags {
+ if p.cachedTags != nil {
+ return p.cachedTags
+ }
+ p.cachedTags = parseTags(p.key, nil)
+ return p.cachedTags
+}
+
+func (p *point) ForEachTag(fn func(k, v []byte) bool) {
+ walkTags(p.key, fn)
+}
+
+func (p *point) HasTag(tag []byte) bool {
+ if len(p.key) == 0 {
+ return false
+ }
+
+ var exists bool
+ walkTags(p.key, func(key, value []byte) bool {
+ if bytes.Equal(tag, key) {
+ exists = true
+ return false
+ }
+ return true
+ })
+
+ return exists
+}
+
+func walkTags(buf []byte, fn func(key, value []byte) bool) {
+ if len(buf) == 0 {
+ return
+ }
+
+ pos, name := scanTo(buf, 0, ',')
+
+ // it's an empty key, so there are no tags
+ if len(name) == 0 {
+ return
+ }
+
+ hasEscape := bytes.IndexByte(buf, '\\') != -1
+ i := pos + 1
+ var key, value []byte
+ for {
+ if i >= len(buf) {
+ break
+ }
+ i, key = scanTo(buf, i, '=')
+ i, value = scanTagValue(buf, i+1)
+
+ if len(value) == 0 {
+ continue
+ }
+
+ if hasEscape {
+ if !fn(unescapeTag(key), unescapeTag(value)) {
+ return
+ }
+ } else {
+ if !fn(key, value) {
+ return
+ }
+ }
+
+ i++
+ }
+}
+
+// walkFields walks each field key and value via fn. If fn returns false, the iteration
+// is stopped. The values are the raw byte slices and not the converted types.
+func walkFields(buf []byte, fn func(key, value []byte) bool) error {
+ var i int
+ var key, val []byte
+ for len(buf) > 0 {
+ i, key = scanTo(buf, 0, '=')
+ if i > len(buf)-2 {
+ return fmt.Errorf("invalid value: field-key=%s", key)
+ }
+ buf = buf[i+1:]
+ i, val = scanFieldValue(buf, 0)
+ buf = buf[i:]
+ if !fn(key, val) {
+ break
+ }
+
+ // slice off comma
+ if len(buf) > 0 {
+ buf = buf[1:]
+ }
+ }
+ return nil
+}
+
+// parseTags parses buf into the provided destination tags, returning destination
+// Tags, which may have a different length and capacity.
+func parseTags(buf []byte, dst Tags) Tags {
+ if len(buf) == 0 {
+ return nil
+ }
+
+ n := bytes.Count(buf, []byte(","))
+ if cap(dst) < n {
+ dst = make(Tags, n)
+ } else {
+ dst = dst[:n]
+ }
+
+ // Ensure existing behaviour when point has no tags and nil slice passed in.
+ if dst == nil {
+ dst = Tags{}
+ }
+
+ // Series keys can contain escaped commas, therefore the number of commas
+ // in a series key only gives an estimation of the upper bound on the number
+ // of tags.
+ var i int
+ walkTags(buf, func(key, value []byte) bool {
+ dst[i].Key, dst[i].Value = key, value
+ i++
+ return true
+ })
+ return dst[:i]
+}
+
+// MakeKey creates a key for a set of tags.
+func MakeKey(name []byte, tags Tags) []byte {
+ return AppendMakeKey(nil, name, tags)
+}
+
+// AppendMakeKey appends the key derived from name and tags to dst and returns the extended buffer.
+func AppendMakeKey(dst []byte, name []byte, tags Tags) []byte {
+ // unescape the name and then re-escape it to avoid double escaping.
+ // The key should always be stored in escaped form.
+ dst = append(dst, EscapeMeasurement(unescapeMeasurement(name))...)
+ dst = tags.AppendHashKey(dst)
+ return dst
+}
+
+// SetTags replaces the tags for the point.
+func (p *point) SetTags(tags Tags) {
+ p.key = MakeKey(p.Name(), tags)
+ p.cachedTags = tags
+}
+
+// AddTag adds or replaces a tag value for a point.
+func (p *point) AddTag(key, value string) {
+ tags := p.Tags()
+ tags = append(tags, Tag{Key: []byte(key), Value: []byte(value)})
+ sort.Sort(tags)
+ p.cachedTags = tags
+ p.key = MakeKey(p.Name(), tags)
+}
+
+// Fields returns the fields for the point.
+func (p *point) Fields() (Fields, error) {
+ if p.cachedFields != nil {
+ return p.cachedFields, nil
+ }
+ cf, err := p.unmarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+ p.cachedFields = cf
+ return p.cachedFields, nil
+}
+
+// SetPrecision will round a time to the specified precision.
+func (p *point) SetPrecision(precision string) {
+ switch precision {
+ case "n":
+ case "u":
+ p.SetTime(p.Time().Truncate(time.Microsecond))
+ case "ms":
+ p.SetTime(p.Time().Truncate(time.Millisecond))
+ case "s":
+ p.SetTime(p.Time().Truncate(time.Second))
+ case "m":
+ p.SetTime(p.Time().Truncate(time.Minute))
+ case "h":
+ p.SetTime(p.Time().Truncate(time.Hour))
+ }
+}
+
+// String returns the string representation of the point.
+func (p *point) String() string {
+ if p.Time().IsZero() {
+ return string(p.Key()) + " " + string(p.fields)
+ }
+ return string(p.Key()) + " " + string(p.fields) + " " + strconv.FormatInt(p.UnixNano(), 10)
+}
+
+// AppendString appends the string representation of the point to buf.
+func (p *point) AppendString(buf []byte) []byte {
+ buf = append(buf, p.key...)
+ buf = append(buf, ' ')
+ buf = append(buf, p.fields...)
+
+ if !p.time.IsZero() {
+ buf = append(buf, ' ')
+ buf = strconv.AppendInt(buf, p.UnixNano(), 10)
+ }
+
+ return buf
+}
+
+// StringSize returns the length of the string that would be returned by String().
+func (p *point) StringSize() int {
+ size := len(p.key) + len(p.fields) + 1
+
+ if !p.time.IsZero() {
+ digits := 1 // even "0" has one digit
+ t := p.UnixNano()
+ if t < 0 {
+ // account for negative sign, then negate
+ digits++
+ t = -t
+ }
+ for t > 9 { // already accounted for one digit
+ digits++
+ t /= 10
+ }
+ size += digits + 1 // digits and a space
+ }
+
+ return size
+}
+
+// MarshalBinary returns a binary representation of the point.
+func (p *point) MarshalBinary() ([]byte, error) {
+ if len(p.fields) == 0 {
+ return nil, ErrPointMustHaveAField
+ }
+
+ tb, err := p.time.MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+
+ b := make([]byte, 8+len(p.key)+len(p.fields)+len(tb))
+ i := 0
+
+ binary.BigEndian.PutUint32(b[i:], uint32(len(p.key)))
+ i += 4
+
+ i += copy(b[i:], p.key)
+
+ binary.BigEndian.PutUint32(b[i:i+4], uint32(len(p.fields)))
+ i += 4
+
+ i += copy(b[i:], p.fields)
+
+ copy(b[i:], tb)
+ return b, nil
+}
+
+// UnmarshalBinary decodes a binary representation of the point into a point struct.
+func (p *point) UnmarshalBinary(b []byte) error {
+ var n int
+
+ // Read key length.
+ if len(b) < 4 {
+ return io.ErrShortBuffer
+ }
+ n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:]
+
+ // Read key.
+ if len(b) < n {
+ return io.ErrShortBuffer
+ }
+ p.key, b = b[:n], b[n:]
+
+ // Read fields length.
+ if len(b) < 4 {
+ return io.ErrShortBuffer
+ }
+ n, b = int(binary.BigEndian.Uint32(b[:4])), b[4:]
+
+ // Read fields.
+ if len(b) < n {
+ return io.ErrShortBuffer
+ }
+ p.fields, b = b[:n], b[n:]
+
+ // Read timestamp.
+ return p.time.UnmarshalBinary(b)
+}
+
+// PrecisionString returns a string representation of the point. If there
+// is a timestamp associated with the point then it will be specified in the
+// given unit.
+func (p *point) PrecisionString(precision string) string {
+ if p.Time().IsZero() {
+ return fmt.Sprintf("%s %s", p.Key(), string(p.fields))
+ }
+ return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields),
+ p.UnixNano()/GetPrecisionMultiplier(precision))
+}
+
+// RoundedString returns a string representation of the point. If there
+// is a timestamp associated with the point, then it will be rounded to the
+// given duration.
+func (p *point) RoundedString(d time.Duration) string {
+ if p.Time().IsZero() {
+ return fmt.Sprintf("%s %s", p.Key(), string(p.fields))
+ }
+ return fmt.Sprintf("%s %s %d", p.Key(), string(p.fields),
+ p.time.Round(d).UnixNano())
+}
+
+func (p *point) unmarshalBinary() (Fields, error) {
+ iter := p.FieldIterator()
+ fields := make(Fields, 8)
+ for iter.Next() {
+ if len(iter.FieldKey()) == 0 {
+ continue
+ }
+ switch iter.Type() {
+ case Float:
+ v, err := iter.FloatValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ fields[string(iter.FieldKey())] = v
+ case Integer:
+ v, err := iter.IntegerValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ fields[string(iter.FieldKey())] = v
+ case Unsigned:
+ v, err := iter.UnsignedValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ fields[string(iter.FieldKey())] = v
+ case String:
+ fields[string(iter.FieldKey())] = iter.StringValue()
+ case Boolean:
+ v, err := iter.BooleanValue()
+ if err != nil {
+ return nil, fmt.Errorf("unable to unmarshal field %s: %s", string(iter.FieldKey()), err)
+ }
+ fields[string(iter.FieldKey())] = v
+ }
+ }
+ return fields, nil
+}
+
+// HashID returns a non-cryptographic checksum of the point's key.
+func (p *point) HashID() uint64 {
+ h := NewInlineFNV64a()
+ h.Write(p.key)
+ sum := h.Sum64()
+ return sum
+}
+
+// UnixNano returns the timestamp of the point as nanoseconds since Unix epoch.
+func (p *point) UnixNano() int64 {
+ return p.Time().UnixNano()
+}
+
+// Split will attempt to return multiple points with the same timestamp whose
+// string representations are no longer than size. Points with a single field or
+// a point without a timestamp may exceed the requested size.
+func (p *point) Split(size int) []Point {
+ if p.time.IsZero() || p.StringSize() <= size {
+ return []Point{p}
+ }
+
+ // key string, timestamp string, spaces
+ size -= len(p.key) + len(strconv.FormatInt(p.time.UnixNano(), 10)) + 2
+
+ var points []Point
+ var start, cur int
+
+ for cur < len(p.fields) {
+ end, _ := scanTo(p.fields, cur, '=')
+ end, _ = scanFieldValue(p.fields, end+1)
+
+ if cur > start && end-start > size {
+ points = append(points, &point{
+ key: p.key,
+ time: p.time,
+ fields: p.fields[start : cur-1],
+ })
+ start = cur
+ }
+
+ cur = end + 1
+ }
+
+ points = append(points, &point{
+ key: p.key,
+ time: p.time,
+ fields: p.fields[start:],
+ })
+
+ return points
+}
+
+// Tag represents a single key/value tag pair.
+type Tag struct {
+ Key []byte
+ Value []byte
+}
+
+// NewTag returns a new Tag.
+func NewTag(key, value []byte) Tag {
+ return Tag{
+ Key: key,
+ Value: value,
+ }
+}
+
+// Size returns the size of the key and value.
+func (t Tag) Size() int { return len(t.Key) + len(t.Value) }
+
+// Clone returns a shallow copy of Tag.
+//
+// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed.
+// Use Clone to create a Tag with new byte slices that do not refer to the argument to ParsePointsWithPrecision.
+func (t Tag) Clone() Tag {
+ other := Tag{
+ Key: make([]byte, len(t.Key)),
+ Value: make([]byte, len(t.Value)),
+ }
+
+ copy(other.Key, t.Key)
+ copy(other.Value, t.Value)
+
+ return other
+}
+
+// String returns the string reprsentation of the tag.
+func (t *Tag) String() string {
+ var buf bytes.Buffer
+ buf.WriteByte('{')
+ buf.WriteString(string(t.Key))
+ buf.WriteByte(' ')
+ buf.WriteString(string(t.Value))
+ buf.WriteByte('}')
+ return buf.String()
+}
+
+// Tags represents a sorted list of tags.
+type Tags []Tag
+
+// NewTags returns a new Tags from a map.
+func NewTags(m map[string]string) Tags {
+ if len(m) == 0 {
+ return nil
+ }
+ a := make(Tags, 0, len(m))
+ for k, v := range m {
+ a = append(a, NewTag([]byte(k), []byte(v)))
+ }
+ sort.Sort(a)
+ return a
+}
+
+// HashKey hashes all of a tag's keys.
+func (a Tags) HashKey() []byte {
+ return a.AppendHashKey(nil)
+}
+
+func (a Tags) needsEscape() bool {
+ for i := range a {
+ t := &a[i]
+ for j := range tagEscapeCodes {
+ c := &tagEscapeCodes[j]
+ if bytes.IndexByte(t.Key, c.k[0]) != -1 || bytes.IndexByte(t.Value, c.k[0]) != -1 {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// AppendHashKey appends the result of hashing all of a tag's keys and values to dst and returns the extended buffer.
+func (a Tags) AppendHashKey(dst []byte) []byte {
+ // Empty maps marshal to empty bytes.
+ if len(a) == 0 {
+ return dst
+ }
+
+ // Type invariant: Tags are sorted
+
+ sz := 0
+ var escaped Tags
+ if a.needsEscape() {
+ var tmp [20]Tag
+ if len(a) < len(tmp) {
+ escaped = tmp[:len(a)]
+ } else {
+ escaped = make(Tags, len(a))
+ }
+
+ for i := range a {
+ t := &a[i]
+ nt := &escaped[i]
+ nt.Key = escapeTag(t.Key)
+ nt.Value = escapeTag(t.Value)
+ sz += len(nt.Key) + len(nt.Value)
+ }
+ } else {
+ sz = a.Size()
+ escaped = a
+ }
+
+ sz += len(escaped) + (len(escaped) * 2) // separators
+
+ // Generate marshaled bytes.
+ if cap(dst)-len(dst) < sz {
+ nd := make([]byte, len(dst), len(dst)+sz)
+ copy(nd, dst)
+ dst = nd
+ }
+ buf := dst[len(dst) : len(dst)+sz]
+ idx := 0
+ for i := range escaped {
+ k := &escaped[i]
+ if len(k.Value) == 0 {
+ continue
+ }
+ buf[idx] = ','
+ idx++
+ copy(buf[idx:], k.Key)
+ idx += len(k.Key)
+ buf[idx] = '='
+ idx++
+ copy(buf[idx:], k.Value)
+ idx += len(k.Value)
+ }
+ return dst[:len(dst)+idx]
+}
+
+// String returns the string representation of the tags.
+func (a Tags) String() string {
+ var buf bytes.Buffer
+ buf.WriteByte('[')
+ for i := range a {
+ buf.WriteString(a[i].String())
+ if i < len(a)-1 {
+ buf.WriteByte(' ')
+ }
+ }
+ buf.WriteByte(']')
+ return buf.String()
+}
+
+// Size returns the number of bytes needed to store all tags. Note, this is
+// the number of bytes needed to store all keys and values and does not account
+// for data structures or delimiters for example.
+func (a Tags) Size() int {
+ var total int
+ for i := range a {
+ total += a[i].Size()
+ }
+ return total
+}
+
+// Clone returns a copy of the slice where the elements are a result of calling `Clone` on the original elements
+//
+// Tags associated with a Point created by ParsePointsWithPrecision will hold references to the byte slice that was parsed.
+// Use Clone to create Tags with new byte slices that do not refer to the argument to ParsePointsWithPrecision.
+func (a Tags) Clone() Tags {
+ if len(a) == 0 {
+ return nil
+ }
+
+ others := make(Tags, len(a))
+ for i := range a {
+ others[i] = a[i].Clone()
+ }
+
+ return others
+}
+
+func (a Tags) Len() int { return len(a) }
+func (a Tags) Less(i, j int) bool { return bytes.Compare(a[i].Key, a[j].Key) == -1 }
+func (a Tags) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+
+// Equal returns true if a equals other.
+func (a Tags) Equal(other Tags) bool {
+ if len(a) != len(other) {
+ return false
+ }
+ for i := range a {
+ if !bytes.Equal(a[i].Key, other[i].Key) || !bytes.Equal(a[i].Value, other[i].Value) {
+ return false
+ }
+ }
+ return true
+}
+
+// CompareTags returns -1 if a < b, 1 if a > b, and 0 if a == b.
+func CompareTags(a, b Tags) int {
+ // Compare each key & value until a mismatch.
+ for i := 0; i < len(a) && i < len(b); i++ {
+ if cmp := bytes.Compare(a[i].Key, b[i].Key); cmp != 0 {
+ return cmp
+ }
+ if cmp := bytes.Compare(a[i].Value, b[i].Value); cmp != 0 {
+ return cmp
+ }
+ }
+
+ // If all tags are equal up to this point then return shorter tagset.
+ if len(a) < len(b) {
+ return -1
+ } else if len(a) > len(b) {
+ return 1
+ }
+
+ // All tags are equal.
+ return 0
+}
+
+// Get returns the value for a key.
+func (a Tags) Get(key []byte) []byte {
+ // OPTIMIZE: Use sort.Search if tagset is large.
+
+ for _, t := range a {
+ if bytes.Equal(t.Key, key) {
+ return t.Value
+ }
+ }
+ return nil
+}
+
+// GetString returns the string value for a string key.
+func (a Tags) GetString(key string) string {
+ return string(a.Get([]byte(key)))
+}
+
+// Set sets the value for a key.
+func (a *Tags) Set(key, value []byte) {
+ for i, t := range *a {
+ if bytes.Equal(t.Key, key) {
+ (*a)[i].Value = value
+ return
+ }
+ }
+ *a = append(*a, Tag{Key: key, Value: value})
+ sort.Sort(*a)
+}
+
+// SetString sets the string value for a string key.
+func (a *Tags) SetString(key, value string) {
+ a.Set([]byte(key), []byte(value))
+}
+
+// Map returns a map representation of the tags.
+func (a Tags) Map() map[string]string {
+ m := make(map[string]string, len(a))
+ for _, t := range a {
+ m[string(t.Key)] = string(t.Value)
+ }
+ return m
+}
+
+// CopyTags returns a shallow copy of tags.
+func CopyTags(a Tags) Tags {
+ other := make(Tags, len(a))
+ copy(other, a)
+ return other
+}
+
+// DeepCopyTags returns a deep copy of tags.
+func DeepCopyTags(a Tags) Tags {
+ // Calculate size of keys/values in bytes.
+ var n int
+ for _, t := range a {
+ n += len(t.Key) + len(t.Value)
+ }
+
+ // Build single allocation for all key/values.
+ buf := make([]byte, n)
+
+ // Copy tags to new set.
+ other := make(Tags, len(a))
+ for i, t := range a {
+ copy(buf, t.Key)
+ other[i].Key, buf = buf[:len(t.Key)], buf[len(t.Key):]
+
+ copy(buf, t.Value)
+ other[i].Value, buf = buf[:len(t.Value)], buf[len(t.Value):]
+ }
+
+ return other
+}
+
+// Fields represents a mapping between a Point's field names and their
+// values.
+type Fields map[string]interface{}
+
+// FieldIterator retuns a FieldIterator that can be used to traverse the
+// fields of a point without constructing the in-memory map.
+func (p *point) FieldIterator() FieldIterator {
+ p.Reset()
+ return p
+}
+
+type fieldIterator struct {
+ start, end int
+ key, keybuf []byte
+ valueBuf []byte
+ fieldType FieldType
+}
+
+// Next indicates whether there any fields remaining.
+func (p *point) Next() bool {
+ p.it.start = p.it.end
+ if p.it.start >= len(p.fields) {
+ return false
+ }
+
+ p.it.end, p.it.key = scanTo(p.fields, p.it.start, '=')
+ if escape.IsEscaped(p.it.key) {
+ p.it.keybuf = escape.AppendUnescaped(p.it.keybuf[:0], p.it.key)
+ p.it.key = p.it.keybuf
+ }
+
+ p.it.end, p.it.valueBuf = scanFieldValue(p.fields, p.it.end+1)
+ p.it.end++
+
+ if len(p.it.valueBuf) == 0 {
+ p.it.fieldType = Empty
+ return true
+ }
+
+ c := p.it.valueBuf[0]
+
+ if c == '"' {
+ p.it.fieldType = String
+ return true
+ }
+
+ if strings.IndexByte(`0123456789-.nNiIu`, c) >= 0 {
+ if p.it.valueBuf[len(p.it.valueBuf)-1] == 'i' {
+ p.it.fieldType = Integer
+ p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1]
+ } else if p.it.valueBuf[len(p.it.valueBuf)-1] == 'u' {
+ p.it.fieldType = Unsigned
+ p.it.valueBuf = p.it.valueBuf[:len(p.it.valueBuf)-1]
+ } else {
+ p.it.fieldType = Float
+ }
+ return true
+ }
+
+ // to keep the same behavior that currently exists, default to boolean
+ p.it.fieldType = Boolean
+ return true
+}
+
+// FieldKey returns the key of the current field.
+func (p *point) FieldKey() []byte {
+ return p.it.key
+}
+
+// Type returns the FieldType of the current field.
+func (p *point) Type() FieldType {
+ return p.it.fieldType
+}
+
+// StringValue returns the string value of the current field.
+func (p *point) StringValue() string {
+ return unescapeStringField(string(p.it.valueBuf[1 : len(p.it.valueBuf)-1]))
+}
+
+// IntegerValue returns the integer value of the current field.
+func (p *point) IntegerValue() (int64, error) {
+ n, err := parseIntBytes(p.it.valueBuf, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("unable to parse integer value %q: %v", p.it.valueBuf, err)
+ }
+ return n, nil
+}
+
+// UnsignedValue returns the unsigned value of the current field.
+func (p *point) UnsignedValue() (uint64, error) {
+ n, err := parseUintBytes(p.it.valueBuf, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("unable to parse unsigned value %q: %v", p.it.valueBuf, err)
+ }
+ return n, nil
+}
+
+// BooleanValue returns the boolean value of the current field.
+func (p *point) BooleanValue() (bool, error) {
+ b, err := parseBoolBytes(p.it.valueBuf)
+ if err != nil {
+ return false, fmt.Errorf("unable to parse bool value %q: %v", p.it.valueBuf, err)
+ }
+ return b, nil
+}
+
+// FloatValue returns the float value of the current field.
+func (p *point) FloatValue() (float64, error) {
+ f, err := parseFloatBytes(p.it.valueBuf, 64)
+ if err != nil {
+ return 0, fmt.Errorf("unable to parse floating point value %q: %v", p.it.valueBuf, err)
+ }
+ return f, nil
+}
+
+// Reset resets the iterator to its initial state.
+func (p *point) Reset() {
+ p.it.fieldType = Empty
+ p.it.key = nil
+ p.it.valueBuf = nil
+ p.it.start = 0
+ p.it.end = 0
+}
+
+// MarshalBinary encodes all the fields to their proper type and returns the binary
+// represenation
+// NOTE: uint64 is specifically not supported due to potential overflow when we decode
+// again later to an int64
+// NOTE2: uint is accepted, and may be 64 bits, and is for some reason accepted...
+func (p Fields) MarshalBinary() []byte {
+ var b []byte
+ keys := make([]string, 0, len(p))
+
+ for k := range p {
+ keys = append(keys, k)
+ }
+
+ // Not really necessary, can probably be removed.
+ sort.Strings(keys)
+
+ for i, k := range keys {
+ if i > 0 {
+ b = append(b, ',')
+ }
+ b = appendField(b, k, p[k])
+ }
+
+ return b
+}
+
+func appendField(b []byte, k string, v interface{}) []byte {
+ b = append(b, []byte(escape.String(k))...)
+ b = append(b, '=')
+
+ // check popular types first
+ switch v := v.(type) {
+ case float64:
+ b = strconv.AppendFloat(b, v, 'f', -1, 64)
+ case int64:
+ b = strconv.AppendInt(b, v, 10)
+ b = append(b, 'i')
+ case string:
+ b = append(b, '"')
+ b = append(b, []byte(EscapeStringField(v))...)
+ b = append(b, '"')
+ case bool:
+ b = strconv.AppendBool(b, v)
+ case int32:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case int16:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case int8:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case int:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case uint64:
+ b = strconv.AppendUint(b, v, 10)
+ b = append(b, 'u')
+ case uint32:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case uint16:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case uint8:
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case uint:
+ // TODO: 'uint' should be converted to writing as an unsigned integer,
+ // but we cannot since that would break backwards compatibility.
+ b = strconv.AppendInt(b, int64(v), 10)
+ b = append(b, 'i')
+ case float32:
+ b = strconv.AppendFloat(b, float64(v), 'f', -1, 32)
+ case []byte:
+ b = append(b, v...)
+ case nil:
+ // skip
+ default:
+ // Can't determine the type, so convert to string
+ b = append(b, '"')
+ b = append(b, []byte(EscapeStringField(fmt.Sprintf("%v", v)))...)
+ b = append(b, '"')
+
+ }
+
+ return b
+}
+
+// ValidKeyToken returns true if the token used for measurement, tag key, or tag
+// value is a valid unicode string and only contains printable, non-replacement characters.
+func ValidKeyToken(s string) bool {
+ if !utf8.ValidString(s) {
+ return false
+ }
+ for _, r := range s {
+ if !unicode.IsPrint(r) || r == unicode.ReplacementChar {
+ return false
+ }
+ }
+ return true
+}
+
+// ValidKeyTokens returns true if the measurement name and all tags are valid.
+func ValidKeyTokens(name string, tags Tags) bool {
+ if !ValidKeyToken(name) {
+ return false
+ }
+ for _, tag := range tags {
+ if !ValidKeyToken(string(tag.Key)) || !ValidKeyToken(string(tag.Value)) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/influxdb/client/models/rows.go b/influxdb/client/models/rows.go
new file mode 100644
index 0000000..1d10292
--- /dev/null
+++ b/influxdb/client/models/rows.go
@@ -0,0 +1,75 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models
+
+import (
+ "sort"
+)
+
+// Row represents a single row returned from the execution of a statement.
+type Row struct {
+ Name string `json:"name,omitempty"`
+ Tags map[string]string `json:"tags,omitempty"`
+ Columns []string `json:"columns,omitempty"`
+ Values [][]interface{} `json:"values,omitempty"`
+ Partial bool `json:"partial,omitempty"`
+}
+
+// SameSeries returns true if r contains values for the same series as o.
+func (r *Row) SameSeries(o *Row) bool {
+ return r.tagsHash() == o.tagsHash() && r.Name == o.Name
+}
+
+// tagsHash returns a hash of tag key/value pairs.
+func (r *Row) tagsHash() uint64 {
+ h := NewInlineFNV64a()
+ keys := r.tagsKeys()
+ for _, k := range keys {
+ h.Write([]byte(k))
+ h.Write([]byte(r.Tags[k]))
+ }
+ return h.Sum64()
+}
+
+// tagKeys returns a sorted list of tag keys.
+func (r *Row) tagsKeys() []string {
+ a := make([]string, 0, len(r.Tags))
+ for k := range r.Tags {
+ a = append(a, k)
+ }
+ sort.Strings(a)
+ return a
+}
+
+// Rows represents a collection of rows. Rows implements sort.Interface.
+type Rows []*Row
+
+// Len implements sort.Interface.
+func (p Rows) Len() int { return len(p) }
+
+// Less implements sort.Interface.
+func (p Rows) Less(i, j int) bool {
+ // Sort by name first.
+ if p[i].Name != p[j].Name {
+ return p[i].Name < p[j].Name
+ }
+
+ // Sort by tag set hash. Tags don't have a meaningful sort order so we
+ // just compute a hash and sort by that instead. This allows the tests
+ // to receive rows in a predictable order every time.
+ return p[i].tagsHash() < p[j].tagsHash()
+}
+
+// Swap implements sort.Interface.
+func (p Rows) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
diff --git a/influxdb/client/models/statistic.go b/influxdb/client/models/statistic.go
new file mode 100644
index 0000000..4c963b1
--- /dev/null
+++ b/influxdb/client/models/statistic.go
@@ -0,0 +1,55 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models
+
+// Statistic is the representation of a statistic used by the monitoring service.
+type Statistic struct {
+ Name string `json:"name"`
+ Tags map[string]string `json:"tags"`
+ Values map[string]interface{} `json:"values"`
+}
+
+// NewStatistic returns an initialized Statistic.
+func NewStatistic(name string) Statistic {
+ return Statistic{
+ Name: name,
+ Tags: make(map[string]string),
+ Values: make(map[string]interface{}),
+ }
+}
+
+// StatisticTags is a map that can be merged with others without causing
+// mutations to either map.
+type StatisticTags map[string]string
+
+// Merge creates a new map containing the merged contents of tags and t.
+// If both tags and the receiver map contain the same key, the value in tags
+// is used in the resulting map.
+//
+// Merge always returns a usable map.
+func (t StatisticTags) Merge(tags map[string]string) map[string]string {
+ // Add everything in tags to the result.
+ out := make(map[string]string, len(tags))
+ for k, v := range tags {
+ out[k] = v
+ }
+
+ // Only add values from t that don't appear in tags.
+ for k, v := range t {
+ if _, ok := tags[k]; !ok {
+ out[k] = v
+ }
+ }
+ return out
+}
diff --git a/influxdb/client/models/time.go b/influxdb/client/models/time.go
new file mode 100644
index 0000000..cb46414
--- /dev/null
+++ b/influxdb/client/models/time.go
@@ -0,0 +1,87 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models
+
+// Helper time methods since parsing time can easily overflow and we only support a
+// specific time range.
+
+import (
+ "fmt"
+ "math"
+ "time"
+)
+
+const (
+ // MinNanoTime is the minumum time that can be represented.
+ //
+ // 1677-09-21 00:12:43.145224194 +0000 UTC
+ //
+ // The two lowest minimum integers are used as sentinel values. The
+ // minimum value needs to be used as a value lower than any other value for
+ // comparisons and another separate value is needed to act as a sentinel
+ // default value that is unusable by the user, but usable internally.
+ // Because these two values need to be used for a special purpose, we do
+ // not allow users to write points at these two times.
+ MinNanoTime = int64(math.MinInt64) + 2
+
+ // MaxNanoTime is the maximum time that can be represented.
+ //
+ // 2262-04-11 23:47:16.854775806 +0000 UTC
+ //
+ // The highest time represented by a nanosecond needs to be used for an
+ // exclusive range in the shard group, so the maximum time needs to be one
+ // less than the possible maximum number of nanoseconds representable by an
+ // int64 so that we don't lose a point at that one time.
+ MaxNanoTime = int64(math.MaxInt64) - 1
+)
+
+var (
+ minNanoTime = time.Unix(0, MinNanoTime).UTC()
+ maxNanoTime = time.Unix(0, MaxNanoTime).UTC()
+
+ // ErrTimeOutOfRange gets returned when time is out of the representable range using int64 nanoseconds since the epoch.
+ ErrTimeOutOfRange = fmt.Errorf("time outside range %d - %d", MinNanoTime, MaxNanoTime)
+)
+
+// SafeCalcTime safely calculates the time given. Will return error if the time is outside the
+// supported range.
+func SafeCalcTime(timestamp int64, precision string) (time.Time, error) {
+ mult := GetPrecisionMultiplier(precision)
+ if t, ok := safeSignedMult(timestamp, mult); ok {
+ tme := time.Unix(0, t).UTC()
+ return tme, CheckTime(tme)
+ }
+
+ return time.Time{}, ErrTimeOutOfRange
+}
+
+// CheckTime checks that a time is within the safe range.
+func CheckTime(t time.Time) error {
+ if t.Before(minNanoTime) || t.After(maxNanoTime) {
+ return ErrTimeOutOfRange
+ }
+ return nil
+}
+
+// Perform the multiplication and check to make sure it didn't overflow.
+func safeSignedMult(a, b int64) (int64, bool) {
+ if a == 0 || b == 0 || a == 1 || b == 1 {
+ return a * b, true
+ }
+ if a == MinNanoTime || b == MaxNanoTime {
+ return 0, false
+ }
+ c := a * b
+ return c, c/b == a
+}
diff --git a/influxdb/client/models/uint_support.go b/influxdb/client/models/uint_support.go
new file mode 100644
index 0000000..dcff490
--- /dev/null
+++ b/influxdb/client/models/uint_support.go
@@ -0,0 +1,21 @@
+//go:build uint || uint64
+// +build uint uint64
+
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package models
+
+func init() {
+ EnableUintSupport()
+}
diff --git a/influxdb/client/pkg/escape/bytes.go b/influxdb/client/pkg/escape/bytes.go
new file mode 100644
index 0000000..5d93c50
--- /dev/null
+++ b/influxdb/client/pkg/escape/bytes.go
@@ -0,0 +1,127 @@
+// and InfluxDB line protocol.
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package escape // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/pkg/escape"
+
+import (
+ "bytes"
+ "strings"
+)
+
+// Codes is a map of bytes to be escaped.
+var Codes = map[byte][]byte{
+ ',': []byte(`\,`),
+ '"': []byte(`\"`),
+ ' ': []byte(`\ `),
+ '=': []byte(`\=`),
+}
+
+// Bytes escapes characters on the input slice, as defined by Codes.
+func Bytes(in []byte) []byte {
+ for b, esc := range Codes {
+ in = bytes.Replace(in, []byte{b}, esc, -1)
+ }
+ return in
+}
+
+const escapeChars = `," =`
+
+// IsEscaped returns whether b has any escaped characters,
+// i.e. whether b seems to have been processed by Bytes.
+func IsEscaped(b []byte) bool {
+ for len(b) > 0 {
+ i := bytes.IndexByte(b, '\\')
+ if i < 0 {
+ return false
+ }
+
+ if i+1 < len(b) && strings.IndexByte(escapeChars, b[i+1]) >= 0 {
+ return true
+ }
+ b = b[i+1:]
+ }
+ return false
+}
+
+// AppendUnescaped appends the unescaped version of src to dst
+// and returns the resulting slice.
+func AppendUnescaped(dst, src []byte) []byte {
+ var pos int
+ for len(src) > 0 {
+ next := bytes.IndexByte(src[pos:], '\\')
+ if next < 0 || pos+next+1 >= len(src) {
+ return append(dst, src...)
+ }
+
+ if pos+next+1 < len(src) && strings.IndexByte(escapeChars, src[pos+next+1]) >= 0 {
+ if pos+next > 0 {
+ dst = append(dst, src[:pos+next]...)
+ }
+ src = src[pos+next+1:]
+ pos = 0
+ } else {
+ pos += next + 1
+ }
+ }
+
+ return dst
+}
+
+// Unescape returns a new slice containing the unescaped version of in.
+func Unescape(in []byte) []byte {
+ if len(in) == 0 {
+ return nil
+ }
+
+ if bytes.IndexByte(in, '\\') == -1 {
+ return in
+ }
+
+ i := 0
+ inLen := len(in)
+
+ // The output size will be no more than inLen. Preallocating the
+ // capacity of the output is faster and uses less memory than
+ // letting append() do its own (over)allocation.
+ out := make([]byte, 0, inLen)
+
+ for {
+ if i >= inLen {
+ break
+ }
+ if in[i] == '\\' && i+1 < inLen {
+ switch in[i+1] {
+ case ',':
+ out = append(out, ',')
+ i += 2
+ continue
+ case '"':
+ out = append(out, '"')
+ i += 2
+ continue
+ case ' ':
+ out = append(out, ' ')
+ i += 2
+ continue
+ case '=':
+ out = append(out, '=')
+ i += 2
+ continue
+ }
+ }
+ out = append(out, in[i])
+ i += 1
+ }
+ return out
+}
diff --git a/influxdb/client/pkg/escape/strings.go b/influxdb/client/pkg/escape/strings.go
new file mode 100644
index 0000000..5de0a01
--- /dev/null
+++ b/influxdb/client/pkg/escape/strings.go
@@ -0,0 +1,34 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package escape
+
+import "strings"
+
+var (
+ escaper = strings.NewReplacer(`,`, `\,`, `"`, `\"`, ` `, `\ `, `=`, `\=`)
+ unescaper = strings.NewReplacer(`\,`, `,`, `\"`, `"`, `\ `, ` `, `\=`, `=`)
+)
+
+// UnescapeString returns unescaped version of in.
+func UnescapeString(in string) string {
+ if strings.IndexByte(in, '\\') == -1 {
+ return in
+ }
+ return unescaper.Replace(in)
+}
+
+// String returns the escaped version of in.
+func String(in string) string {
+ return escaper.Replace(in)
+}
diff --git a/influxdb/client/v2/client.go b/influxdb/client/v2/client.go
new file mode 100644
index 0000000..8a729dc
--- /dev/null
+++ b/influxdb/client/v2/client.go
@@ -0,0 +1,823 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package client // import "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2"
+
+import (
+ "bytes"
+ "compress/gzip"
+ "crypto/tls"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime"
+ "net"
+ "net/http"
+ "net/url"
+ "path"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/models"
+)
+
+type ContentEncoding string
+
+const (
+ DefaultEncoding ContentEncoding = ""
+ GzipEncoding ContentEncoding = "gzip"
+ DefaultMaxIdleConns = 30
+)
+
+// HTTPConfig is the config data needed to create an HTTP Client.
+type HTTPConfig struct {
+ // Addr should be of the form "http://host:port"
+ // or "http://[ipv6-host%zone]:port".
+ Addr string
+
+ // Username is the influxdb username, optional.
+ Username string
+
+ // Password is the influxdb password, optional.
+ Password string
+
+ // UserAgent is the http User Agent, defaults to "InfluxDBClient".
+ UserAgent string
+
+ // Timeout for influxdb writes, defaults to no timeout.
+ Timeout time.Duration
+
+ // InsecureSkipVerify gets passed to the http client, if true, it will
+ // skip https certificate verification. Defaults to false.
+ InsecureSkipVerify bool
+
+ // TLSConfig allows the user to set their own TLS config for the HTTP
+ // Client. If set, this option overrides InsecureSkipVerify.
+ TLSConfig *tls.Config
+
+ // Proxy configures the Proxy function on the HTTP client.
+ Proxy func(req *http.Request) (*url.URL, error)
+
+ // WriteEncoding specifies the encoding of write request
+ WriteEncoding ContentEncoding
+
+ // MaxIdleConns controls the maximum number of idle (keep-alive)
+ MaxIdleConns int
+
+ // MaxIdleConnsPerHost, if non-zero, controls the maximum idle
+ MaxIdleConnsPerHost int
+
+ // IdleConnTimeout is the maximum amount of time an idle
+ IdleConnTimeout time.Duration
+}
+
+// BatchPointsConfig is the config data needed to create an instance of the BatchPoints struct.
+type BatchPointsConfig struct {
+ // Precision is the write precision of the points, defaults to "ns".
+ Precision string
+
+ // Database is the database to write points to.
+ Database string
+
+ // RetentionPolicy is the retention policy of the points.
+ RetentionPolicy string
+
+ // Write consistency is the number of servers required to confirm write.
+ WriteConsistency string
+}
+
+// Client is a client interface for writing & querying the database.
+type Client interface {
+ // Ping checks that status of cluster, and will always return 0 time and no
+ // error for UDP clients.
+ Ping(timeout time.Duration) (time.Duration, string, error)
+
+ // Write takes a BatchPoints object and writes all Points to InfluxDB.
+ Write(bp BatchPoints) error
+
+ // Query makes an InfluxDB Query on the database. This will fail if using
+ // the UDP client.
+ Query(q Query) (*Response, error)
+
+ // QueryAsChunk makes an InfluxDB Query on the database. This will fail if using
+ // the UDP client.
+ QueryAsChunk(q Query) (*ChunkedResponse, error)
+
+ // Close releases any resources a Client may be using.
+ Close() error
+}
+
+// NewHTTPClient returns a new Client from the provided config.
+// Client is safe for concurrent use by multiple goroutines.
+func NewHTTPClient(conf HTTPConfig) (Client, error) {
+ if conf.UserAgent == "" {
+ conf.UserAgent = "InfluxDBClient"
+ }
+
+ u, err := url.Parse(conf.Addr)
+ if err != nil {
+ return nil, err
+ } else if u.Scheme != "http" && u.Scheme != "https" {
+ m := fmt.Sprintf("Unsupported protocol scheme: %s, your address"+
+ " must start with http:// or https://", u.Scheme)
+ return nil, errors.New(m)
+ }
+
+ switch conf.WriteEncoding {
+ case DefaultEncoding, GzipEncoding:
+ default:
+ return nil, fmt.Errorf("unsupported encoding %s", conf.WriteEncoding)
+ }
+
+ if conf.MaxIdleConns == 0 {
+ conf.MaxIdleConns = DefaultMaxIdleConns
+ }
+ if conf.MaxIdleConnsPerHost == 0 {
+ conf.MaxIdleConnsPerHost = conf.MaxIdleConns
+ }
+
+ tr := &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: conf.InsecureSkipVerify,
+ },
+ Proxy: conf.Proxy,
+ MaxIdleConns: conf.MaxIdleConns,
+ MaxIdleConnsPerHost: conf.MaxIdleConnsPerHost,
+ IdleConnTimeout: conf.IdleConnTimeout,
+ DialContext: (&net.Dialer{
+ KeepAlive: time.Second * 60,
+ }).DialContext,
+ }
+ if conf.TLSConfig != nil {
+ tr.TLSClientConfig = conf.TLSConfig
+ }
+ return &client{
+ url: *u,
+ username: conf.Username,
+ password: conf.Password,
+ useragent: conf.UserAgent,
+ httpClient: &http.Client{
+ Timeout: conf.Timeout,
+ Transport: tr,
+ },
+ transport: tr,
+ encoding: conf.WriteEncoding,
+ }, nil
+}
+
+// Ping will check to see if the server is up with an optional timeout on waiting for leader.
+// Ping returns how long the request took, the version of the server it connected to, and an error if one occurred.
+func (c *client) Ping(timeout time.Duration) (time.Duration, string, error) {
+ now := time.Now()
+
+ u := c.url
+ u.Path = path.Join(u.Path, "ping")
+
+ req, err := http.NewRequest("GET", u.String(), nil)
+ if err != nil {
+ return 0, "", err
+ }
+
+ req.Header.Set("User-Agent", c.useragent)
+
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ if timeout > 0 {
+ params := req.URL.Query()
+ params.Set("wait_for_leader", fmt.Sprintf("%.0fs", timeout.Seconds()))
+ req.URL.RawQuery = params.Encode()
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return 0, "", err
+ }
+ defer resp.Body.Close()
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return 0, "", err
+ }
+
+ if resp.StatusCode != http.StatusNoContent {
+ var err = errors.New(string(body))
+ return 0, "", err
+ }
+
+ version := resp.Header.Get("X-Influxdb-Version")
+ return time.Since(now), version, nil
+}
+
+// Close releases the client's resources.
+func (c *client) Close() error {
+ c.transport.CloseIdleConnections()
+ return nil
+}
+
+// client is safe for concurrent use as the fields are all read-only
+// once the client is instantiated.
+type client struct {
+ // N.B - if url.UserInfo is accessed in future modifications to the
+ // methods on client, you will need to synchronize access to url.
+ url url.URL
+ username string
+ password string
+ useragent string
+ httpClient *http.Client
+ transport *http.Transport
+ encoding ContentEncoding
+}
+
+// BatchPoints is an interface into a batched grouping of points to write into
+// InfluxDB together. BatchPoints is NOT thread-safe, you must create a separate
+// batch for each goroutine.
+type BatchPoints interface {
+ // AddPoint adds the given point to the Batch of points.
+ AddPoint(p *Point)
+ // AddPoints adds the given points to the Batch of points.
+ AddPoints(ps []*Point)
+ // Points lists the points in the Batch.
+ Points() []*Point
+ // ClearPoints clear all points
+ ClearPoints()
+
+ //ClearPoints get the number of Point
+ GetPointsNum() int
+
+ // Precision returns the currently set precision of this Batch.
+ Precision() string
+ // SetPrecision sets the precision of this batch.
+ SetPrecision(s string) error
+
+ // Database returns the currently set database of this Batch.
+ Database() string
+ // SetDatabase sets the database of this Batch.
+ SetDatabase(s string)
+
+ // WriteConsistency returns the currently set write consistency of this Batch.
+ WriteConsistency() string
+ // SetWriteConsistency sets the write consistency of this Batch.
+ SetWriteConsistency(s string)
+
+ // RetentionPolicy returns the currently set retention policy of this Batch.
+ RetentionPolicy() string
+ // SetRetentionPolicy sets the retention policy of this Batch.
+ SetRetentionPolicy(s string)
+}
+
+// NewBatchPoints returns a BatchPoints interface based on the given config.
+func NewBatchPoints(conf BatchPointsConfig) (BatchPoints, error) {
+ if conf.Precision == "" {
+ conf.Precision = "ns"
+ }
+ if _, err := time.ParseDuration("1" + conf.Precision); err != nil {
+ return nil, err
+ }
+ bp := &batchpoints{
+ database: conf.Database,
+ precision: conf.Precision,
+ retentionPolicy: conf.RetentionPolicy,
+ writeConsistency: conf.WriteConsistency,
+ }
+ return bp, nil
+}
+
+type batchpoints struct {
+ points []*Point
+ database string
+ precision string
+ retentionPolicy string
+ writeConsistency string
+}
+
+func (bp *batchpoints) AddPoint(p *Point) {
+ bp.points = append(bp.points, p)
+}
+
+func (bp *batchpoints) AddPoints(ps []*Point) {
+ bp.points = append(bp.points, ps...)
+}
+
+func (bp *batchpoints) Points() []*Point {
+ return bp.points
+}
+
+func (bp *batchpoints) ClearPoints() {
+ bp.points = bp.points[0:0]
+}
+
+func (bp *batchpoints) GetPointsNum() int {
+ return len(bp.points)
+}
+
+func (bp *batchpoints) Precision() string {
+ return bp.precision
+}
+
+func (bp *batchpoints) Database() string {
+ return bp.database
+}
+
+func (bp *batchpoints) WriteConsistency() string {
+ return bp.writeConsistency
+}
+
+func (bp *batchpoints) RetentionPolicy() string {
+ return bp.retentionPolicy
+}
+
+func (bp *batchpoints) SetPrecision(p string) error {
+ if _, err := time.ParseDuration("1" + p); err != nil {
+ return err
+ }
+ bp.precision = p
+ return nil
+}
+
+func (bp *batchpoints) SetDatabase(db string) {
+ bp.database = db
+}
+
+func (bp *batchpoints) SetWriteConsistency(wc string) {
+ bp.writeConsistency = wc
+}
+
+func (bp *batchpoints) SetRetentionPolicy(rp string) {
+ bp.retentionPolicy = rp
+}
+
+// Point represents a single data point.
+type Point struct {
+ pt models.Point
+}
+
+// NewPoint returns a point with the given timestamp. If a timestamp is not
+// given, then data is sent to the database without a timestamp, in which case
+// the server will assign local time upon reception. NOTE: it is recommended to
+// send data with a timestamp.
+func NewPoint(
+ name string,
+ tags map[string]string,
+ fields map[string]interface{},
+ t ...time.Time,
+) (*Point, error) {
+ var T time.Time
+ if len(t) > 0 {
+ T = t[0]
+ }
+
+ pt, err := models.NewPoint(name, models.NewTags(tags), fields, T)
+ if err != nil {
+ return nil, err
+ }
+ return &Point{
+ pt: pt,
+ }, nil
+}
+
+// String returns a line-protocol string of the Point.
+func (p *Point) String() string {
+ return p.pt.String()
+}
+
+// PrecisionString returns a line-protocol string of the Point,
+// with the timestamp formatted for the given precision.
+func (p *Point) PrecisionString(precision string) string {
+ return p.pt.PrecisionString(precision)
+}
+
+// Name returns the measurement name of the point.
+func (p *Point) Name() string {
+ return string(p.pt.Name())
+}
+
+// Tags returns the tags associated with the point.
+func (p *Point) Tags() map[string]string {
+ return p.pt.Tags().Map()
+}
+
+// Time return the timestamp for the point.
+func (p *Point) Time() time.Time {
+ return p.pt.Time()
+}
+
+// UnixNano returns timestamp of the point in nanoseconds since Unix epoch.
+func (p *Point) UnixNano() int64 {
+ return p.pt.UnixNano()
+}
+
+// Fields returns the fields for the point.
+func (p *Point) Fields() (map[string]interface{}, error) {
+ return p.pt.Fields()
+}
+
+// NewPointFrom returns a point from the provided models.Point.
+func NewPointFrom(pt models.Point) *Point {
+ return &Point{pt: pt}
+}
+
+func (c *client) Write(bp BatchPoints) error {
+ var b bytes.Buffer
+
+ var w io.Writer
+ if c.encoding == GzipEncoding {
+ w = gzip.NewWriter(&b)
+ } else {
+ w = &b
+ }
+
+ for _, p := range bp.Points() {
+ if p == nil {
+ continue
+ }
+ if _, err := io.WriteString(w, p.pt.PrecisionString(bp.Precision())); err != nil && err != io.EOF {
+ return err
+ }
+
+ if _, err := w.Write([]byte{'\n'}); err != nil && err != io.EOF {
+ return err
+ }
+ }
+
+ // gzip writer should be closed to flush data into underlying buffer
+ if c, ok := w.(io.Closer); ok {
+ if err := c.Close(); err != nil && err != io.EOF {
+ return err
+ }
+ }
+
+ u := c.url
+ u.Path = path.Join(u.Path, "write")
+
+ req, err := http.NewRequest("POST", u.String(), &b)
+ if err == io.EOF {
+ err = nil
+ }
+ if err != nil {
+ return err
+ }
+ if c.encoding != DefaultEncoding {
+ req.Header.Set("Content-Encoding", string(c.encoding))
+ }
+ req.Header.Set("Content-Type", "")
+ req.Header.Set("User-Agent", c.useragent)
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ params := req.URL.Query()
+ params.Set("db", bp.Database())
+ params.Set("rp", bp.RetentionPolicy())
+ params.Set("precision", bp.Precision())
+ params.Set("consistency", bp.WriteConsistency())
+ req.URL.RawQuery = params.Encode()
+
+ resp, err := c.httpClient.Do(req)
+ if err == io.EOF {
+ err = nil
+ }
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err == io.EOF {
+ err = nil
+ }
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
+ var err = errors.New(string(body))
+ if err == io.EOF {
+ err = nil
+ }
+ return err
+ }
+
+ return nil
+}
+
+// Query defines a query to send to the server.
+type Query struct {
+ Command string
+ Database string
+ RetentionPolicy string
+ Precision string
+ Chunked bool
+ ChunkSize int
+ Parameters map[string]interface{}
+}
+
+// Params is a type alias to the query parameters.
+type Params map[string]interface{}
+
+// NewQuery returns a query object.
+// The database and precision arguments can be empty strings if they are not needed for the query.
+func NewQuery(command, database, precision string) Query {
+ return Query{
+ Command: command,
+ Database: database,
+ Precision: precision,
+ Parameters: make(map[string]interface{}),
+ }
+}
+
+// NewQueryWithRP returns a query object.
+// The database, retention policy, and precision arguments can be empty strings if they are not needed
+// for the query. Setting the retention policy only works on InfluxDB versions 1.6 or greater.
+func NewQueryWithRP(command, database, retentionPolicy, precision string) Query {
+ return Query{
+ Command: command,
+ Database: database,
+ RetentionPolicy: retentionPolicy,
+ Precision: precision,
+ Parameters: make(map[string]interface{}),
+ }
+}
+
+// NewQueryWithParameters returns a query object.
+// The database and precision arguments can be empty strings if they are not needed for the query.
+// parameters is a map of the parameter names used in the command to their values.
+func NewQueryWithParameters(command, database, precision string, parameters map[string]interface{}) Query {
+ return Query{
+ Command: command,
+ Database: database,
+ Precision: precision,
+ Parameters: parameters,
+ }
+}
+
+// Response represents a list of statement results.
+type Response struct {
+ Results []Result
+ Err string `json:"error,omitempty"`
+}
+
+// Error returns the first error from any statement.
+// It returns nil if no errors occurred on any statements.
+func (r *Response) Error() error {
+ if r.Err != "" {
+ return errors.New(r.Err)
+ }
+ for _, result := range r.Results {
+ if result.Err != "" {
+ return errors.New(result.Err)
+ }
+ }
+ return nil
+}
+
+// Message represents a user message.
+type Message struct {
+ Level string
+ Text string
+}
+
+// Result represents a resultset returned from a single statement.
+type Result struct {
+ StatementId int `json:"statement_id"`
+ Series []models.Row
+ Messages []*Message
+ Err string `json:"error,omitempty"`
+}
+
+// Query sends a command to the server and returns the Response.
+func (c *client) Query(q Query) (*Response, error) {
+ req, err := c.createDefaultRequest(q)
+ if err != nil {
+ return nil, err
+ }
+ params := req.URL.Query()
+ if q.Chunked {
+ params.Set("chunked", "true")
+ if q.ChunkSize > 0 {
+ params.Set("chunk_size", strconv.Itoa(q.ChunkSize))
+ }
+ req.URL.RawQuery = params.Encode()
+ }
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if err := checkResponse(resp); err != nil {
+ return nil, err
+ }
+
+ var response Response
+ if q.Chunked {
+ cr := NewChunkedResponse(resp.Body)
+ for {
+ r, err := cr.NextResponse()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ // If we got an error while decoding the response, send that back.
+ return nil, err
+ }
+
+ if r == nil {
+ break
+ }
+
+ response.Results = append(response.Results, r.Results...)
+ if r.Err != "" {
+ response.Err = r.Err
+ break
+ }
+ }
+ } else {
+ dec := json.NewDecoder(resp.Body)
+ dec.UseNumber()
+ decErr := dec.Decode(&response)
+
+ // ignore this error if we got an invalid status code
+ if decErr != nil && decErr.Error() == "EOF" && resp.StatusCode != http.StatusOK {
+ decErr = nil
+ }
+ // If we got a valid decode error, send that back
+ if decErr != nil {
+ return nil, fmt.Errorf("unable to decode json: received status code %d err: %s", resp.StatusCode, decErr)
+ }
+ }
+
+ // If we don't have an error in our json response, and didn't get statusOK
+ // then send back an error
+ if resp.StatusCode != http.StatusOK && response.Error() == nil {
+ return &response, fmt.Errorf("received status code %d from server", resp.StatusCode)
+ }
+ return &response, nil
+}
+
+// QueryAsChunk sends a command to the server and returns the Response.
+func (c *client) QueryAsChunk(q Query) (*ChunkedResponse, error) {
+ req, err := c.createDefaultRequest(q)
+ if err != nil {
+ return nil, err
+ }
+ params := req.URL.Query()
+ params.Set("chunked", "true")
+ if q.ChunkSize > 0 {
+ params.Set("chunk_size", strconv.Itoa(q.ChunkSize))
+ }
+ req.URL.RawQuery = params.Encode()
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := checkResponse(resp); err != nil {
+ return nil, err
+ }
+ return NewChunkedResponse(resp.Body), nil
+}
+
+func checkResponse(resp *http.Response) error {
+ // If we lack a X-Influxdb-Version header, then we didn't get a response from influxdb
+ // but instead some other service. If the error code is also a 500+ code, then some
+ // downstream loadbalancer/proxy/etc had an issue and we should report that.
+ if resp.Header.Get("X-Influxdb-Version") == "" && resp.StatusCode >= http.StatusInternalServerError {
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil || len(body) == 0 {
+ return fmt.Errorf("received status code %d from downstream server", resp.StatusCode)
+ }
+
+ return fmt.Errorf("received status code %d from downstream server, with response body: %q", resp.StatusCode, body)
+ }
+
+ // If we get an unexpected content type, then it is also not from influx direct and therefore
+ // we want to know what we received and what status code was returned for debugging purposes.
+ if cType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")); cType != "application/json" {
+ // Read up to 1kb of the body to help identify downstream errors and limit the impact of things
+ // like downstream serving a large file
+ body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1024))
+ if err != nil || len(body) == 0 {
+ return fmt.Errorf("expected json response, got empty body, with status: %v", resp.StatusCode)
+ }
+
+ return fmt.Errorf("expected json response, got %q, with status: %v and response body: %q", cType, resp.StatusCode, body)
+ }
+ return nil
+}
+
+func (c *client) createDefaultRequest(q Query) (*http.Request, error) {
+ u := c.url
+ u.Path = path.Join(u.Path, "query")
+
+ jsonParameters, err := json.Marshal(q.Parameters)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", u.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "")
+ req.Header.Set("User-Agent", c.useragent)
+
+ if c.username != "" {
+ req.SetBasicAuth(c.username, c.password)
+ }
+
+ params := req.URL.Query()
+ params.Set("q", q.Command)
+ params.Set("db", q.Database)
+ if q.RetentionPolicy != "" {
+ params.Set("rp", q.RetentionPolicy)
+ }
+ params.Set("params", string(jsonParameters))
+
+ if q.Precision != "" {
+ params.Set("epoch", q.Precision)
+ }
+ req.URL.RawQuery = params.Encode()
+
+ return req, nil
+
+}
+
+// duplexReader reads responses and writes it to another writer while
+// satisfying the reader interface.
+type duplexReader struct {
+ r io.ReadCloser
+ w io.Writer
+}
+
+func (r *duplexReader) Read(p []byte) (n int, err error) {
+ n, err = r.r.Read(p)
+ if err == nil {
+ r.w.Write(p[:n])
+ }
+ return n, err
+}
+
+// Close closes the response.
+func (r *duplexReader) Close() error {
+ return r.r.Close()
+}
+
+// ChunkedResponse represents a response from the server that
+// uses chunking to stream the output.
+type ChunkedResponse struct {
+ dec *json.Decoder
+ duplex *duplexReader
+ buf bytes.Buffer
+}
+
+// NewChunkedResponse reads a stream and produces responses from the stream.
+func NewChunkedResponse(r io.Reader) *ChunkedResponse {
+ rc, ok := r.(io.ReadCloser)
+ if !ok {
+ rc = ioutil.NopCloser(r)
+ }
+ resp := &ChunkedResponse{}
+ resp.duplex = &duplexReader{r: rc, w: &resp.buf}
+ resp.dec = json.NewDecoder(resp.duplex)
+ resp.dec.UseNumber()
+ return resp
+}
+
+// NextResponse reads the next line of the stream and returns a response.
+func (r *ChunkedResponse) NextResponse() (*Response, error) {
+ var response Response
+ if err := r.dec.Decode(&response); err != nil {
+ if err == io.EOF {
+ return nil, err
+ }
+ // A decoding error happened. This probably means the server crashed
+ // and sent a last-ditch error message to us. Ensure we have read the
+ // entirety of the connection to get any remaining error text.
+ io.Copy(ioutil.Discard, r.duplex)
+ return nil, errors.New(strings.TrimSpace(r.buf.String()))
+ }
+
+ r.buf.Reset()
+ return &response, nil
+}
+
+// Close closes the response.
+func (r *ChunkedResponse) Close() error {
+ return r.duplex.Close()
+}
diff --git a/influxdb/client/v2/params.go b/influxdb/client/v2/params.go
new file mode 100644
index 0000000..238ef66
--- /dev/null
+++ b/influxdb/client/v2/params.go
@@ -0,0 +1,86 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package client
+
+import (
+ "encoding/json"
+ "time"
+)
+
+type (
+ // Identifier is an identifier value.
+ Identifier string
+
+ // StringValue is a string literal.
+ StringValue string
+
+ // RegexValue is a regexp literal.
+ RegexValue string
+
+ // NumberValue is a number literal.
+ NumberValue float64
+
+ // IntegerValue is an integer literal.
+ IntegerValue int64
+
+ // BooleanValue is a boolean literal.
+ BooleanValue bool
+
+ // TimeValue is a time literal.
+ TimeValue time.Time
+
+ // DurationValue is a duration literal.
+ DurationValue time.Duration
+)
+
+func (v Identifier) MarshalJSON() ([]byte, error) {
+ m := map[string]string{"identifier": string(v)}
+ return json.Marshal(m)
+}
+
+func (v StringValue) MarshalJSON() ([]byte, error) {
+ m := map[string]string{"string": string(v)}
+ return json.Marshal(m)
+}
+
+func (v RegexValue) MarshalJSON() ([]byte, error) {
+ m := map[string]string{"regex": string(v)}
+ return json.Marshal(m)
+}
+
+func (v NumberValue) MarshalJSON() ([]byte, error) {
+ m := map[string]float64{"number": float64(v)}
+ return json.Marshal(m)
+}
+
+func (v IntegerValue) MarshalJSON() ([]byte, error) {
+ m := map[string]int64{"integer": int64(v)}
+ return json.Marshal(m)
+}
+
+func (v BooleanValue) MarshalJSON() ([]byte, error) {
+ m := map[string]bool{"boolean": bool(v)}
+ return json.Marshal(m)
+}
+
+func (v TimeValue) MarshalJSON() ([]byte, error) {
+ t := time.Time(v)
+ m := map[string]string{"string": t.Format(time.RFC3339Nano)}
+ return json.Marshal(m)
+}
+
+func (v DurationValue) MarshalJSON() ([]byte, error) {
+ m := map[string]int64{"duration": int64(v)}
+ return json.Marshal(m)
+}
diff --git a/influxdb/client/v2/udp.go b/influxdb/client/v2/udp.go
new file mode 100644
index 0000000..c548a0e
--- /dev/null
+++ b/influxdb/client/v2/udp.go
@@ -0,0 +1,129 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package client
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+const (
+ // UDPPayloadSize is a reasonable default payload size for UDP packets that
+ // could be travelling over the internet.
+ UDPPayloadSize = 512
+)
+
+// UDPConfig is the config data needed to create a UDP Client.
+type UDPConfig struct {
+ // Addr should be of the form "host:port"
+ // or "[ipv6-host%zone]:port".
+ Addr string
+
+ // PayloadSize is the maximum size of a UDP client message, optional
+ // Tune this based on your network. Defaults to UDPPayloadSize.
+ PayloadSize int
+}
+
+// NewUDPClient returns a client interface for writing to an InfluxDB UDP
+// service from the given config.
+func NewUDPClient(conf UDPConfig) (Client, error) {
+ var udpAddr *net.UDPAddr
+ udpAddr, err := net.ResolveUDPAddr("udp", conf.Addr)
+ if err != nil {
+ return nil, err
+ }
+
+ conn, err := net.DialUDP("udp", nil, udpAddr)
+ if err != nil {
+ return nil, err
+ }
+
+ payloadSize := conf.PayloadSize
+ if payloadSize == 0 {
+ payloadSize = UDPPayloadSize
+ }
+
+ return &udpclient{
+ conn: conn,
+ payloadSize: payloadSize,
+ }, nil
+}
+
+// Close releases the udpclient's resources.
+func (uc *udpclient) Close() error {
+ return uc.conn.Close()
+}
+
+type udpclient struct {
+ conn io.WriteCloser
+ payloadSize int
+}
+
+func (uc *udpclient) Write(bp BatchPoints) error {
+ var b = make([]byte, 0, uc.payloadSize) // initial buffer size, it will grow as needed
+ var d, _ = time.ParseDuration("1" + bp.Precision())
+
+ var delayedError error
+
+ var checkBuffer = func(n int) {
+ if len(b) > 0 && len(b)+n > uc.payloadSize {
+ if _, err := uc.conn.Write(b); err != nil {
+ delayedError = err
+ }
+ b = b[:0]
+ }
+ }
+
+ for _, p := range bp.Points() {
+ p.pt.Round(d)
+ pointSize := p.pt.StringSize() + 1 // include newline in size
+ //point := p.pt.RoundedString(d) + "\n"
+
+ checkBuffer(pointSize)
+
+ if p.Time().IsZero() || pointSize <= uc.payloadSize {
+ b = p.pt.AppendString(b)
+ b = append(b, '\n')
+ continue
+ }
+
+ points := p.pt.Split(uc.payloadSize - 1) // account for newline character
+ for _, sp := range points {
+ checkBuffer(sp.StringSize() + 1)
+ b = sp.AppendString(b)
+ b = append(b, '\n')
+ }
+ }
+
+ if len(b) > 0 {
+ if _, err := uc.conn.Write(b); err != nil {
+ return err
+ }
+ }
+ return delayedError
+}
+
+func (uc *udpclient) Query(q Query) (*Response, error) {
+ return nil, fmt.Errorf("Querying via UDP is not supported")
+}
+
+func (uc *udpclient) QueryAsChunk(q Query) (*ChunkedResponse, error) {
+ return nil, fmt.Errorf("Querying via UDP is not supported")
+}
+
+func (uc *udpclient) Ping(timeout time.Duration) (time.Duration, string, error) {
+ return 0, "", nil
+}
diff --git a/influxdb/config.go b/influxdb/config.go
new file mode 100644
index 0000000..ea763ad
--- /dev/null
+++ b/influxdb/config.go
@@ -0,0 +1,48 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package influxdb
+
+// Config configuration
+type Config struct {
+ Enable bool `yaml:"enable"` //Service switch
+ Address string `yaml:"address"`
+ Port int `yaml:"port"`
+ UDPAddress string `yaml:"udp_address"` //influxdb UDP address of the database,ip:port
+ Database string `yaml:"database"` //Database name
+ Precision string `yaml:"precision"` //Accuracy n, u, ms, s, m or h
+ UserName string `yaml:"username"`
+ Password string `yaml:"password"`
+ MaxIdleConns int `yaml:"max-idle-conns"`
+ MaxIdleConnsPerHost int `yaml:"max-idle-conns-per-host"`
+ IdleConnTimeout int `yaml:"idle-conn-timeout"`
+}
+
+// CustomConfig Custom configuration
+type CustomConfig struct {
+ Enabled bool `yaml:"enabled"` //Service switch
+ Address string `yaml:"address"`
+ Port int `yaml:"port"`
+ UDPAddress string `yaml:"udp_address"` //influxdb UDP address of the database,ip:port
+ Database string `yaml:"database"` //Database name
+ Precision string `yaml:"precision"` //Accuracy n, u, ms, s, m or h
+ UserName string `yaml:"username"`
+ Password string `yaml:"password"`
+ ReadUserName string `yaml:"read-username"`
+ ReadPassword string `yaml:"read-password"`
+ MaxIdleConns int `yaml:"max-idle-conns"`
+ MaxIdleConnsPerHost int `yaml:"max-idle-conns-per-host"`
+ IdleConnTimeout int `yaml:"idle-conn-timeout"`
+ FlushSize int `yaml:"flush-size"`
+ FlushTime int `yaml:"flush-time"`
+}
diff --git a/influxdb/metrics.go b/influxdb/metrics.go
new file mode 100644
index 0000000..c5b4152
--- /dev/null
+++ b/influxdb/metrics.go
@@ -0,0 +1,141 @@
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package influxdb
+
+import (
+ "errors"
+ "fmt"
+ client "github.com/ztdbp/ZACA/pkg/influxdb/influxdb-client/v2"
+ "github.com/ztdbp/ZACA/pkg/logger"
+ "io"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Metrics ...
+type Metrics struct {
+ mu sync.Mutex
+ conf *CustomConfig
+ batchPoints client.BatchPoints
+ point chan *client.Point
+ flushTimer *time.Ticker
+ InfluxDBHttpClient *HTTPClient
+ counter uint64
+}
+
+// MetricsData ...
+type MetricsData struct {
+ Measurement string `json:"measurement"`
+ Fields map[string]interface{} `json:"fields"`
+ Tags map[string]string `json:"tags"`
+}
+
+// Response ...
+type Response struct {
+ State int `json:"state"`
+ Data struct{} `json:"data"`
+ Msg string `json:"msg"`
+}
+
+// NewMetrics ...
+func NewMetrics(influxDBHttpClient *HTTPClient, conf *CustomConfig) (metrics *Metrics) {
+ bp, err := client.NewBatchPoints(influxDBHttpClient.BatchPointsConfig)
+ if err != nil {
+ logger.Named("metrics").Errorf("custom-influxdb client.NewBatchPoints err: %v", err)
+ return
+ }
+ metrics = &Metrics{
+ conf: conf,
+ batchPoints: bp,
+ point: make(chan *client.Point, 16),
+ flushTimer: time.NewTicker(time.Duration(conf.FlushTime) * time.Second),
+ InfluxDBHttpClient: influxDBHttpClient,
+ }
+ go metrics.worker()
+ return
+}
+
+func (mt *Metrics) AddPoint(metricsData *MetricsData) {
+ if mt == nil {
+ return
+ }
+ //atomic.AddUint64(&mt.counter, 1)
+ pt, err := client.NewPoint(metricsData.Measurement, metricsData.Tags, metricsData.Fields, time.Now())
+ if err != nil {
+ logger.Named("metrics").Errorf("custom-influxdb client.NewPoint err: %s", err)
+ return
+ }
+ mt.point <- pt
+}
+
+func (mt *Metrics) worker() {
+ for {
+ select {
+ case p, ok := <-mt.point:
+ if !ok {
+ mt.flush()
+ return
+ }
+ mt.batchPoints.AddPoint(p)
+ // When the number of points reaches 50, send data
+ if mt.batchPoints.GetPointsNum() >= mt.conf.FlushSize {
+ mt.flush()
+ }
+ case <-mt.flushTimer.C:
+ mt.flush()
+ }
+ }
+}
+
+func (mt *Metrics) flush() {
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
+ if mt.batchPoints.GetPointsNum() == 0 {
+ return
+ }
+ err := mt.Write()
+ if err != nil {
+ if strings.Contains(err.Error(), io.EOF.Error()) {
+ err = nil
+ } else {
+ logger.Named("metric").Errorf("custom-influxdb client.Write err: %s", err)
+ }
+ }
+ defer mt.InfluxDBHttpClient.FluxDBHttpClose()
+ // Clear all points
+ mt.batchPoints.ClearPoints()
+}
+
+// Write data timeout processing
+func (mt *Metrics) Write() error {
+ ch := make(chan error, 1)
+ go func() {
+ ch <- mt.InfluxDBHttpClient.FluxDBHttpWrite(mt.batchPoints)
+ }()
+ select {
+ case err := <-ch:
+ return err
+ case <-time.After(800 * time.Millisecond):
+ return errors.New("write timeout")
+ }
+}
+
+func (mt *Metrics) count() {
+ for {
+ time.Sleep(time.Second)
+ fmt.Println("Counter:", atomic.LoadUint64(&mt.counter))
+ }
+}
diff --git a/logger/hook/hook.go b/logger/hook/hook.go
new file mode 100644
index 0000000..087e598
--- /dev/null
+++ b/logger/hook/hook.go
@@ -0,0 +1,165 @@
+// Copyright 2022-present The ZTDBP Authors.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hook
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/LyricTian/queue"
+ "github.com/sirupsen/logrus"
+)
+
+var defaultOptions = options{
+ maxQueues: 512,
+ maxWorkers: 1,
+ levels: []logrus.Level{
+ logrus.PanicLevel,
+ logrus.FatalLevel,
+ logrus.ErrorLevel,
+ logrus.WarnLevel,
+ logrus.InfoLevel,
+ logrus.DebugLevel,
+ logrus.TraceLevel,
+ },
+}
+
+// ExecCloser write the logrus entry to the store and close the store
+type ExecCloser interface {
+ Exec(entry *logrus.Entry) error
+ Close() error
+}
+
+// FilterHandle a filter handler
+type FilterHandle func(*logrus.Entry) *logrus.Entry
+
+type options struct {
+ maxQueues int
+ maxWorkers int
+ extra map[string]interface{}
+ filter FilterHandle
+ levels []logrus.Level
+}
+
+// SetMaxQueues set the number of buffers
+func SetMaxQueues(maxQueues int) Option {
+ return func(o *options) {
+ o.maxQueues = maxQueues
+ }
+}
+
+// SetMaxWorkers set the number of worker threads
+func SetMaxWorkers(maxWorkers int) Option {
+ return func(o *options) {
+ o.maxWorkers = maxWorkers
+ }
+}
+
+// SetExtra set extended parameters
+func SetExtra(extra map[string]interface{}) Option {
+ return func(o *options) {
+ o.extra = extra
+ }
+}
+
+// SetFilter set the entry filter
+func SetFilter(filter FilterHandle) Option {
+ return func(o *options) {
+ o.filter = filter
+ }
+}
+
+// SetLevels set the available log level
+func SetLevels(levels ...logrus.Level) Option {
+ return func(o *options) {
+ if len(levels) == 0 {
+ return
+ }
+ o.levels = levels
+ }
+}
+
+// Option a hook parameter options
+type Option func(*options)
+
+// New creates a hook to be added to an instance of logger
+func New(exec ExecCloser, opt ...Option) *Hook {
+ opts := defaultOptions
+ for _, o := range opt {
+ o(&opts)
+ }
+
+ q := queue.NewQueue(opts.maxQueues, opts.maxWorkers)
+ q.Run()
+
+ return &Hook{
+ opts: opts,
+ q: q,
+ e: exec,
+ }
+}
+
+// Hook to send logs to a mongo database
+type Hook struct {
+ opts options
+ q *queue.Queue
+ e ExecCloser
+}
+
+// Levels returns the available logging levels
+func (h *Hook) Levels() []logrus.Level {
+ return h.opts.levels
+}
+
+// Fire is called when a log event is fired
+func (h *Hook) Fire(entry *logrus.Entry) error {
+ entry = h.copyEntry(entry)
+ h.q.Push(queue.NewJob(entry, func(v interface{}) {
+ h.exec(v.(*logrus.Entry))
+ }))
+ return nil
+}
+
+func (h *Hook) copyEntry(e *logrus.Entry) *logrus.Entry {
+ entry := logrus.NewEntry(e.Logger)
+ entry.Data = make(logrus.Fields)
+ entry.Time = e.Time
+ entry.Level = e.Level
+ entry.Message = e.Message
+ for k, v := range e.Data {
+ entry.Data[k] = v
+ }
+ return entry
+}
+
+func (h *Hook) exec(entry *logrus.Entry) {
+ for k, v := range h.opts.extra {
+ if _, ok := entry.Data[k]; !ok {
+ entry.Data[k] = v
+ }
+ }
+
+ if filter := h.opts.filter; filter != nil {
+ entry = filter(entry)
+ }
+
+ err := h.e.Exec(entry)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "[logrus-hook] execution error: %s", err.Error())
+ }
+}
+
+// Flush waits for the log queue to be empty
+func (h *Hook) Flush() {
+ h.q.Terminate()
+ h.e.Close()
+}
diff --git a/logger/hook/redis/redis.go b/logger/hook/redis/redis.go
new file mode 100644
index 0000000..5142f02
--- /dev/null
+++ b/logger/hook/redis/redis.go
@@ -0,0 +1,64 @@
+// Copyright 2022-present The ZTDBP Authors.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package redis
+
+import (
+ "fmt"
+ "github.com/ztdbp/ZASentinel/pkg/util/json"
+
+ "github.com/go-redis/redis"
+ "github.com/sirupsen/logrus"
+)
+
+type Config struct {
+ Addr string
+ Password string
+ Key string
+}
+
+func New(c *Config) *Hook {
+ redisdb := redis.NewClient(&redis.Options{
+ Addr: c.Addr, // use default Addr
+ Password: c.Password, // no password set
+ DB: 0, // use default DB
+ })
+
+ _, err := redisdb.Ping().Result()
+ if err != nil {
+ fmt.Println("error creating message for REDIS:", err)
+ panic(err)
+ }
+ return &Hook{
+ cli: redisdb,
+ key: c.Key,
+ }
+}
+
+type Hook struct {
+ cli *redis.Client
+ key string
+}
+
+func (h *Hook) Exec(entry *logrus.Entry) error {
+ fields := make(map[string]interface{})
+ for k, v := range entry.Data {
+ fields[k] = v
+ }
+ fields["level"] = entry.Level.String()
+ fields["message"] = entry.Message
+ b, _ := json.Marshal(fields)
+ return h.cli.RPush(h.key, string(b)).Err()
+}
+
+func (h *Hook) Close() error {
+ return h.cli.Close()
+}
diff --git a/logger/logger.go b/logger/logger.go
new file mode 100644
index 0000000..bed90c1
--- /dev/null
+++ b/logger/logger.go
@@ -0,0 +1,155 @@
+// Copyright 2022-present The ZTDBP Authors.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logger
+
+import (
+ "context"
+ "io"
+
+ "github.com/sirupsen/logrus"
+)
+
+// Define key
+const (
+ TraceIDKey = "trace_id"
+ UserIDKey = "user_id"
+ TagKey = "tag"
+ VersionKey = "version"
+ StackKey = "stack"
+)
+
+var version string
+
+type Logger = logrus.Logger
+
+type Entry = logrus.Entry
+
+type Hook = logrus.Hook
+
+func StandardLogger() *Logger {
+ return logrus.StandardLogger()
+}
+
+func SetLevel(level int) {
+ logrus.SetLevel(logrus.Level(level))
+}
+
+func SetFormatter(format string) {
+ switch format {
+ case "json":
+ logrus.SetFormatter(new(logrus.JSONFormatter))
+ default:
+ logrus.SetFormatter(new(logrus.TextFormatter))
+ }
+}
+
+func SetOutput(out io.Writer) {
+ logrus.SetOutput(out)
+}
+
+func SetVersion(v string) {
+ version = v
+}
+
+func AddHook(hook Hook) {
+ logrus.AddHook(hook)
+}
+
+type (
+ traceIDKey struct{}
+ userIDKey struct{}
+ tagKey struct{}
+ stackKey struct{}
+)
+
+func NewTraceIDContext(ctx context.Context, traceID string) context.Context {
+ return context.WithValue(ctx, traceIDKey{}, traceID)
+}
+
+func FromTraceIDContext(ctx context.Context) string {
+ v := ctx.Value(traceIDKey{})
+ if v != nil {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func NewUserIDContext(ctx context.Context, userID string) context.Context {
+ return context.WithValue(ctx, userIDKey{}, userID)
+}
+
+func FromUserIDContext(ctx context.Context) string {
+ v := ctx.Value(userIDKey{})
+ if v != nil {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func NewTagContext(ctx context.Context, tag string) context.Context {
+ return context.WithValue(ctx, tagKey{}, tag)
+}
+
+func FromTagContext(ctx context.Context) string {
+ v := ctx.Value(tagKey{})
+ if v != nil {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func NewStackContext(ctx context.Context, stack error) context.Context {
+ return context.WithValue(ctx, stackKey{}, stack)
+}
+
+func FromStackContext(ctx context.Context) error {
+ v := ctx.Value(stackKey{})
+ if v != nil {
+ if s, ok := v.(error); ok {
+ return s
+ }
+ }
+ return nil
+}
+
+func WithErrorStack(ctx context.Context, err error) *Entry {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return WithContext(NewStackContext(ctx, err))
+}
+
+func WithContext(ctx context.Context) *Entry {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ return logrus.WithContext(ctx)
+}
+
+// Define logrus alias
+var (
+ Tracef = logrus.Tracef
+ Debugf = logrus.Debugf
+ Infof = logrus.Infof
+ Warnf = logrus.Warnf
+ Errorf = logrus.Errorf
+ Fatalf = logrus.Fatalf
+ Panicf = logrus.Panicf
+ Printf = logrus.Printf
+)
diff --git a/memorycacher/README.md b/memorycacher/README.md
new file mode 100644
index 0000000..7411ca1
--- /dev/null
+++ b/memorycacher/README.md
@@ -0,0 +1,83 @@
+
+
+# Base on go-cache
+
+go-cache is an in-memory key:value store/cache similar to memcached that is
+suitable for applications running on a single machine. Its major advantage is
+that, being essentially a thread-safe `map[string]interface{}` with expiration
+times, it doesn't need to serialize or transmit its contents over the network.
+
+Any object can be stored, for a given duration or forever, and the cache can be
+safely used by multiple goroutines.
+
+Although go-cache isn't meant to be used as a persistent datastore, the entire
+cache can be saved to and loaded from a file (using `c.Items()` to retrieve the
+items map to serialize, and `NewFrom()` to create a cache from a deserialized
+one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.)
+
+### Usage
+
+```go
+import (
+ "fmt"
+ "memorycache"
+ "time"
+)
+
+func main() {
+ // Create a cache with a default expiration time of 5 minutes, and which
+ // purges expired items every 10 minutes
+ c := memorycache.New(5*time.Minute, 10*time.Minute)
+
+ // Set the value of the key "foo" to "bar", with the default expiration time
+ c.Set("foo", "bar", cache.DefaultExpiration)
+
+ // Set the value of the key "baz" to 42, with no expiration time
+ // (the item won't be removed until it is re-set, or removed using
+ // c.Delete("baz")
+ c.Set("baz", 42, cache.NoExpiration)
+
+ // Get the string associated with the key "foo" from the cache
+ foo, found := c.Get("foo")
+ if found {
+ fmt.Println(foo)
+ }
+
+ // Since Go is statically typed, and cache values can be anything, type
+ // assertion is needed when values are being passed to functions that don't
+ // take arbitrary types, (i.e. interface{}). The simplest way to do this for
+ // values which will only be used once--e.g. for passing to another
+ // function--is:
+ foo, found := c.Get("foo")
+ if found {
+ MyFunction(foo.(string))
+ }
+
+ // This gets tedious if the value is used several times in the same function.
+ // You might do either of the following instead:
+ if x, found := c.Get("foo"); found {
+ foo := x.(string)
+ // ...
+ }
+ // or
+ var foo string
+ if x, found := c.Get("foo"); found {
+ foo = x.(string)
+ }
+ // ...
+ // foo can then be passed around freely as a string
+
+ // Want performance? Store pointers!
+ c.Set("foo", &MyStruct, cache.DefaultExpiration)
+ if x, found := c.Get("foo"); found {
+ foo := x.(*MyStruct)
+ // ...
+ }
+}
+```
\ No newline at end of file
diff --git a/memorycacher/cache.go b/memorycacher/cache.go
new file mode 100644
index 0000000..56f3939
--- /dev/null
+++ b/memorycacher/cache.go
@@ -0,0 +1,1234 @@
+/*
+ * @Author: patrickmn,gitsrc
+ * @Date: 2020-07-09 13:17:30
+ * @LastEditors: gitsrc
+ * @LastEditTime: 2020-07-09 13:22:16
+ * @FilePath: /ServiceCar/utils/memorycache/cache.go
+ */
+
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package memorycacher
+
+import (
+ "encoding/gob"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+ "sync"
+ "time"
+)
+
+type Item struct {
+ Object interface{}
+ Expiration int64
+}
+
+// const (
+// cleanStatusPending = 0
+// cleanStatusRunning = 1
+// )
+
+var (
+ maxItemsCountErr = errors.New("reach max items count.")
+)
+
+// Returns true if the item has expired.
+func (item Item) Expired() bool {
+ if item.Expiration == 0 {
+ return false
+ }
+ return time.Now().UnixNano() > item.Expiration
+}
+
+const (
+ // For use with functions that take an expiration time.
+ NoExpiration time.Duration = -1
+ // For use with functions that take an expiration time. Equivalent to
+ // passing in the same expiration duration as was given to New() or
+ // NewFrom() when the cache was created (e.g. 5 minutes.)
+ DefaultExpiration time.Duration = 0
+)
+
+type Cache struct {
+ *cache
+ // If this is confusing, see the comment at the bottom of New()
+}
+
+type cache struct {
+ defaultExpiration time.Duration
+ maxItemsCount int //Maximum number of items
+ items map[string]Item
+ mu sync.RWMutex
+ onEvicted func(string, interface{})
+ lastCleanTime time.Time //Last cleaning time
+ janitor *janitor
+}
+
+//Determine whether the amount of schema inside the map schema structure has reached the maximum number limit
+func (c *cache) IsReachMaxItemsCount() bool {
+ return c.ItemCount() >= c.maxItemsCount
+}
+
+// Add an item to the cache, replacing any existing item. If the duration is 0
+// (DefaultExpiration), the cache's default expiration time is used. If it is -1
+// (NoExpiration), the item never expires.
+func (c *cache) Set(k string, x interface{}, d time.Duration) {
+ if c.IsReachMaxItemsCount() {
+ c.ShoudClean()
+ return
+ }
+
+ // "Inlining" of set
+ var e int64
+ if d == DefaultExpiration {
+ d = c.defaultExpiration
+ }
+ if d > 0 {
+ e = time.Now().Add(d).UnixNano()
+ }
+ c.mu.Lock()
+ c.items[k] = Item{
+ Object: x,
+ Expiration: e,
+ }
+ // TODO: Calls to mu.Unlock are currently not deferred because defer
+ // adds ~200 ns (as of go1.)
+ c.mu.Unlock()
+}
+
+func (c *cache) set(k string, x interface{}, d time.Duration) {
+ var e int64
+ if d == DefaultExpiration {
+ d = c.defaultExpiration
+ }
+ if d > 0 {
+ e = time.Now().Add(d).UnixNano()
+ }
+ c.items[k] = Item{
+ Object: x,
+ Expiration: e,
+ }
+}
+
+// Add an item to the cache, replacing any existing item, using the default
+// expiration.
+func (c *cache) SetDefault(k string, x interface{}) {
+ c.Set(k, x, DefaultExpiration)
+}
+
+// Add an item to the cache only if an item doesn't already exist for the given
+// key, or if the existing item has expired. Returns an error otherwise.
+func (c *cache) Add(k string, x interface{}, d time.Duration) error {
+ if c.IsReachMaxItemsCount() {
+ c.ShoudClean()
+ return maxItemsCountErr
+ }
+
+ c.mu.Lock()
+ _, found := c.get(k)
+ if found {
+ c.mu.Unlock()
+ return fmt.Errorf("Item %s already exists", k)
+ }
+ c.set(k, x, d)
+ c.mu.Unlock()
+ return nil
+}
+
+// Set a new value for the cache key only if it already exists, and the existing
+// item hasn't expired. Returns an error otherwise.
+func (c *cache) Replace(k string, x interface{}, d time.Duration) error {
+ c.mu.Lock()
+ _, found := c.get(k)
+ if !found {
+ c.mu.Unlock()
+ return fmt.Errorf("Item %s doesn't exist", k)
+ }
+ c.set(k, x, d)
+ c.mu.Unlock()
+ return nil
+}
+
+// Get an item from the cache. Returns the item or nil, and a bool indicating
+// whether the key was found.
+func (c *cache) Get(k string) (interface{}, bool) {
+ c.mu.RLock()
+ // "Inlining" of get and Expired
+ item, found := c.items[k]
+ if !found {
+ c.mu.RUnlock()
+ return nil, false
+ }
+ if item.Expiration > 0 {
+ if time.Now().UnixNano() > item.Expiration {
+ c.mu.RUnlock()
+ return nil, false
+ }
+ }
+ c.mu.RUnlock()
+ return item.Object, true
+}
+
+// GetWithExpiration returns an item and its expiration time from the cache.
+// It returns the item or nil, the expiration time if one is set (if the item
+// never expires a zero value for time.Time is returned), and a bool indicating
+// whether the key was found.
+func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) {
+ c.mu.RLock()
+ // "Inlining" of get and Expired
+ item, found := c.items[k]
+ if !found {
+ c.mu.RUnlock()
+ return nil, time.Time{}, false
+ }
+
+ if item.Expiration > 0 {
+ if time.Now().UnixNano() > item.Expiration {
+ c.mu.RUnlock()
+ return nil, time.Time{}, false
+ }
+
+ // Return the item and the expiration time
+ c.mu.RUnlock()
+ return item.Object, time.Unix(0, item.Expiration), true
+ }
+
+ // If expiration <= 0 (i.e. no expiration time set) then return the item
+ // and a zeroed time.Time
+ c.mu.RUnlock()
+ return item.Object, time.Time{}, true
+}
+
+func (c *cache) get(k string) (interface{}, bool) {
+ item, found := c.items[k]
+ if !found {
+ return nil, false
+ }
+ // "Inlining" of Expired
+ if item.Expiration > 0 {
+ if time.Now().UnixNano() > item.Expiration {
+ return nil, false
+ }
+ }
+ return item.Object, true
+}
+
+//send map clean to cache janitor
+func (c *cache) ShoudClean() {
+ if c.janitor.shoudClean == nil {
+ return
+ }
+ select {
+ case c.janitor.shoudClean <- true:
+ default:
+ }
+}
+
+// Increment an item of type int, int8, int16, int32, int64, uintptr, uint,
+// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the
+// item's value is not an integer, if it was not found, or if it is not
+// possible to increment it by n. To retrieve the incremented value, use one
+// of the specialized methods, e.g. IncrementInt64.
+func (c *cache) Increment(k string, n int64) error {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return fmt.Errorf("Item %s not found", k)
+ }
+ switch v.Object.(type) {
+ case int:
+ v.Object = v.Object.(int) + int(n)
+ case int8:
+ v.Object = v.Object.(int8) + int8(n)
+ case int16:
+ v.Object = v.Object.(int16) + int16(n)
+ case int32:
+ v.Object = v.Object.(int32) + int32(n)
+ case int64:
+ v.Object = v.Object.(int64) + n
+ case uint:
+ v.Object = v.Object.(uint) + uint(n)
+ case uintptr:
+ v.Object = v.Object.(uintptr) + uintptr(n)
+ case uint8:
+ v.Object = v.Object.(uint8) + uint8(n)
+ case uint16:
+ v.Object = v.Object.(uint16) + uint16(n)
+ case uint32:
+ v.Object = v.Object.(uint32) + uint32(n)
+ case uint64:
+ v.Object = v.Object.(uint64) + uint64(n)
+ case float32:
+ v.Object = v.Object.(float32) + float32(n)
+ case float64:
+ v.Object = v.Object.(float64) + float64(n)
+ default:
+ c.mu.Unlock()
+ return fmt.Errorf("The value for %s is not an integer", k)
+ }
+ c.items[k] = v
+ c.mu.Unlock()
+ return nil
+}
+
+// Increment an item of type float32 or float64 by n. Returns an error if the
+// item's value is not floating point, if it was not found, or if it is not
+// possible to increment it by n. Pass a negative number to decrement the
+// value. To retrieve the incremented value, use one of the specialized methods,
+// e.g. IncrementFloat64.
+func (c *cache) IncrementFloat(k string, n float64) error {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return fmt.Errorf("Item %s not found", k)
+ }
+ switch v.Object.(type) {
+ case float32:
+ v.Object = v.Object.(float32) + float32(n)
+ case float64:
+ v.Object = v.Object.(float64) + n
+ default:
+ c.mu.Unlock()
+ return fmt.Errorf("The value for %s does not have type float32 or float64", k)
+ }
+ c.items[k] = v
+ c.mu.Unlock()
+ return nil
+}
+
+// Increment an item of type int by n. Returns an error if the item's value is
+// not an int, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementInt(k string, n int) (int, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type int8 by n. Returns an error if the item's value is
+// not an int8, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementInt8(k string, n int8) (int8, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int8)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int8", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type int16 by n. Returns an error if the item's value is
+// not an int16, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementInt16(k string, n int16) (int16, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int16)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int16", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type int32 by n. Returns an error if the item's value is
+// not an int32, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementInt32(k string, n int32) (int32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int32", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type int64 by n. Returns an error if the item's value is
+// not an int64, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementInt64(k string, n int64) (int64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int64", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uint by n. Returns an error if the item's value is
+// not an uint, or if it was not found. If there is no error, the incremented
+// value is returned.
+func (c *cache) IncrementUint(k string, n uint) (uint, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uintptr by n. Returns an error if the item's value
+// is not an uintptr, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uintptr)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uintptr", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uint8 by n. Returns an error if the item's value
+// is not an uint8, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint8)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint8", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uint16 by n. Returns an error if the item's value
+// is not an uint16, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint16)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint16", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uint32 by n. Returns an error if the item's value
+// is not an uint32, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint32", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type uint64 by n. Returns an error if the item's value
+// is not an uint64, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint64", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type float32 by n. Returns an error if the item's value
+// is not an float32, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementFloat32(k string, n float32) (float32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(float32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an float32", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Increment an item of type float64 by n. Returns an error if the item's value
+// is not an float64, or if it was not found. If there is no error, the
+// incremented value is returned.
+func (c *cache) IncrementFloat64(k string, n float64) (float64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(float64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an float64", k)
+ }
+ nv := rv + n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type int, int8, int16, int32, int64, uintptr, uint,
+// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the
+// item's value is not an integer, if it was not found, or if it is not
+// possible to decrement it by n. To retrieve the decremented value, use one
+// of the specialized methods, e.g. DecrementInt64.
+func (c *cache) Decrement(k string, n int64) error {
+ // TODO: Implement Increment and Decrement more cleanly.
+ // (Cannot do Increment(k, n*-1) for uints.)
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return fmt.Errorf("Item not found")
+ }
+ switch v.Object.(type) {
+ case int:
+ v.Object = v.Object.(int) - int(n)
+ case int8:
+ v.Object = v.Object.(int8) - int8(n)
+ case int16:
+ v.Object = v.Object.(int16) - int16(n)
+ case int32:
+ v.Object = v.Object.(int32) - int32(n)
+ case int64:
+ v.Object = v.Object.(int64) - n
+ case uint:
+ v.Object = v.Object.(uint) - uint(n)
+ case uintptr:
+ v.Object = v.Object.(uintptr) - uintptr(n)
+ case uint8:
+ v.Object = v.Object.(uint8) - uint8(n)
+ case uint16:
+ v.Object = v.Object.(uint16) - uint16(n)
+ case uint32:
+ v.Object = v.Object.(uint32) - uint32(n)
+ case uint64:
+ v.Object = v.Object.(uint64) - uint64(n)
+ case float32:
+ v.Object = v.Object.(float32) - float32(n)
+ case float64:
+ v.Object = v.Object.(float64) - float64(n)
+ default:
+ c.mu.Unlock()
+ return fmt.Errorf("The value for %s is not an integer", k)
+ }
+ c.items[k] = v
+ c.mu.Unlock()
+ return nil
+}
+
+// Decrement an item of type float32 or float64 by n. Returns an error if the
+// item's value is not floating point, if it was not found, or if it is not
+// possible to decrement it by n. Pass a negative number to decrement the
+// value. To retrieve the decremented value, use one of the specialized methods,
+// e.g. DecrementFloat64.
+func (c *cache) DecrementFloat(k string, n float64) error {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return fmt.Errorf("Item %s not found", k)
+ }
+ switch v.Object.(type) {
+ case float32:
+ v.Object = v.Object.(float32) - float32(n)
+ case float64:
+ v.Object = v.Object.(float64) - n
+ default:
+ c.mu.Unlock()
+ return fmt.Errorf("The value for %s does not have type float32 or float64", k)
+ }
+ c.items[k] = v
+ c.mu.Unlock()
+ return nil
+}
+
+// Decrement an item of type int by n. Returns an error if the item's value is
+// not an int, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementInt(k string, n int) (int, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type int8 by n. Returns an error if the item's value is
+// not an int8, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementInt8(k string, n int8) (int8, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int8)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int8", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type int16 by n. Returns an error if the item's value is
+// not an int16, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementInt16(k string, n int16) (int16, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int16)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int16", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type int32 by n. Returns an error if the item's value is
+// not an int32, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementInt32(k string, n int32) (int32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int32", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type int64 by n. Returns an error if the item's value is
+// not an int64, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementInt64(k string, n int64) (int64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(int64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an int64", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uint by n. Returns an error if the item's value is
+// not an uint, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementUint(k string, n uint) (uint, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uintptr by n. Returns an error if the item's value
+// is not an uintptr, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uintptr)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uintptr", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uint8 by n. Returns an error if the item's value is
+// not an uint8, or if it was not found. If there is no error, the decremented
+// value is returned.
+func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint8)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint8", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uint16 by n. Returns an error if the item's value
+// is not an uint16, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint16)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint16", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uint32 by n. Returns an error if the item's value
+// is not an uint32, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint32", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type uint64 by n. Returns an error if the item's value
+// is not an uint64, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(uint64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an uint64", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type float32 by n. Returns an error if the item's value
+// is not an float32, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementFloat32(k string, n float32) (float32, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(float32)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an float32", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Decrement an item of type float64 by n. Returns an error if the item's value
+// is not an float64, or if it was not found. If there is no error, the
+// decremented value is returned.
+func (c *cache) DecrementFloat64(k string, n float64) (float64, error) {
+ c.mu.Lock()
+ v, found := c.items[k]
+ if !found || v.Expired() {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("Item %s not found", k)
+ }
+ rv, ok := v.Object.(float64)
+ if !ok {
+ c.mu.Unlock()
+ return 0, fmt.Errorf("The value for %s is not an float64", k)
+ }
+ nv := rv - n
+ v.Object = nv
+ c.items[k] = v
+ c.mu.Unlock()
+ return nv, nil
+}
+
+// Delete an item from the cache. Does nothing if the key is not in the cache.
+func (c *cache) Delete(k string) {
+ c.mu.Lock()
+ v, evicted := c.delete(k)
+ c.mu.Unlock()
+ if evicted {
+ c.onEvicted(k, v)
+ }
+}
+
+func (c *cache) delete(k string) (interface{}, bool) {
+ if c.onEvicted != nil {
+ if v, found := c.items[k]; found {
+ delete(c.items, k)
+ return v.Object, true
+ }
+ }
+ delete(c.items, k)
+ return nil, false
+}
+
+type keyAndValue struct {
+ key string
+ value interface{}
+}
+
+// Delete all expired items from the cache.
+func (c *cache) DeleteExpired() {
+ var evictedItems []keyAndValue
+ nowTime := time.Now()
+ now := nowTime.UnixNano()
+ c.mu.Lock()
+ for k, v := range c.items {
+ // "Inlining" of expired
+ if v.Expiration > 0 && now > v.Expiration {
+ ov, evicted := c.delete(k)
+ if evicted {
+ evictedItems = append(evictedItems, keyAndValue{k, ov})
+ }
+ }
+ }
+ c.lastCleanTime = nowTime
+ c.mu.Unlock()
+ for _, v := range evictedItems {
+ c.onEvicted(v.key, v.value)
+ }
+}
+
+// Sets an (optional) function that is called with the key and value when an
+// item is evicted from the cache. (Including when it is deleted manually, but
+// not when it is overwritten.) Set to nil to disable.
+func (c *cache) OnEvicted(f func(string, interface{})) {
+ c.mu.Lock()
+ c.onEvicted = f
+ c.mu.Unlock()
+}
+
+// Write the cache's items (using Gob) to an io.Writer.
+//
+// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the
+// documentation for NewFrom().)
+func (c *cache) Save(w io.Writer) (err error) {
+ enc := gob.NewEncoder(w)
+ defer func() {
+ if x := recover(); x != nil {
+ err = fmt.Errorf("Error registering item types with Gob library")
+ }
+ }()
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ for _, v := range c.items {
+ gob.Register(v.Object)
+ }
+ err = enc.Encode(&c.items)
+ return
+}
+
+// Save the cache's items to the given filename, creating the file if it
+// doesn't exist, and overwriting it if it does.
+//
+// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the
+// documentation for NewFrom().)
+func (c *cache) SaveFile(fname string) error {
+ fp, err := os.Create(fname)
+ if err != nil {
+ return err
+ }
+ err = c.Save(fp)
+ if err != nil {
+ fp.Close()
+ return err
+ }
+ return fp.Close()
+}
+
+// Add (Gob-serialized) cache items from an io.Reader, excluding any items with
+// keys that already exist (and haven't expired) in the current cache.
+//
+// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the
+// documentation for NewFrom().)
+func (c *cache) Load(r io.Reader) error {
+ dec := gob.NewDecoder(r)
+ items := map[string]Item{}
+ err := dec.Decode(&items)
+ if err == nil {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for k, v := range items {
+ ov, found := c.items[k]
+ if !found || ov.Expired() {
+ c.items[k] = v
+ }
+ }
+ }
+ return err
+}
+
+// Load and add cache items from the given filename, excluding any items with
+// keys that already exist in the current cache.
+//
+// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the
+// documentation for NewFrom().)
+func (c *cache) LoadFile(fname string) error {
+ fp, err := os.Open(fname)
+ if err != nil {
+ return err
+ }
+ err = c.Load(fp)
+ if err != nil {
+ fp.Close()
+ return err
+ }
+ return fp.Close()
+}
+
+// Copies all unexpired items in the cache into a new map and returns it.
+func (c *cache) Items() map[string]Item {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ m := make(map[string]Item, len(c.items))
+ now := time.Now().UnixNano()
+ for k, v := range c.items {
+ // "Inlining" of Expired
+ if v.Expiration > 0 {
+ if now > v.Expiration {
+ continue
+ }
+ }
+ m[k] = v
+ }
+ return m
+}
+
+// Returns the number of items in the cache. This may include items that have
+// expired, but have not yet been cleaned up.
+func (c *cache) ItemCount() int {
+ c.mu.RLock()
+ n := len(c.items)
+ c.mu.RUnlock()
+ return n
+}
+
+// Delete all items from the cache.
+func (c *cache) Flush() {
+ c.mu.Lock()
+ c.items = map[string]Item{}
+ c.mu.Unlock()
+}
+
+type janitor struct {
+ Interval time.Duration
+ stop chan bool
+ shoudClean chan bool
+}
+
+func (j *janitor) Run(c *cache) {
+ ticker := time.NewTicker(j.Interval)
+ for {
+ select {
+ case <-ticker.C:
+ c.DeleteExpired()
+ case <-j.shoudClean:
+ c.mu.RLock()
+ lastCleanTime := c.lastCleanTime
+ c.mu.RUnlock()
+
+ if lastCleanTime.Add(time.Second * 1).Before(time.Now()) {
+ c.DeleteExpired()
+ }
+ case <-j.stop:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+func stopJanitor(c *Cache) {
+ c.janitor.stop <- true
+}
+
+func runJanitor(c *cache, ci time.Duration) {
+ j := &janitor{
+ Interval: ci,
+ stop: make(chan bool),
+ shoudClean: make(chan bool),
+ }
+ c.janitor = j
+ go j.Run(c)
+}
+
+func newCache(de time.Duration, maxItemsCount int, m map[string]Item) *cache {
+ if de == 0 {
+ de = -1
+ }
+ c := &cache{
+ defaultExpiration: de,
+ maxItemsCount: maxItemsCount,
+ items: m,
+ lastCleanTime: time.Now(),
+ }
+ return c
+}
+
+func newCacheWithJanitor(de time.Duration, ci time.Duration, maxItemsCount int, m map[string]Item) *Cache {
+ c := newCache(de, maxItemsCount, m)
+ // This trick ensures that the janitor goroutine (which--granted it
+ // was enabled--is running DeleteExpired on c forever) does not keep
+ // the returned C object from being garbage collected. When it is
+ // garbage collected, the finalizer stops the janitor goroutine, after
+ // which c can be collected.
+ C := &Cache{c}
+ if ci > 0 {
+ runJanitor(c, ci)
+ runtime.SetFinalizer(C, stopJanitor)
+ }
+ return C
+}
+
+// Return a new cache with a given default expiration duration and cleanup
+// interval. If the expiration duration is less than one (or NoExpiration),
+// the items in the cache never expire (by default), and must be deleted
+// manually. If the cleanup interval is less than one, expired items are not
+// deleted from the cache before calling c.DeleteExpired().
+func New(defaultExpiration, cleanupInterval time.Duration, maxItemsCount int) *Cache {
+ items := make(map[string]Item)
+ return newCacheWithJanitor(defaultExpiration, cleanupInterval, maxItemsCount, items)
+}
+
+// Return a new cache with a given default expiration duration and cleanup
+// interval. If the expiration duration is less than one (or NoExpiration),
+// the items in the cache never expire (by default), and must be deleted
+// manually. If the cleanup interval is less than one, expired items are not
+// deleted from the cache before calling c.DeleteExpired().
+//
+// NewFrom() also accepts an items map which will serve as the underlying map
+// for the cache. This is useful for starting from a deserialized cache
+// (serialized using e.g. gob.Encode() on c.Items()), or passing in e.g.
+// make(map[string]Item, 500) to improve startup performance when the cache
+// is expected to reach a certain minimum size.
+//
+// Only the cache's methods synchronize access to this map, so it is not
+// recommended to keep any references to the map around after creating a cache.
+// If need be, the map can be accessed at a later point using c.Items() (subject
+// to the same caveat.)
+//
+// Note regarding serialization: When using e.g. gob, make sure to
+// gob.Register() the individual types stored in the cache before encoding a
+// map retrieved with c.Items(), and to register those same types before
+// decoding a blob containing an items map.
+func NewFrom(defaultExpiration, cleanupInterval time.Duration, maxItemsCount int, items map[string]Item) *Cache {
+ return newCacheWithJanitor(defaultExpiration, cleanupInterval, maxItemsCount, items)
+}
diff --git a/memorycacher/cache_test.go b/memorycacher/cache_test.go
new file mode 100644
index 0000000..209a237
--- /dev/null
+++ b/memorycacher/cache_test.go
@@ -0,0 +1,1796 @@
+/*
+ * @Author: patrickmn,gitsrc
+ * @Date: 2020-07-09 13:17:30
+ * @LastEditors: gitsrc
+ * @LastEditTime: 2020-07-10 10:06:28
+ * @FilePath: /ServiceCar/utils/memorycacher/cache_test.go
+ */
+
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package memorycacher
+
+import (
+ "bytes"
+ "io/ioutil"
+ "runtime"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+)
+
+type TestStruct struct {
+ Num int
+ Children []*TestStruct
+}
+
+const (
+ maxItemsCount = 10000
+)
+
+func TestCache(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+
+ a, found := tc.Get("a")
+ if found || a != nil {
+ t.Error("Getting A found value that shouldn't exist:", a)
+ }
+
+ b, found := tc.Get("b")
+ if found || b != nil {
+ t.Error("Getting B found value that shouldn't exist:", b)
+ }
+
+ c, found := tc.Get("c")
+ if found || c != nil {
+ t.Error("Getting C found value that shouldn't exist:", c)
+ }
+
+ tc.Set("a", 1, DefaultExpiration)
+ tc.Set("b", "b", DefaultExpiration)
+ tc.Set("c", 3.5, DefaultExpiration)
+
+ x, found := tc.Get("a")
+ if !found {
+ t.Error("a was not found while getting a2")
+ }
+ if x == nil {
+ t.Error("x for a is nil")
+ } else if a2 := x.(int); a2+2 != 3 {
+ t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
+ }
+
+ x, found = tc.Get("b")
+ if !found {
+ t.Error("b was not found while getting b2")
+ }
+ if x == nil {
+ t.Error("x for b is nil")
+ } else if b2 := x.(string); b2+"B" != "bB" {
+ t.Error("b2 (which should be b) plus B does not equal bB; value:", b2)
+ }
+
+ x, found = tc.Get("c")
+ if !found {
+ t.Error("c was not found while getting c2")
+ }
+ if x == nil {
+ t.Error("x for c is nil")
+ } else if c2 := x.(float64); c2+1.2 != 4.7 {
+ t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2)
+ }
+}
+
+func TestCacheTimes(t *testing.T) {
+ var found bool
+
+ tc := New(50*time.Millisecond, 1*time.Millisecond, maxItemsCount)
+ tc.Set("a", 1, DefaultExpiration)
+ tc.Set("b", 2, NoExpiration)
+ tc.Set("c", 3, 20*time.Millisecond)
+ tc.Set("d", 4, 70*time.Millisecond)
+
+ <-time.After(25 * time.Millisecond)
+ _, found = tc.Get("c")
+ if found {
+ t.Error("Found c when it should have been automatically deleted")
+ }
+
+ <-time.After(30 * time.Millisecond)
+ _, found = tc.Get("a")
+ if found {
+ t.Error("Found a when it should have been automatically deleted")
+ }
+
+ _, found = tc.Get("b")
+ if !found {
+ t.Error("Did not find b even though it was set to never expire")
+ }
+
+ _, found = tc.Get("d")
+ if !found {
+ t.Error("Did not find d even though it was set to expire later than the default")
+ }
+
+ <-time.After(20 * time.Millisecond)
+ _, found = tc.Get("d")
+ if found {
+ t.Error("Found d when it should have been automatically deleted (later than the default)")
+ }
+}
+
+func TestNewFrom(t *testing.T) {
+ m := map[string]Item{
+ "a": Item{
+ Object: 1,
+ Expiration: 0,
+ },
+ "b": Item{
+ Object: 2,
+ Expiration: 0,
+ },
+ }
+ tc := NewFrom(DefaultExpiration, 0, maxItemsCount, m)
+ a, found := tc.Get("a")
+ if !found {
+ t.Fatal("Did not find a")
+ }
+ if a.(int) != 1 {
+ t.Fatal("a is not 1")
+ }
+ b, found := tc.Get("b")
+ if !found {
+ t.Fatal("Did not find b")
+ }
+ if b.(int) != 2 {
+ t.Fatal("b is not 2")
+ }
+}
+
+func TestStorePointerToStruct(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", &TestStruct{Num: 1}, DefaultExpiration)
+ x, found := tc.Get("foo")
+ if !found {
+ t.Fatal("*TestStruct was not found for foo")
+ }
+ foo := x.(*TestStruct)
+ foo.Num++
+
+ y, found := tc.Get("foo")
+ if !found {
+ t.Fatal("*TestStruct was not found for foo (second time)")
+ }
+ bar := y.(*TestStruct)
+ if bar.Num != 2 {
+ t.Fatal("TestStruct.Num is not 2")
+ }
+}
+
+func TestIncrementWithInt(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint", 1, DefaultExpiration)
+ err := tc.Increment("tint", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tint")
+ if !found {
+ t.Error("tint was not found")
+ }
+ if x.(int) != 3 {
+ t.Error("tint is not 3:", x)
+ }
+}
+
+func TestIncrementWithInt8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint8", int8(1), DefaultExpiration)
+ err := tc.Increment("tint8", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tint8")
+ if !found {
+ t.Error("tint8 was not found")
+ }
+ if x.(int8) != 3 {
+ t.Error("tint8 is not 3:", x)
+ }
+}
+
+func TestIncrementWithInt16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint16", int16(1), DefaultExpiration)
+ err := tc.Increment("tint16", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tint16")
+ if !found {
+ t.Error("tint16 was not found")
+ }
+ if x.(int16) != 3 {
+ t.Error("tint16 is not 3:", x)
+ }
+}
+
+func TestIncrementWithInt32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint32", int32(1), DefaultExpiration)
+ err := tc.Increment("tint32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tint32")
+ if !found {
+ t.Error("tint32 was not found")
+ }
+ if x.(int32) != 3 {
+ t.Error("tint32 is not 3:", x)
+ }
+}
+
+func TestIncrementWithInt64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint64", int64(1), DefaultExpiration)
+ err := tc.Increment("tint64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tint64")
+ if !found {
+ t.Error("tint64 was not found")
+ }
+ if x.(int64) != 3 {
+ t.Error("tint64 is not 3:", x)
+ }
+}
+
+func TestIncrementWithUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint", uint(1), DefaultExpiration)
+ err := tc.Increment("tuint", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tuint")
+ if !found {
+ t.Error("tuint was not found")
+ }
+ if x.(uint) != 3 {
+ t.Error("tuint is not 3:", x)
+ }
+}
+
+func TestIncrementWithUintptr(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuintptr", uintptr(1), DefaultExpiration)
+ err := tc.Increment("tuintptr", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+
+ x, found := tc.Get("tuintptr")
+ if !found {
+ t.Error("tuintptr was not found")
+ }
+ if x.(uintptr) != 3 {
+ t.Error("tuintptr is not 3:", x)
+ }
+}
+
+func TestIncrementWithUint8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint8", uint8(1), DefaultExpiration)
+ err := tc.Increment("tuint8", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tuint8")
+ if !found {
+ t.Error("tuint8 was not found")
+ }
+ if x.(uint8) != 3 {
+ t.Error("tuint8 is not 3:", x)
+ }
+}
+
+func TestIncrementWithUint16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint16", uint16(1), DefaultExpiration)
+ err := tc.Increment("tuint16", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+
+ x, found := tc.Get("tuint16")
+ if !found {
+ t.Error("tuint16 was not found")
+ }
+ if x.(uint16) != 3 {
+ t.Error("tuint16 is not 3:", x)
+ }
+}
+
+func TestIncrementWithUint32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint32", uint32(1), DefaultExpiration)
+ err := tc.Increment("tuint32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("tuint32")
+ if !found {
+ t.Error("tuint32 was not found")
+ }
+ if x.(uint32) != 3 {
+ t.Error("tuint32 is not 3:", x)
+ }
+}
+
+func TestIncrementWithUint64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint64", uint64(1), DefaultExpiration)
+ err := tc.Increment("tuint64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+
+ x, found := tc.Get("tuint64")
+ if !found {
+ t.Error("tuint64 was not found")
+ }
+ if x.(uint64) != 3 {
+ t.Error("tuint64 is not 3:", x)
+ }
+}
+
+func TestIncrementWithFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(1.5), DefaultExpiration)
+ err := tc.Increment("float32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3.5 {
+ t.Error("float32 is not 3.5:", x)
+ }
+}
+
+func TestIncrementWithFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(1.5), DefaultExpiration)
+ err := tc.Increment("float64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3.5 {
+ t.Error("float64 is not 3.5:", x)
+ }
+}
+
+func TestIncrementFloatWithFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(1.5), DefaultExpiration)
+ err := tc.IncrementFloat("float32", 2)
+ if err != nil {
+ t.Error("Error incrementfloating:", err)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3.5 {
+ t.Error("float32 is not 3.5:", x)
+ }
+}
+
+func TestIncrementFloatWithFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(1.5), DefaultExpiration)
+ err := tc.IncrementFloat("float64", 2)
+ if err != nil {
+ t.Error("Error incrementfloating:", err)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3.5 {
+ t.Error("float64 is not 3.5:", x)
+ }
+}
+
+func TestDecrementWithInt(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int", int(5), DefaultExpiration)
+ err := tc.Decrement("int", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("int")
+ if !found {
+ t.Error("int was not found")
+ }
+ if x.(int) != 3 {
+ t.Error("int is not 3:", x)
+ }
+}
+
+func TestDecrementWithInt8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int8", int8(5), DefaultExpiration)
+ err := tc.Decrement("int8", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("int8")
+ if !found {
+ t.Error("int8 was not found")
+ }
+ if x.(int8) != 3 {
+ t.Error("int8 is not 3:", x)
+ }
+}
+
+func TestDecrementWithInt16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int16", int16(5), DefaultExpiration)
+ err := tc.Decrement("int16", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("int16")
+ if !found {
+ t.Error("int16 was not found")
+ }
+ if x.(int16) != 3 {
+ t.Error("int16 is not 3:", x)
+ }
+}
+
+func TestDecrementWithInt32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int32", int32(5), DefaultExpiration)
+ err := tc.Decrement("int32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("int32")
+ if !found {
+ t.Error("int32 was not found")
+ }
+ if x.(int32) != 3 {
+ t.Error("int32 is not 3:", x)
+ }
+}
+
+func TestDecrementWithInt64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int64", int64(5), DefaultExpiration)
+ err := tc.Decrement("int64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("int64")
+ if !found {
+ t.Error("int64 was not found")
+ }
+ if x.(int64) != 3 {
+ t.Error("int64 is not 3:", x)
+ }
+}
+
+func TestDecrementWithUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint", uint(5), DefaultExpiration)
+ err := tc.Decrement("uint", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uint")
+ if !found {
+ t.Error("uint was not found")
+ }
+ if x.(uint) != 3 {
+ t.Error("uint is not 3:", x)
+ }
+}
+
+func TestDecrementWithUintptr(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uintptr", uintptr(5), DefaultExpiration)
+ err := tc.Decrement("uintptr", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uintptr")
+ if !found {
+ t.Error("uintptr was not found")
+ }
+ if x.(uintptr) != 3 {
+ t.Error("uintptr is not 3:", x)
+ }
+}
+
+func TestDecrementWithUint8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint8", uint8(5), DefaultExpiration)
+ err := tc.Decrement("uint8", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uint8")
+ if !found {
+ t.Error("uint8 was not found")
+ }
+ if x.(uint8) != 3 {
+ t.Error("uint8 is not 3:", x)
+ }
+}
+
+func TestDecrementWithUint16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint16", uint16(5), DefaultExpiration)
+ err := tc.Decrement("uint16", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uint16")
+ if !found {
+ t.Error("uint16 was not found")
+ }
+ if x.(uint16) != 3 {
+ t.Error("uint16 is not 3:", x)
+ }
+}
+
+func TestDecrementWithUint32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint32", uint32(5), DefaultExpiration)
+ err := tc.Decrement("uint32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uint32")
+ if !found {
+ t.Error("uint32 was not found")
+ }
+ if x.(uint32) != 3 {
+ t.Error("uint32 is not 3:", x)
+ }
+}
+
+func TestDecrementWithUint64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint64", uint64(5), DefaultExpiration)
+ err := tc.Decrement("uint64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("uint64")
+ if !found {
+ t.Error("uint64 was not found")
+ }
+ if x.(uint64) != 3 {
+ t.Error("uint64 is not 3:", x)
+ }
+}
+
+func TestDecrementWithFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(5.5), DefaultExpiration)
+ err := tc.Decrement("float32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3.5 {
+ t.Error("float32 is not 3:", x)
+ }
+}
+
+func TestDecrementWithFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(5.5), DefaultExpiration)
+ err := tc.Decrement("float64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3.5 {
+ t.Error("float64 is not 3:", x)
+ }
+}
+
+func TestDecrementFloatWithFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(5.5), DefaultExpiration)
+ err := tc.DecrementFloat("float32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3.5 {
+ t.Error("float32 is not 3:", x)
+ }
+}
+
+func TestDecrementFloatWithFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(5.5), DefaultExpiration)
+ err := tc.DecrementFloat("float64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3.5 {
+ t.Error("float64 is not 3:", x)
+ }
+}
+
+func TestIncrementInt(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint", 1, DefaultExpiration)
+ n, err := tc.IncrementInt("tint", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tint")
+ if !found {
+ t.Error("tint was not found")
+ }
+ if x.(int) != 3 {
+ t.Error("tint is not 3:", x)
+ }
+}
+
+func TestIncrementInt8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint8", int8(1), DefaultExpiration)
+ n, err := tc.IncrementInt8("tint8", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tint8")
+ if !found {
+ t.Error("tint8 was not found")
+ }
+ if x.(int8) != 3 {
+ t.Error("tint8 is not 3:", x)
+ }
+}
+
+func TestIncrementInt16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint16", int16(1), DefaultExpiration)
+ n, err := tc.IncrementInt16("tint16", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tint16")
+ if !found {
+ t.Error("tint16 was not found")
+ }
+ if x.(int16) != 3 {
+ t.Error("tint16 is not 3:", x)
+ }
+}
+
+func TestIncrementInt32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint32", int32(1), DefaultExpiration)
+ n, err := tc.IncrementInt32("tint32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tint32")
+ if !found {
+ t.Error("tint32 was not found")
+ }
+ if x.(int32) != 3 {
+ t.Error("tint32 is not 3:", x)
+ }
+}
+
+func TestIncrementInt64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tint64", int64(1), DefaultExpiration)
+ n, err := tc.IncrementInt64("tint64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tint64")
+ if !found {
+ t.Error("tint64 was not found")
+ }
+ if x.(int64) != 3 {
+ t.Error("tint64 is not 3:", x)
+ }
+}
+
+func TestIncrementUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint", uint(1), DefaultExpiration)
+ n, err := tc.IncrementUint("tuint", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuint")
+ if !found {
+ t.Error("tuint was not found")
+ }
+ if x.(uint) != 3 {
+ t.Error("tuint is not 3:", x)
+ }
+}
+
+func TestIncrementUintptr(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuintptr", uintptr(1), DefaultExpiration)
+ n, err := tc.IncrementUintptr("tuintptr", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuintptr")
+ if !found {
+ t.Error("tuintptr was not found")
+ }
+ if x.(uintptr) != 3 {
+ t.Error("tuintptr is not 3:", x)
+ }
+}
+
+func TestIncrementUint8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint8", uint8(1), DefaultExpiration)
+ n, err := tc.IncrementUint8("tuint8", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuint8")
+ if !found {
+ t.Error("tuint8 was not found")
+ }
+ if x.(uint8) != 3 {
+ t.Error("tuint8 is not 3:", x)
+ }
+}
+
+func TestIncrementUint16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint16", uint16(1), DefaultExpiration)
+ n, err := tc.IncrementUint16("tuint16", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuint16")
+ if !found {
+ t.Error("tuint16 was not found")
+ }
+ if x.(uint16) != 3 {
+ t.Error("tuint16 is not 3:", x)
+ }
+}
+
+func TestIncrementUint32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint32", uint32(1), DefaultExpiration)
+ n, err := tc.IncrementUint32("tuint32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuint32")
+ if !found {
+ t.Error("tuint32 was not found")
+ }
+ if x.(uint32) != 3 {
+ t.Error("tuint32 is not 3:", x)
+ }
+}
+
+func TestIncrementUint64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("tuint64", uint64(1), DefaultExpiration)
+ n, err := tc.IncrementUint64("tuint64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("tuint64")
+ if !found {
+ t.Error("tuint64 was not found")
+ }
+ if x.(uint64) != 3 {
+ t.Error("tuint64 is not 3:", x)
+ }
+}
+
+func TestIncrementFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(1.5), DefaultExpiration)
+ n, err := tc.IncrementFloat32("float32", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3.5 {
+ t.Error("Returned number is not 3.5:", n)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3.5 {
+ t.Error("float32 is not 3.5:", x)
+ }
+}
+
+func TestIncrementFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(1.5), DefaultExpiration)
+ n, err := tc.IncrementFloat64("float64", 2)
+ if err != nil {
+ t.Error("Error incrementing:", err)
+ }
+ if n != 3.5 {
+ t.Error("Returned number is not 3.5:", n)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3.5 {
+ t.Error("float64 is not 3.5:", x)
+ }
+}
+
+func TestDecrementInt8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int8", int8(5), DefaultExpiration)
+ n, err := tc.DecrementInt8("int8", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("int8")
+ if !found {
+ t.Error("int8 was not found")
+ }
+ if x.(int8) != 3 {
+ t.Error("int8 is not 3:", x)
+ }
+}
+
+func TestDecrementInt16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int16", int16(5), DefaultExpiration)
+ n, err := tc.DecrementInt16("int16", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("int16")
+ if !found {
+ t.Error("int16 was not found")
+ }
+ if x.(int16) != 3 {
+ t.Error("int16 is not 3:", x)
+ }
+}
+
+func TestDecrementInt32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int32", int32(5), DefaultExpiration)
+ n, err := tc.DecrementInt32("int32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("int32")
+ if !found {
+ t.Error("int32 was not found")
+ }
+ if x.(int32) != 3 {
+ t.Error("int32 is not 3:", x)
+ }
+}
+
+func TestDecrementInt64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int64", int64(5), DefaultExpiration)
+ n, err := tc.DecrementInt64("int64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("int64")
+ if !found {
+ t.Error("int64 was not found")
+ }
+ if x.(int64) != 3 {
+ t.Error("int64 is not 3:", x)
+ }
+}
+
+func TestDecrementUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint", uint(5), DefaultExpiration)
+ n, err := tc.DecrementUint("uint", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uint")
+ if !found {
+ t.Error("uint was not found")
+ }
+ if x.(uint) != 3 {
+ t.Error("uint is not 3:", x)
+ }
+}
+
+func TestDecrementUintptr(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uintptr", uintptr(5), DefaultExpiration)
+ n, err := tc.DecrementUintptr("uintptr", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uintptr")
+ if !found {
+ t.Error("uintptr was not found")
+ }
+ if x.(uintptr) != 3 {
+ t.Error("uintptr is not 3:", x)
+ }
+}
+
+func TestDecrementUint8(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint8", uint8(5), DefaultExpiration)
+ n, err := tc.DecrementUint8("uint8", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uint8")
+ if !found {
+ t.Error("uint8 was not found")
+ }
+ if x.(uint8) != 3 {
+ t.Error("uint8 is not 3:", x)
+ }
+}
+
+func TestDecrementUint16(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint16", uint16(5), DefaultExpiration)
+ n, err := tc.DecrementUint16("uint16", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uint16")
+ if !found {
+ t.Error("uint16 was not found")
+ }
+ if x.(uint16) != 3 {
+ t.Error("uint16 is not 3:", x)
+ }
+}
+
+func TestDecrementUint32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint32", uint32(5), DefaultExpiration)
+ n, err := tc.DecrementUint32("uint32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uint32")
+ if !found {
+ t.Error("uint32 was not found")
+ }
+ if x.(uint32) != 3 {
+ t.Error("uint32 is not 3:", x)
+ }
+}
+
+func TestDecrementUint64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint64", uint64(5), DefaultExpiration)
+ n, err := tc.DecrementUint64("uint64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("uint64")
+ if !found {
+ t.Error("uint64 was not found")
+ }
+ if x.(uint64) != 3 {
+ t.Error("uint64 is not 3:", x)
+ }
+}
+
+func TestDecrementFloat32(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float32", float32(5), DefaultExpiration)
+ n, err := tc.DecrementFloat32("float32", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("float32")
+ if !found {
+ t.Error("float32 was not found")
+ }
+ if x.(float32) != 3 {
+ t.Error("float32 is not 3:", x)
+ }
+}
+
+func TestDecrementFloat64(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("float64", float64(5), DefaultExpiration)
+ n, err := tc.DecrementFloat64("float64", 2)
+ if err != nil {
+ t.Error("Error decrementing:", err)
+ }
+ if n != 3 {
+ t.Error("Returned number is not 3:", n)
+ }
+ x, found := tc.Get("float64")
+ if !found {
+ t.Error("float64 was not found")
+ }
+ if x.(float64) != 3 {
+ t.Error("float64 is not 3:", x)
+ }
+}
+
+func TestAdd(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ err := tc.Add("foo", "bar", DefaultExpiration)
+ if err != nil {
+ t.Error("Couldn't add foo even though it shouldn't exist")
+ }
+ err = tc.Add("foo", "baz", DefaultExpiration)
+ if err == nil {
+ t.Error("Successfully added another foo when it should have returned an error")
+ }
+}
+
+func TestReplace(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ err := tc.Replace("foo", "bar", DefaultExpiration)
+ if err == nil {
+ t.Error("Replaced foo when it shouldn't exist")
+ }
+ tc.Set("foo", "bar", DefaultExpiration)
+ err = tc.Replace("foo", "bar", DefaultExpiration)
+ if err != nil {
+ t.Error("Couldn't replace existing key foo")
+ }
+}
+
+func TestDelete(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", "bar", DefaultExpiration)
+ tc.Delete("foo")
+ x, found := tc.Get("foo")
+ if found {
+ t.Error("foo was found, but it should have been deleted")
+ }
+ if x != nil {
+ t.Error("x is not nil:", x)
+ }
+}
+
+func TestItemCount(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", "1", DefaultExpiration)
+ tc.Set("bar", "2", DefaultExpiration)
+ tc.Set("baz", "3", DefaultExpiration)
+ if n := tc.ItemCount(); n != 3 {
+ t.Errorf("Item count is not 3: %d", n)
+ }
+}
+
+func TestFlush(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", "bar", DefaultExpiration)
+ tc.Set("baz", "yes", DefaultExpiration)
+ tc.Flush()
+ x, found := tc.Get("foo")
+ if found {
+ t.Error("foo was found, but it should have been deleted")
+ }
+ if x != nil {
+ t.Error("x is not nil:", x)
+ }
+ x, found = tc.Get("baz")
+ if found {
+ t.Error("baz was found, but it should have been deleted")
+ }
+ if x != nil {
+ t.Error("x is not nil:", x)
+ }
+}
+
+func TestIncrementOverflowInt(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("int8", int8(127), DefaultExpiration)
+ err := tc.Increment("int8", 1)
+ if err != nil {
+ t.Error("Error incrementing int8:", err)
+ }
+ x, _ := tc.Get("int8")
+ int8 := x.(int8)
+ if int8 != -128 {
+ t.Error("int8 did not overflow as expected; value:", int8)
+ }
+
+}
+
+func TestIncrementOverflowUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint8", uint8(255), DefaultExpiration)
+ err := tc.Increment("uint8", 1)
+ if err != nil {
+ t.Error("Error incrementing int8:", err)
+ }
+ x, _ := tc.Get("uint8")
+ uint8 := x.(uint8)
+ if uint8 != 0 {
+ t.Error("uint8 did not overflow as expected; value:", uint8)
+ }
+}
+
+func TestDecrementUnderflowUint(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("uint8", uint8(0), DefaultExpiration)
+ err := tc.Decrement("uint8", 1)
+ if err != nil {
+ t.Error("Error decrementing int8:", err)
+ }
+ x, _ := tc.Get("uint8")
+ uint8 := x.(uint8)
+ if uint8 != 255 {
+ t.Error("uint8 did not underflow as expected; value:", uint8)
+ }
+}
+
+func TestOnEvicted(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", 3, DefaultExpiration)
+ if tc.onEvicted != nil {
+ t.Fatal("tc.onEvicted is not nil")
+ }
+ works := false
+ tc.OnEvicted(func(k string, v interface{}) {
+ if k == "foo" && v.(int) == 3 {
+ works = true
+ }
+ tc.Set("bar", 4, DefaultExpiration)
+ })
+ tc.Delete("foo")
+ x, _ := tc.Get("bar")
+ if !works {
+ t.Error("works bool not true")
+ }
+ if x.(int) != 4 {
+ t.Error("bar was not 4")
+ }
+}
+
+func TestCacheSerialization(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ testFillAndSerialize(t, tc)
+
+ // Check if gob.Register behaves properly even after multiple gob.Register
+ // on c.Items (many of which will be the same type)
+ testFillAndSerialize(t, tc)
+}
+
+func testFillAndSerialize(t *testing.T, tc *Cache) {
+ tc.Set("a", "a", DefaultExpiration)
+ tc.Set("b", "b", DefaultExpiration)
+ tc.Set("c", "c", DefaultExpiration)
+ tc.Set("expired", "foo", 1*time.Millisecond)
+ tc.Set("*struct", &TestStruct{Num: 1}, DefaultExpiration)
+ tc.Set("[]struct", []TestStruct{
+ {Num: 2},
+ {Num: 3},
+ }, DefaultExpiration)
+ tc.Set("[]*struct", []*TestStruct{
+ &TestStruct{Num: 4},
+ &TestStruct{Num: 5},
+ }, DefaultExpiration)
+ tc.Set("structception", &TestStruct{
+ Num: 42,
+ Children: []*TestStruct{
+ &TestStruct{Num: 6174},
+ &TestStruct{Num: 4716},
+ },
+ }, DefaultExpiration)
+
+ fp := &bytes.Buffer{}
+ err := tc.Save(fp)
+ if err != nil {
+ t.Fatal("Couldn't save cache to fp:", err)
+ }
+
+ oc := New(DefaultExpiration, 0, maxItemsCount)
+ err = oc.Load(fp)
+ if err != nil {
+ t.Fatal("Couldn't load cache from fp:", err)
+ }
+
+ a, found := oc.Get("a")
+ if !found {
+ t.Error("a was not found")
+ }
+ if a.(string) != "a" {
+ t.Error("a is not a")
+ }
+
+ b, found := oc.Get("b")
+ if !found {
+ t.Error("b was not found")
+ }
+ if b.(string) != "b" {
+ t.Error("b is not b")
+ }
+
+ c, found := oc.Get("c")
+ if !found {
+ t.Error("c was not found")
+ }
+ if c.(string) != "c" {
+ t.Error("c is not c")
+ }
+
+ <-time.After(5 * time.Millisecond)
+ _, found = oc.Get("expired")
+ if found {
+ t.Error("expired was found")
+ }
+
+ s1, found := oc.Get("*struct")
+ if !found {
+ t.Error("*struct was not found")
+ }
+ if s1.(*TestStruct).Num != 1 {
+ t.Error("*struct.Num is not 1")
+ }
+
+ s2, found := oc.Get("[]struct")
+ if !found {
+ t.Error("[]struct was not found")
+ }
+ s2r := s2.([]TestStruct)
+ if len(s2r) != 2 {
+ t.Error("Length of s2r is not 2")
+ }
+ if s2r[0].Num != 2 {
+ t.Error("s2r[0].Num is not 2")
+ }
+ if s2r[1].Num != 3 {
+ t.Error("s2r[1].Num is not 3")
+ }
+
+ s3, found := oc.get("[]*struct")
+ if !found {
+ t.Error("[]*struct was not found")
+ }
+ s3r := s3.([]*TestStruct)
+ if len(s3r) != 2 {
+ t.Error("Length of s3r is not 2")
+ }
+ if s3r[0].Num != 4 {
+ t.Error("s3r[0].Num is not 4")
+ }
+ if s3r[1].Num != 5 {
+ t.Error("s3r[1].Num is not 5")
+ }
+
+ s4, found := oc.get("structception")
+ if !found {
+ t.Error("structception was not found")
+ }
+ s4r := s4.(*TestStruct)
+ if len(s4r.Children) != 2 {
+ t.Error("Length of s4r.Children is not 2")
+ }
+ if s4r.Children[0].Num != 6174 {
+ t.Error("s4r.Children[0].Num is not 6174")
+ }
+ if s4r.Children[1].Num != 4716 {
+ t.Error("s4r.Children[1].Num is not 4716")
+ }
+}
+
+func TestFileSerialization(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Add("a", "a", DefaultExpiration)
+ tc.Add("b", "b", DefaultExpiration)
+ f, err := ioutil.TempFile("", "go-cache-cache.dat")
+ if err != nil {
+ t.Fatal("Couldn't create cache file:", err)
+ }
+ fname := f.Name()
+ f.Close()
+ tc.SaveFile(fname)
+
+ oc := New(DefaultExpiration, 0, maxItemsCount)
+ oc.Add("a", "aa", 0) // this should not be overwritten
+ err = oc.LoadFile(fname)
+ if err != nil {
+ t.Error(err)
+ }
+ a, found := oc.Get("a")
+ if !found {
+ t.Error("a was not found")
+ }
+ astr := a.(string)
+ if astr != "aa" {
+ if astr == "a" {
+ t.Error("a was overwritten")
+ } else {
+ t.Error("a is not aa")
+ }
+ }
+ b, found := oc.Get("b")
+ if !found {
+ t.Error("b was not found")
+ }
+ if b.(string) != "b" {
+ t.Error("b is not b")
+ }
+}
+
+func TestSerializeUnserializable(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ ch := make(chan bool, 1)
+ ch <- true
+ tc.Set("chan", ch, DefaultExpiration)
+ fp := &bytes.Buffer{}
+ err := tc.Save(fp) // this should fail gracefully
+ if err.Error() != "gob NewTypeObject can't handle type: chan bool" {
+ t.Error("Error from Save was not gob NewTypeObject can't handle type chan bool:", err)
+ }
+}
+
+func BenchmarkCacheGetExpiring(b *testing.B) {
+ benchmarkCacheGet(b, 5*time.Minute)
+}
+
+func BenchmarkCacheGetNotExpiring(b *testing.B) {
+ benchmarkCacheGet(b, NoExpiration)
+}
+
+func benchmarkCacheGet(b *testing.B, exp time.Duration) {
+ b.StopTimer()
+ tc := New(exp, 0, maxItemsCount)
+ tc.Set("foo", "bar", DefaultExpiration)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.Get("foo")
+ }
+}
+
+func BenchmarkRWMutexMapGet(b *testing.B) {
+ b.StopTimer()
+ m := map[string]string{
+ "foo": "bar",
+ }
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.RLock()
+ _, _ = m["foo"]
+ mu.RUnlock()
+ }
+}
+
+func BenchmarkRWMutexInterfaceMapGetStruct(b *testing.B) {
+ b.StopTimer()
+ s := struct{ name string }{name: "foo"}
+ m := map[interface{}]string{
+ s: "bar",
+ }
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.RLock()
+ _, _ = m[s]
+ mu.RUnlock()
+ }
+}
+
+func BenchmarkRWMutexInterfaceMapGetString(b *testing.B) {
+ b.StopTimer()
+ m := map[interface{}]string{
+ "foo": "bar",
+ }
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.RLock()
+ _, _ = m["foo"]
+ mu.RUnlock()
+ }
+}
+
+func BenchmarkCacheGetConcurrentExpiring(b *testing.B) {
+ benchmarkCacheGetConcurrent(b, 5*time.Minute)
+}
+
+func BenchmarkCacheGetConcurrentNotExpiring(b *testing.B) {
+ benchmarkCacheGetConcurrent(b, NoExpiration)
+}
+
+func benchmarkCacheGetConcurrent(b *testing.B, exp time.Duration) {
+ b.StopTimer()
+ tc := New(exp, 0, maxItemsCount)
+ tc.Set("foo", "bar", DefaultExpiration)
+ wg := new(sync.WaitGroup)
+ workers := runtime.NumCPU()
+ each := b.N / workers
+ wg.Add(workers)
+ b.StartTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ for j := 0; j < each; j++ {
+ tc.Get("foo")
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkRWMutexMapGetConcurrent(b *testing.B) {
+ b.StopTimer()
+ m := map[string]string{
+ "foo": "bar",
+ }
+ mu := sync.RWMutex{}
+ wg := new(sync.WaitGroup)
+ workers := runtime.NumCPU()
+ each := b.N / workers
+ wg.Add(workers)
+ b.StartTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ for j := 0; j < each; j++ {
+ mu.RLock()
+ _, _ = m["foo"]
+ mu.RUnlock()
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkCacheGetManyConcurrentExpiring(b *testing.B) {
+ benchmarkCacheGetManyConcurrent(b, 5*time.Minute)
+}
+
+func BenchmarkCacheGetManyConcurrentNotExpiring(b *testing.B) {
+ benchmarkCacheGetManyConcurrent(b, NoExpiration)
+}
+
+func benchmarkCacheGetManyConcurrent(b *testing.B, exp time.Duration) {
+ // This is the same as BenchmarkCacheGetConcurrent, but its result
+ // can be compared against BenchmarkShardedCacheGetManyConcurrent
+ // in sharded_test.go.
+ b.StopTimer()
+ n := 10000
+ tc := New(exp, 0, maxItemsCount)
+ keys := make([]string, n)
+ for i := 0; i < n; i++ {
+ k := "foo" + strconv.Itoa(i)
+ keys[i] = k
+ tc.Set(k, "bar", DefaultExpiration)
+ }
+ each := b.N / n
+ wg := new(sync.WaitGroup)
+ wg.Add(n)
+ for _, v := range keys {
+ go func(k string) {
+ for j := 0; j < each; j++ {
+ tc.Get(k)
+ }
+ wg.Done()
+ }(v)
+ }
+ b.StartTimer()
+ wg.Wait()
+}
+
+func BenchmarkCacheSetExpiring(b *testing.B) {
+ benchmarkCacheSet(b, 5*time.Minute)
+}
+
+func BenchmarkCacheSetNotExpiring(b *testing.B) {
+ benchmarkCacheSet(b, NoExpiration)
+}
+
+func benchmarkCacheSet(b *testing.B, exp time.Duration) {
+ b.StopTimer()
+ tc := New(exp, 0, maxItemsCount)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.Set("foo", "bar", DefaultExpiration)
+ }
+}
+
+func BenchmarkRWMutexMapSet(b *testing.B) {
+ b.StopTimer()
+ m := map[string]string{}
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m["foo"] = "bar"
+ mu.Unlock()
+ }
+}
+
+func BenchmarkCacheSetDelete(b *testing.B) {
+ b.StopTimer()
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.Set("foo", "bar", DefaultExpiration)
+ tc.Delete("foo")
+ }
+}
+
+func BenchmarkRWMutexMapSetDelete(b *testing.B) {
+ b.StopTimer()
+ m := map[string]string{}
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m["foo"] = "bar"
+ mu.Unlock()
+ mu.Lock()
+ delete(m, "foo")
+ mu.Unlock()
+ }
+}
+
+func BenchmarkCacheSetDeleteSingleLock(b *testing.B) {
+ b.StopTimer()
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.mu.Lock()
+ tc.set("foo", "bar", DefaultExpiration)
+ tc.delete("foo")
+ tc.mu.Unlock()
+ }
+}
+
+func BenchmarkRWMutexMapSetDeleteSingleLock(b *testing.B) {
+ b.StopTimer()
+ m := map[string]string{}
+ mu := sync.RWMutex{}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m["foo"] = "bar"
+ delete(m, "foo")
+ mu.Unlock()
+ }
+}
+
+func BenchmarkIncrementInt(b *testing.B) {
+ b.StopTimer()
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+ tc.Set("foo", 0, DefaultExpiration)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.IncrementInt("foo", 1)
+ }
+}
+
+func BenchmarkDeleteExpiredLoop(b *testing.B) {
+ b.StopTimer()
+ tc := New(5*time.Minute, 0, maxItemsCount)
+ tc.mu.Lock()
+ for i := 0; i < 100000; i++ {
+ tc.set(strconv.Itoa(i), "bar", DefaultExpiration)
+ }
+ tc.mu.Unlock()
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.DeleteExpired()
+ }
+}
+
+func TestGetWithExpiration(t *testing.T) {
+ tc := New(DefaultExpiration, 0, maxItemsCount)
+
+ a, expiration, found := tc.GetWithExpiration("a")
+ if found || a != nil || !expiration.IsZero() {
+ t.Error("Getting A found value that shouldn't exist:", a)
+ }
+
+ b, expiration, found := tc.GetWithExpiration("b")
+ if found || b != nil || !expiration.IsZero() {
+ t.Error("Getting B found value that shouldn't exist:", b)
+ }
+
+ c, expiration, found := tc.GetWithExpiration("c")
+ if found || c != nil || !expiration.IsZero() {
+ t.Error("Getting C found value that shouldn't exist:", c)
+ }
+
+ tc.Set("a", 1, DefaultExpiration)
+ tc.Set("b", "b", DefaultExpiration)
+ tc.Set("c", 3.5, DefaultExpiration)
+ tc.Set("d", 1, NoExpiration)
+ tc.Set("e", 1, 50*time.Millisecond)
+
+ x, expiration, found := tc.GetWithExpiration("a")
+ if !found {
+ t.Error("a was not found while getting a2")
+ }
+ if x == nil {
+ t.Error("x for a is nil")
+ } else if a2 := x.(int); a2+2 != 3 {
+ t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
+ }
+ if !expiration.IsZero() {
+ t.Error("expiration for a is not a zeroed time")
+ }
+
+ x, expiration, found = tc.GetWithExpiration("b")
+ if !found {
+ t.Error("b was not found while getting b2")
+ }
+ if x == nil {
+ t.Error("x for b is nil")
+ } else if b2 := x.(string); b2+"B" != "bB" {
+ t.Error("b2 (which should be b) plus B does not equal bB; value:", b2)
+ }
+ if !expiration.IsZero() {
+ t.Error("expiration for b is not a zeroed time")
+ }
+
+ x, expiration, found = tc.GetWithExpiration("c")
+ if !found {
+ t.Error("c was not found while getting c2")
+ }
+ if x == nil {
+ t.Error("x for c is nil")
+ } else if c2 := x.(float64); c2+1.2 != 4.7 {
+ t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2)
+ }
+ if !expiration.IsZero() {
+ t.Error("expiration for c is not a zeroed time")
+ }
+
+ x, expiration, found = tc.GetWithExpiration("d")
+ if !found {
+ t.Error("d was not found while getting d2")
+ }
+ if x == nil {
+ t.Error("x for d is nil")
+ } else if d2 := x.(int); d2+2 != 3 {
+ t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2)
+ }
+ if !expiration.IsZero() {
+ t.Error("expiration for d is not a zeroed time")
+ }
+
+ x, expiration, found = tc.GetWithExpiration("e")
+ if !found {
+ t.Error("e was not found while getting e2")
+ }
+ if x == nil {
+ t.Error("x for e is nil")
+ } else if e2 := x.(int); e2+2 != 3 {
+ t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2)
+ }
+ if expiration.UnixNano() != tc.items["e"].Expiration {
+ t.Error("expiration for e is not the correct time")
+ }
+ if expiration.UnixNano() < time.Now().UnixNano() {
+ t.Error("expiration for e is in the past")
+ }
+}
diff --git a/memorycacher/sharded.go b/memorycacher/sharded.go
new file mode 100644
index 0000000..e507068
--- /dev/null
+++ b/memorycacher/sharded.go
@@ -0,0 +1,221 @@
+/*
+ * @Author: patrickmn,gitsrc
+ * @Date: 2020-07-09 13:17:30
+ * @LastEditors: gitsrc
+ * @LastEditTime: 2020-07-09 13:22:41
+ * @FilePath: /ServiceCar/utils/memorycache/sharded.go
+ */
+
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package memorycacher
+
+import (
+ "crypto/rand"
+ "math"
+ "math/big"
+ insecurerand "math/rand"
+ "os"
+ "runtime"
+ "time"
+)
+
+// This is an experimental and unexported (for now) attempt at making a cache
+// with better algorithmic complexity than the standard one, namely by
+// preventing write locks of the entire cache when an item is added. As of the
+// time of writing, the overhead of selecting buckets results in cache
+// operations being about twice as slow as for the standard cache with small
+// total cache sizes, and faster for larger ones.
+//
+// See cache_test.go for a few benchmarks.
+
+type unexportedShardedCache struct {
+ *shardedCache
+}
+
+type shardedCache struct {
+ seed uint32
+ m uint32
+ cs []*cache
+ lastCleanTime time.Time //Last cleanup time
+ janitor *shardedJanitor
+}
+
+// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead.
+func djb33(seed uint32, k string) uint32 {
+ var (
+ l = uint32(len(k))
+ d = 5381 + seed + l
+ i = uint32(0)
+ )
+ // Why is all this 5x faster than a for loop?
+ if l >= 4 {
+ for i < l-4 {
+ d = (d * 33) ^ uint32(k[i])
+ d = (d * 33) ^ uint32(k[i+1])
+ d = (d * 33) ^ uint32(k[i+2])
+ d = (d * 33) ^ uint32(k[i+3])
+ i += 4
+ }
+ }
+ switch l - i {
+ case 1:
+ case 2:
+ d = (d * 33) ^ uint32(k[i])
+ case 3:
+ d = (d * 33) ^ uint32(k[i])
+ d = (d * 33) ^ uint32(k[i+1])
+ case 4:
+ d = (d * 33) ^ uint32(k[i])
+ d = (d * 33) ^ uint32(k[i+1])
+ d = (d * 33) ^ uint32(k[i+2])
+ }
+ return d ^ (d >> 16)
+}
+
+func (sc *shardedCache) bucket(k string) *cache {
+ return sc.cs[djb33(sc.seed, k)%sc.m]
+}
+
+func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) {
+ sc.bucket(k).Set(k, x, d)
+}
+
+func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error {
+ return sc.bucket(k).Add(k, x, d)
+}
+
+func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error {
+ return sc.bucket(k).Replace(k, x, d)
+}
+
+func (sc *shardedCache) Get(k string) (interface{}, bool) {
+ return sc.bucket(k).Get(k)
+}
+
+func (sc *shardedCache) Increment(k string, n int64) error {
+ return sc.bucket(k).Increment(k, n)
+}
+
+func (sc *shardedCache) IncrementFloat(k string, n float64) error {
+ return sc.bucket(k).IncrementFloat(k, n)
+}
+
+func (sc *shardedCache) Decrement(k string, n int64) error {
+ return sc.bucket(k).Decrement(k, n)
+}
+
+func (sc *shardedCache) Delete(k string) {
+ sc.bucket(k).Delete(k)
+}
+
+func (sc *shardedCache) DeleteExpired() {
+ for _, v := range sc.cs {
+ v.DeleteExpired()
+ }
+}
+
+// Returns the items in the cache. This may include items that have expired,
+// but have not yet been cleaned up. If this is significant, the Expiration
+// fields of the items should be checked. Note that explicit synchronization
+// is needed to use a cache and its corresponding Items() return values at
+// the same time, as the maps are shared.
+func (sc *shardedCache) Items() []map[string]Item {
+ res := make([]map[string]Item, len(sc.cs))
+ for i, v := range sc.cs {
+ res[i] = v.Items()
+ }
+ return res
+}
+
+func (sc *shardedCache) Flush() {
+ for _, v := range sc.cs {
+ v.Flush()
+ }
+}
+
+type shardedJanitor struct {
+ Interval time.Duration
+ shoudClean chan bool //Signal should be cleaned up
+ stop chan bool
+}
+
+func (j *shardedJanitor) Run(sc *shardedCache) {
+ j.stop = make(chan bool)
+ tick := time.Tick(j.Interval)
+ for {
+ select {
+ case <-tick:
+ sc.DeleteExpired()
+ case <-j.shoudClean: //If received should Clean signal
+
+ sc.DeleteExpired()
+ case <-j.stop:
+ return
+ }
+ }
+}
+
+func stopShardedJanitor(sc *unexportedShardedCache) {
+ sc.janitor.stop <- true
+}
+
+func runShardedJanitor(sc *shardedCache, ci time.Duration) {
+ j := &shardedJanitor{
+ Interval: ci,
+ shoudClean: make(chan bool),
+ }
+ sc.janitor = j
+ go j.Run(sc)
+}
+
+func newShardedCache(n int, de time.Duration, maxItemsCount int) *shardedCache {
+ max := big.NewInt(0).SetUint64(uint64(math.MaxUint32))
+ rnd, err := rand.Int(rand.Reader, max)
+ var seed uint32
+ if err != nil {
+ os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n"))
+ seed = insecurerand.Uint32()
+ } else {
+ seed = uint32(rnd.Uint64())
+ }
+ sc := &shardedCache{
+ seed: seed,
+ m: uint32(n),
+ cs: make([]*cache, n),
+ }
+ for i := 0; i < n; i++ {
+ c := &cache{
+ defaultExpiration: de,
+ items: map[string]Item{},
+ maxItemsCount: maxItemsCount,
+ lastCleanTime: time.Now(),
+ }
+ sc.cs[i] = c
+ }
+ return sc
+}
+
+func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int, maxItemsCount int) *unexportedShardedCache {
+ if defaultExpiration == 0 {
+ defaultExpiration = -1
+ }
+ sc := newShardedCache(shards, defaultExpiration, maxItemsCount)
+ SC := &unexportedShardedCache{sc}
+ if cleanupInterval > 0 {
+ runShardedJanitor(sc, cleanupInterval)
+ runtime.SetFinalizer(SC, stopShardedJanitor)
+ }
+ return SC
+}
diff --git a/memorycacher/sharded_test.go b/memorycacher/sharded_test.go
new file mode 100644
index 0000000..19fe6cc
--- /dev/null
+++ b/memorycacher/sharded_test.go
@@ -0,0 +1,105 @@
+/*
+ * @Author: patrickmn,gitsrc
+ * @Date: 2020-07-09 13:17:30
+ * @LastEditors: gitsrc
+ * @LastEditTime: 2020-07-10 10:06:42
+ * @FilePath: /ServiceCar/utils/memorycacher/sharded_test.go
+ */
+/*
+Copyright 2022-present The ZTDBP Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package memorycacher
+
+import (
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+)
+
+// func TestDjb33(t *testing.T) {
+// }
+
+var shardedKeys = []string{
+ "f",
+ "fo",
+ "foo",
+ "barf",
+ "barfo",
+ "foobar",
+ "bazbarf",
+ "bazbarfo",
+ "bazbarfoo",
+ "foobarbazq",
+ "foobarbazqu",
+ "foobarbazquu",
+ "foobarbazquux",
+}
+
+func TestShardedCache(t *testing.T) {
+ tc := unexportedNewSharded(DefaultExpiration, time.Minute*2, 13, 100)
+ for _, v := range shardedKeys {
+ tc.Set(v, "value", DefaultExpiration)
+ }
+}
+
+func BenchmarkShardedCacheGetExpiring(b *testing.B) {
+ benchmarkShardedCacheGet(b, 5*time.Minute)
+}
+
+func BenchmarkShardedCacheGetNotExpiring(b *testing.B) {
+ benchmarkShardedCacheGet(b, NoExpiration)
+}
+
+func benchmarkShardedCacheGet(b *testing.B, exp time.Duration) {
+ b.StopTimer()
+ tc := unexportedNewSharded(exp, 0, 10, 0)
+ tc.Set("foobarba", "zquux", DefaultExpiration)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ tc.Get("foobarba")
+ }
+}
+
+func BenchmarkShardedCacheGetManyConcurrentExpiring(b *testing.B) {
+ benchmarkShardedCacheGetManyConcurrent(b, 5*time.Minute)
+}
+
+func BenchmarkShardedCacheGetManyConcurrentNotExpiring(b *testing.B) {
+ benchmarkShardedCacheGetManyConcurrent(b, NoExpiration)
+}
+
+func benchmarkShardedCacheGetManyConcurrent(b *testing.B, exp time.Duration) {
+ b.StopTimer()
+ n := 10000
+ tsc := unexportedNewSharded(exp, 0, 20, 100000)
+ keys := make([]string, n)
+ for i := 0; i < n; i++ {
+ k := "foo" + strconv.Itoa(i)
+ keys[i] = k
+ tsc.Set(k, "bar", DefaultExpiration)
+ }
+ each := b.N / n
+ wg := new(sync.WaitGroup)
+ wg.Add(n)
+ for _, v := range keys {
+ go func(k string) {
+ for j := 0; j < each; j++ {
+ tsc.Get(k)
+ }
+ wg.Done()
+ }(v)
+ }
+ b.StartTimer()
+ wg.Wait()
+}