diff --git a/sqltypes_test.go b/sqltypes_test.go index 09c08b1..411292d 100644 --- a/sqltypes_test.go +++ b/sqltypes_test.go @@ -48,6 +48,7 @@ type TypesSuite struct { skipJSON bool skipJSONB bool skipPostGIS bool + skipTSRange bool } var _ = Suite(&TypesSuite{}) @@ -78,6 +79,8 @@ func (s *TypesSuite) SetUpSuite(c *C) { if minor <= 1 { log.Print("json not available") s.skipJSON = true + log.Print("tsrange not available") + s.skipTSRange = true } if minor <= 3 { log.Print("jsonb not available") @@ -103,6 +106,11 @@ func (s *TypesSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) } + if !s.skipTSRange { + _, err = s.db.Exec(`ALTER TABLE pq_types ADD COLUMN tsrange tsrange`) + c.Assert(err, IsNil) + } + // check PostGIS db.Exec("CREATE EXTENSION postgis") row = db.QueryRow("SELECT PostGIS_full_version()") diff --git a/tsrange.go b/tsrange.go new file mode 100644 index 0000000..50dcf43 --- /dev/null +++ b/tsrange.go @@ -0,0 +1,107 @@ +package pq_types + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +// TimeBound represents Upper and Lower bound for TSRange. +// Time may be nil, so this will be infinity. +// If Time is nil and bound is Inclusive, it will be converted to exclusive by postgresql. +type TimeBound struct { + Inclusive bool + Time *time.Time +} + +// TSRange is a wrapper for postresql tsrange type. +type TSRange struct { + LowerBound TimeBound + UpperBound TimeBound +} + +const ( + timeFormat = "2006-01-02 15:04:05" +) + +// Value implements database/sql/driver Valuer interface. +func (t TSRange) Value() (driver.Value, error) { + res := []byte{} + if t.LowerBound.Inclusive { + res = append(res, '[') + } else { + res = append(res, '(') + } + if t.LowerBound.Time != nil { + tstr := t.LowerBound.Time.UTC().Truncate(time.Second).Format(timeFormat) + res = append(res, []byte(tstr)...) + } + res = append(res, ',') + if t.UpperBound.Time != nil { + tstr := t.UpperBound.Time.UTC().Truncate(time.Second).Format(timeFormat) + res = append(res, []byte(tstr)...) + } + if t.UpperBound.Inclusive { + res = append(res, ']') + } else { + res = append(res, ')') + } + return res, nil +} + +// Scan implements database/sql Scanner interface. +func (t *TSRange) Scan(value interface{}) error { + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("TSRange.Scan: expected []byte, got %T (%q)", value, value) + } + if len(v) < 3 { + return fmt.Errorf("TSRange.Scan: unexpected data %q", v) + } + if v[0] != '(' && v[0] != '[' { + return fmt.Errorf("TSRange.Scan: unexpected data %q", v) + } + if v[len(v)-1] != ')' && v[len(v)-1] != ']' { + return fmt.Errorf("TSRange.Scan: unexpected data %q", v) + } + if v[0] == '[' { + t.LowerBound.Inclusive = true + } else { + t.LowerBound.Inclusive = false + } + commaIdx := bytes.IndexByte(v, ',') + if commaIdx == -1 { + return fmt.Errorf("TSRange.Scan: no comma %q", v) + } + lt := v[1:commaIdx] + if len(lt) > 0 { + lt = lt[1 : len(lt)-1] + time, err := time.Parse(timeFormat, string(lt)) + if err != nil { + return fmt.Errorf("TSRange.Scan: error parsing lower bound time %s: %s", lt, err) + } + t.LowerBound.Time = &time + } + ut := v[commaIdx+1 : len(v)-1] + if len(ut) > 0 { + ut = ut[1 : len(ut)-1] + time, err := time.Parse(timeFormat, string(ut)) + if err != nil { + return fmt.Errorf("TSRange.Scan: error parsing upper bound time %s: %s", ut, err) + } + t.UpperBound.Time = &time + } + if v[len(v)-1] == ']' { + t.UpperBound.Inclusive = true + } else { + t.UpperBound.Inclusive = false + } + return nil +} + +var ( + _ driver.Valuer = TSRange{} + _ sql.Scanner = &TSRange{} +) diff --git a/tsrange_test.go b/tsrange_test.go new file mode 100644 index 0000000..2736966 --- /dev/null +++ b/tsrange_test.go @@ -0,0 +1,44 @@ +package pq_types + +import ( + "fmt" + "time" + + . "gopkg.in/check.v1" +) + +func (s *TypesSuite) TestTSRange(c *C) { + if s.skipTSRange { + c.Skip("TSRange not available") + } + type testData struct { + ts TSRange + s string + } + upperTime := time.Now().UTC().Truncate(time.Second) + lowerTime := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second) + utStr := upperTime.Format(timeFormat) + ltStr := lowerTime.Format(timeFormat) + for _, d := range []testData{ + {TSRange{TimeBound{true, &lowerTime}, TimeBound{true, &upperTime}}, fmt.Sprintf(`["%s","%s"]`, ltStr, utStr)}, + {TSRange{TimeBound{false, &lowerTime}, TimeBound{false, &upperTime}}, fmt.Sprintf(`("%s","%s")`, ltStr, utStr)}, + {TSRange{TimeBound{false, &lowerTime}, TimeBound{true, &upperTime}}, fmt.Sprintf(`("%s","%s"]`, ltStr, utStr)}, + {TSRange{TimeBound{true, &lowerTime}, TimeBound{false, &upperTime}}, fmt.Sprintf(`["%s","%s")`, ltStr, utStr)}, + {TSRange{TimeBound{false, nil}, TimeBound{true, &upperTime}}, fmt.Sprintf(`(,"%s"]`, utStr)}, + {TSRange{TimeBound{true, &lowerTime}, TimeBound{false, nil}}, fmt.Sprintf(`["%s",)`, ltStr)}, + {TSRange{TimeBound{false, nil}, TimeBound{false, &upperTime}}, fmt.Sprintf(`(,"%s")`, utStr)}, + {TSRange{TimeBound{false, &lowerTime}, TimeBound{false, nil}}, fmt.Sprintf(`("%s",)`, ltStr)}, + {TSRange{TimeBound{false, nil}, TimeBound{false, nil}}, "(,)"}, + } { + s.SetUpTest(c) + _, err := s.db.Exec("INSERT INTO pq_types (tsrange) VALUES($1)", d.ts) + c.Assert(err, IsNil) + + var el TSRange + var els string + err = s.db.QueryRow("SELECT tsrange, tsrange FROM pq_types").Scan(&el, &els) + c.Check(err, IsNil) + c.Check(d.ts, DeepEquals, el) + c.Check(d.s, Equals, els) + } +}