From 9654e7a8b4a09a483ff6f80f11f45621d7de3da8 Mon Sep 17 00:00:00 2001 From: Kevin Pierce Date: Fri, 22 Apr 2016 15:11:37 -0700 Subject: [PATCH] Invert zero state for null fields. --- field/time.go | 38 +++++++++++++++----------------------- field/time_date.go | 34 +++++++++++++++------------------- field/time_date_test.go | 21 +++++++++++++++++++++ field/time_test.go | 21 +++++++++++++++++++++ field/time_time_test.go | 7 +++++++ 5 files changed, 79 insertions(+), 42 deletions(-) diff --git a/field/time.go b/field/time.go index 1645fed..93bf858 100644 --- a/field/time.go +++ b/field/time.go @@ -103,9 +103,9 @@ type nullTime null.Time // NullTime time that can be nil type NullTime struct { nullTime - validNull bool - shadow null.Time - shadowValidNull bool + invalidNull bool + shadow null.Time + shadowInvalidNull bool ShadowInit } @@ -116,41 +116,38 @@ func (nt *NullTime) Scan(value interface{}) error { case time.Time: if v.IsZero() { nt.Valid = false - nt.validNull = true + nt.invalidNull = true } else { nt.Time, nt.Valid = v, true - nt.validNull = false + nt.invalidNull = true } - break case []byte: nt.Time, err = parseDateTime(string(v), time.UTC) if nt.Time.IsZero() == true { nt.Valid = false - nt.validNull = false + nt.invalidNull = true return ErrorCouldNotScan("NullTime", value) } nt.Valid = (err == nil) if err == nil { - nt.validNull = false + nt.invalidNull = true } - break case string: nt.Time, err = parseDateTime(v, time.UTC) if nt.Time.IsZero() == true { nt.Valid = false - nt.validNull = false + nt.invalidNull = true return ErrorCouldNotScan("NullTime", value) } nt.Valid = (err == nil) if err == nil { - nt.validNull = false + nt.invalidNull = true } - break default: if value == nil { nt.Valid = false - nt.validNull = true + nt.invalidNull = false } else { err = ErrorCouldNotScan("NullTime", value) } @@ -159,19 +156,14 @@ func (nt *NullTime) Scan(value interface{}) error { // load shadow on first scan only nt.DoInit(func() { _ = nt.shadow.Scan(nt.Time) - if value == nil { - nt.shadowValidNull = true - } + nt.shadowInvalidNull = (value != nil) }) return err } // Value return the value of this field func (nt NullTime) Value() (driver.Value, error) { - if nt.validNull { - return nil, nil - } - if nt.Time.IsZero() { + if !nt.invalidNull { return nil, nil } return nt.Time, nil @@ -179,9 +171,9 @@ func (nt NullTime) Value() (driver.Value, error) { // IsDirty if the shadow value does not match the field value func (nt *NullTime) IsDirty() bool { - if nt.validNull && nt.shadowValidNull { + if !nt.invalidNull && !nt.shadowInvalidNull { return false - } else if nt.validNull == false && nt.shadowValidNull == false { + } else if nt.invalidNull && nt.shadowInvalidNull { return !nt.Time.Equal(nt.shadow.Time) } return true @@ -195,7 +187,7 @@ func (nt NullTime) IsSet() bool { // ShadowValue return the initial value of this field func (nt NullTime) ShadowValue() (driver.Value, error) { if nt.InitDone() { - if nt.shadowValidNull { + if !nt.shadowInvalidNull { return nil, nil } return nt.shadow.Value() diff --git a/field/time_date.go b/field/time_date.go index ad15c04..4a18e7c 100644 --- a/field/time_date.go +++ b/field/time_date.go @@ -96,9 +96,9 @@ func (t *TimeDate) UnmarshalJSON(data []byte) error { // NullTime time that can be nil type NullTimeDate struct { nullTime - validNull bool - shadow null.Time - shadowValidNull bool + invalidNull bool + shadow null.Time + shadowInvalidNull bool ShadowInit } @@ -109,41 +109,39 @@ func (nt *NullTimeDate) Scan(value interface{}) error { case time.Time: if v.IsZero() { nt.Valid = false - nt.validNull = true + nt.invalidNull = true } else { nt.Time, nt.Valid = v, true - nt.validNull = false + nt.invalidNull = true } - break case []byte: nt.Time, err = parseTimeDate(string(v)) if nt.Time.IsZero() == true { nt.Valid = false - nt.validNull = false + nt.invalidNull = true return ErrorCouldNotScan("NullTimeDate", value) } nt.Valid = (err == nil) if err == nil { - nt.validNull = false + nt.invalidNull = true } - break case string: nt.Time, err = parseTimeDate(v) if nt.Time.IsZero() == true { nt.Valid = false - nt.validNull = false + nt.invalidNull = true return ErrorCouldNotScan("NullTimeDate", value) } nt.Valid = (err == nil) if err == nil { - nt.validNull = false + nt.invalidNull = true } break default: if value == nil { nt.Valid = false - nt.validNull = true + nt.invalidNull = false } else { err = ErrorCouldNotScan("NullTimeDate", value) } @@ -152,16 +150,14 @@ func (nt *NullTimeDate) Scan(value interface{}) error { // load shadow on first scan only nt.DoInit(func() { _ = nt.shadow.Scan(nt.Time) - if value == nil { - nt.shadowValidNull = true - } + nt.shadowInvalidNull = (value != nil) }) return err } // Value return the value of this field func (nt NullTimeDate) Value() (driver.Value, error) { - if nt.validNull { + if !nt.invalidNull { return nil, nil } return nt.Time, nil @@ -169,9 +165,9 @@ func (nt NullTimeDate) Value() (driver.Value, error) { // IsDirty if the shadow value does not match the field value func (nt *NullTimeDate) IsDirty() bool { - if nt.validNull && nt.shadowValidNull { + if !nt.invalidNull && !nt.shadowInvalidNull { return false - } else if nt.validNull == false && nt.shadowValidNull == false { + } else if nt.invalidNull && nt.shadowInvalidNull { return !nt.Time.Equal(nt.shadow.Time) } return true @@ -185,7 +181,7 @@ func (nt NullTimeDate) IsSet() bool { // ShadowValue return the initial value of this field func (nt NullTimeDate) ShadowValue() (driver.Value, error) { if nt.InitDone() { - if nt.shadowValidNull { + if !nt.shadowInvalidNull { return nil, nil } return nt.shadow.Value() diff --git a/field/time_date_test.go b/field/time_date_test.go index 6b0780a..7cca641 100644 --- a/field/time_date_test.go +++ b/field/time_date_test.go @@ -181,6 +181,27 @@ func TestTimeDate(t *testing.T) { } func TestNullTimeDate(t *testing.T) { + Convey("Unscanned", t, func() { + Convey("Value should be Null", func() { + ns := NullTimeDate{} + t, err := ns.Value() + So(err, ShouldBeNil) + So(t, ShouldBeNil) + }) + + Convey("IsDirty should be false", func() { + ns := NullTimeDate{} + So(ns.IsDirty(), ShouldBeFalse) + }) + + Convey("Marshal should provide json null", func() { + ns := NullTimeDate{} + v, err := ns.MarshalJSON() + So(err, ShouldBeNil) + So(string(v), ShouldEqual, "null") + }) + }) + Convey("Scan", t, func() { Convey("Scan should load Time and Shadow field", func() { ns := NullTimeDate{} diff --git a/field/time_test.go b/field/time_test.go index 805a150..158c02a 100644 --- a/field/time_test.go +++ b/field/time_test.go @@ -181,6 +181,27 @@ func TestTime(t *testing.T) { } func TestNullTime(t *testing.T) { + Convey("Unscanned", t, func() { + Convey("Value should be Null", func() { + ns := NullTime{} + t, err := ns.Value() + So(err, ShouldBeNil) + So(t, ShouldBeNil) + }) + + Convey("IsDirty should be false", func() { + ns := NullTime{} + So(ns.IsDirty(), ShouldBeFalse) + }) + + Convey("Marshal should provide json null", func() { + ns := NullTime{} + v, err := ns.MarshalJSON() + So(err, ShouldBeNil) + So(string(v), ShouldEqual, "null") + }) + }) + Convey("Scan", t, func() { Convey("Scan should load Time and Shadow field", func() { ns := NullTime{} diff --git a/field/time_time_test.go b/field/time_time_test.go index 5c0469b..460dd4d 100644 --- a/field/time_time_test.go +++ b/field/time_time_test.go @@ -194,6 +194,13 @@ func TestNullTimeTime(t *testing.T) { ns := NullTimeTime{} So(ns.IsDirty(), ShouldBeFalse) }) + + Convey("Marshal should provide json null", func() { + ns := NullTimeTime{} + v, err := ns.MarshalJSON() + So(err, ShouldBeNil) + So(string(v), ShouldEqual, "null") + }) }) Convey("Scan", t, func() {