From 632e90fff3f07f84e67be9cd76d955f28241af90 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 11:42:27 -0800 Subject: [PATCH 1/7] Pre change refactor --- pquery/getter.go | 61 +++++++++------ pquery/getter_test.go | 168 ++++++++++++++++++++++++++++++++++++++++++ pquery/lister.go | 4 +- pquery/query.go | 8 +- psm/state_query.go | 6 +- 5 files changed, 214 insertions(+), 33 deletions(-) create mode 100644 pquery/getter_test.go diff --git a/pquery/getter.go b/pquery/getter.go index cdb9d09..e637ae1 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -34,13 +34,16 @@ type GetSpec[ TableName string DataColumn string Auth AuthProvider - AuthJoin []*LeftJoin + AuthJoin []*KeyJoin PrimaryKey func(REQ) (map[string]interface{}, error) StateResponseField protoreflect.Name - Join *GetJoinSpec + ArrayJoin *ArrayJoinSpec + + // ResponseDescriptor must describe the RES message, defaults to new(RES).ProtoReflect().Descriptor() + ResponseDescriptor protoreflect.MessageDescriptor } // JoinConstraint defines a @@ -77,14 +80,14 @@ func (jc JoinFields) SQL(rootAlias string, joinAlias string) string { return strings.Join(conditions, " AND ") } -type GetJoinSpec struct { +type ArrayJoinSpec struct { TableName string DataColumn string On JoinFields FieldInParent protoreflect.Name } -func (gc GetJoinSpec) validate() error { +func (gc ArrayJoinSpec) validate() error { if gc.TableName == "" { return fmt.Errorf("missing TableName") } @@ -108,7 +111,7 @@ type Getter[ tableName string primaryKey func(REQ) (map[string]interface{}, error) auth AuthProvider - authJoin []*LeftJoin + authJoin []*KeyJoin queryLogger QueryLogger @@ -118,18 +121,19 @@ type Getter[ } type getJoin struct { - dataColumn string - tableName string + ArrayJoinSpec fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type - on JoinFields } func NewGetter[ REQ GetRequest, RES GetResponse, ](spec GetSpec[REQ, RES]) (*Getter[REQ, RES], error) { - descriptors := newMethodDescriptor[REQ, RES]() - resDesc := descriptors.response + resDesc := spec.ResponseDescriptor + if resDesc == nil { + descriptors := newMethodDescriptor[REQ, RES]() + resDesc = descriptors.response + } sc := &Getter[REQ, RES]{ dataColumn: spec.DataColumn, @@ -148,35 +152,40 @@ func NewGetter[ sc.stateField = resDesc.Fields().ByName(spec.StateResponseField) if sc.stateField == nil { if defaultState { - return nil, fmt.Errorf("no 'state' field in proto message - did you mean to override StateResponseField?") + return nil, fmt.Errorf("no 'state' field in proto message, StateResponseField is left blank") } return nil, fmt.Errorf("no '%s' field in proto message", spec.StateResponseField) } + if spec.DataColumn == "" { + return nil, fmt.Errorf("GetSpec missing DataColumn") + } + if spec.PrimaryKey == nil { - return nil, fmt.Errorf("missing PrimaryKey func") + return nil, fmt.Errorf("GetSpec missing PrimaryKey function") } - if spec.Join != nil { + if spec.TableName == "" { + return nil, fmt.Errorf("GetSpec missing TableName") + } - if err := spec.Join.validate(); err != nil { + if spec.ArrayJoin != nil { + if err := spec.ArrayJoin.validate(); err != nil { return nil, fmt.Errorf("invalid join spec: %w", err) } - joinField := resDesc.Fields().ByName(protoreflect.Name(spec.Join.FieldInParent)) + joinField := resDesc.Fields().ByName(protoreflect.Name(spec.ArrayJoin.FieldInParent)) if joinField == nil { - return nil, fmt.Errorf("field %s not found in response message", spec.Join.FieldInParent) + return nil, fmt.Errorf("field %s not found in response message", spec.ArrayJoin.FieldInParent) } if !joinField.IsList() { - return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.Join.FieldInParent) + return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.ArrayJoin.FieldInParent) } sc.join = &getJoin{ - tableName: spec.Join.TableName, - dataColumn: spec.Join.DataColumn, + ArrayJoinSpec: *spec.ArrayJoin, fieldInParent: joinField, - on: spec.Join.On, } } @@ -209,6 +218,10 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return err } + if len(primaryKeyFields) == 0 { + return fmt.Errorf("PrimaryKey() returned no fields") + } + rootFilter, err := dbconvert.FieldsToEqMap(rootAlias, primaryKeyFields) if err != nil { return err @@ -252,15 +265,15 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, } if gc.join != nil { - joinAlias := as.Next(gc.join.tableName) + joinAlias := as.Next(gc.join.TableName) selectQuery. - Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, gc.join.dataColumn)). + Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, gc.join.DataColumn)). LeftJoin(fmt.Sprintf( "%s AS %s ON %s", - gc.join.tableName, + gc.join.TableName, joinAlias, - gc.join.on.SQL(rootAlias, joinAlias), + gc.join.On.SQL(rootAlias, joinAlias), )) } diff --git a/pquery/getter_test.go b/pquery/getter_test.go new file mode 100644 index 0000000..e4c5fef --- /dev/null +++ b/pquery/getter_test.go @@ -0,0 +1,168 @@ +package pquery + +import ( + "context" + "database/sql" + "strings" + "testing" + + "github.com/pentops/pgtest.go/pgtest" + "github.com/pentops/sqrlx.go/sqrlx" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +func TestGetter(t *testing.T) { + + fileDescriptor := &descriptorpb.FileDescriptorProto{ + Name: proto.String("test.proto"), + Package: proto.String("test"), + MessageType: []*descriptorpb.DescriptorProto{{ + Name: proto.String("TestRequest"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }}, + }, { + Name: proto.String("TestResponse"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("state"), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".test.Foo"), + Number: proto.Int32(1), + }}, + }, { + Name: proto.String("Foo"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }, { + Name: proto.String("name"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(2), + }}, + }}, + } + + file, err := protodesc.NewFile(fileDescriptor, protoregistry.GlobalFiles) + if err != nil { + t.Fatal(err.Error()) + } + + reqDesc := file.Messages().ByName("TestRequest") + if reqDesc == nil { + t.Fatal("reqDesc is nil") + } + + resDesc := file.Messages().ByName("TestResponse") + if resDesc == nil { + t.Fatal("resDesc is nil") + } + + spec := GetSpec[*tDynamicMessage, *tDynamicMessage]{ + ResponseDescriptor: resDesc, + DataColumn: "state", + TableName: "foo", + PrimaryKey: func(req *tDynamicMessage) (map[string]interface{}, error) { + return map[string]interface{}{ + "foo_id": req.GetField(t, "foo_id").String(), + }, nil + }, + } + + gg, err := NewGetter(spec) + if err != nil { + t.Fatal(err.Error()) + } + + ctx := context.Background() + conn := pgtest.GetTestDB(t, pgtest.WithSchemaName("pquery")) + + db := sqrlx.NewPostgres(conn) + + execAll(t, conn, + "CREATE TABLE foo (foo_id text PRIMARY KEY, state jsonb)", + `INSERT INTO foo (foo_id, state) VALUES + ( 'id0', '{"foo_id": "id0", "name": "foo0"}'), + ( 'id1', '{"foo_id": "id1", "name": "foo1"}') + `, + ) + + reqMsg := newDynamicMessage(reqDesc) + reqMsg.SetField(t, "foo_id", protoreflect.ValueOf("id0")) + + resMsg := newDynamicMessage(resDesc) + err = gg.Get(ctx, db, reqMsg, resMsg) + if err != nil { + t.Fatal(err.Error()) + } + + resMsg.AssertField(t, "state.foo_id", protoreflect.ValueOf("id0")) + resMsg.AssertField(t, "state.name", protoreflect.ValueOf("foo0")) +} + +func newDynamicMessage(md protoreflect.MessageDescriptor) *tDynamicMessage { + return &tDynamicMessage{dynamicpb.NewMessage(md)} +} + +type tDynamicMessage struct { + *dynamicpb.Message +} + +func (dm *tDynamicMessage) GetField(t *testing.T, name string) protoreflect.Value { + path := strings.Split(name, ".") + pathTo, last := path[:len(path)-1], path[len(path)-1] + var msg protoreflect.Message = dm.Message + for _, p := range pathTo { + field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(p))) + if !field.IsValid() { + t.Errorf("field %s is not valid", p) + } + msg = field.Message() + } + field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(last))) + if !field.IsValid() { + t.Errorf("field %s is not valid", last) + } + return field +} + +func (dm *tDynamicMessage) SetField(t *testing.T, name string, value protoreflect.Value) { + path := strings.Split(name, ".") + pathTo, last := path[:len(path)-1], path[len(path)-1] + var msg protoreflect.Message = dm.Message + for _, p := range pathTo { + field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(p))) + if field.IsValid() { + msg = field.Message() + continue + } + field = protoreflect.ValueOf(msg.NewField(msg.Descriptor().Fields().ByName(protoreflect.Name(p)))) + msg.Set(msg.Descriptor().Fields().ByName(protoreflect.Name(p)), field) + t.Errorf("field %s is not valid", p) + } + + msg.Set(msg.Descriptor().Fields().ByName(protoreflect.Name(last)), value) +} + +func (dm *tDynamicMessage) AssertField(t *testing.T, name string, value protoreflect.Value) { + field := dm.GetField(t, name) + if !field.Equal(value) { + t.Errorf("expected %v, got %v", value, field) + } +} + +func execAll(t *testing.T, conn *sql.DB, queries ...string) { + for _, query := range queries { + _, err := conn.Exec(query) + if err != nil { + t.Fatal(err.Error()) + } + } +} diff --git a/pquery/lister.go b/pquery/lister.go index f04e144..c1b8d68 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -43,7 +43,7 @@ type TableSpec struct { TableName string Auth AuthProvider - AuthJoin []*LeftJoin + AuthJoin []*KeyJoin DataColumn string // TODO: Replace with array Columns []Column @@ -215,7 +215,7 @@ type Lister[REQ ListRequest, RES ListResponse] struct { queryLogger QueryLogger auth AuthProvider - authJoin []*LeftJoin + authJoin []*KeyJoin requestFilter func(REQ) (map[string]interface{}, error) diff --git a/pquery/query.go b/pquery/query.go index 7306609..ebad360 100644 --- a/pquery/query.go +++ b/pquery/query.go @@ -34,10 +34,10 @@ func (f AuthProviderFunc) AuthFilter(ctx context.Context) (map[string]string, er return f(ctx) } -// LeftJoin is a specification for joining in the form -// ON . =
. -// Main is defined in the outer struct holding this LeftJoin -type LeftJoin struct { +// KeyJoin is a join on a primary (or unique) key in the RHS table. +// LEFT JOIN ON . =
. +// Main is defined in the outer struct holding this KeyJoin +type KeyJoin struct { TableName string On JoinFields } diff --git a/psm/state_query.go b/psm/state_query.go index 75ede22..586c8e2 100644 --- a/psm/state_query.go +++ b/psm/state_query.go @@ -89,7 +89,7 @@ func (gc *StateQuerySet[ type StateQueryOptions struct { Auth pquery.AuthProvider - AuthJoin *pquery.LeftJoin + AuthJoin *pquery.KeyJoin SkipEvents bool } @@ -156,7 +156,7 @@ func BuildStateQuerySet[ } if options.AuthJoin != nil { - getSpec.AuthJoin = []*pquery.LeftJoin{options.AuthJoin} + getSpec.AuthJoin = []*pquery.KeyJoin{options.AuthJoin} } if options.Auth != nil { @@ -196,7 +196,7 @@ func BuildStateQuerySet[ if smSpec.Event.Root == nil { return nil, fmt.Errorf("missing EventDataColumn in state spec for %s", smSpec.State.TableName) } - getSpec.Join = &pquery.GetJoinSpec{ + getSpec.ArrayJoin = &pquery.ArrayJoinSpec{ TableName: smSpec.Event.TableName, DataColumn: smSpec.Event.Root.ColumnName, FieldInParent: eventsInGet, From 383eac0b3cc27b596a3f72a74db7dfe0e31e4aab Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 12:09:51 -0800 Subject: [PATCH 2/7] Dynamic list of joins --- pquery/getter.go | 87 ++++++++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/pquery/getter.go b/pquery/getter.go index e637ae1..b5b860d 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -117,7 +117,7 @@ type Getter[ validator *protovalidate.Validator - join *getJoin + joins []*getJoin } type getJoin struct { @@ -125,6 +125,44 @@ type getJoin struct { fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type } +func (join *getJoin) apply(query *sq.SelectBuilder, rootAlias, joinAlias string) { + query.Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)). + LeftJoin(fmt.Sprintf( + "%s AS %s ON %s", + join.TableName, + joinAlias, + join.On.SQL(rootAlias, joinAlias), + )) +} + +func (join *getJoin) scanDest() interface{} { + v := pq.StringArray{} + return &v +} + +func (join *getJoin) unmarshal(rawData interface{}, resReflect protoreflect.Message) error { + data, ok := rawData.(*pq.StringArray) + + if !ok { + return fmt.Errorf("expected []string, got %T", rawData) + } + + elementList := resReflect.Mutable(join.fieldInParent).List() + for _, eventBytes := range *data { + if eventBytes == "" { + continue + } + + rowMessage := elementList.NewElement().Message() + if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { + return fmt.Errorf("joined unmarshal: %w", err) + } + elementList.Append(protoreflect.ValueOf(rowMessage)) + } + + return nil +} + func NewGetter[ REQ GetRequest, RES GetResponse, @@ -183,10 +221,10 @@ func NewGetter[ return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.ArrayJoin.FieldInParent) } - sc.join = &getJoin{ + sc.joins = append(sc.joins, &getJoin{ ArrayJoinSpec: *spec.ArrayJoin, fieldInParent: joinField, - } + }) } var err error @@ -264,21 +302,18 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, } } - if gc.join != nil { - joinAlias := as.Next(gc.join.TableName) + for _, join := range gc.joins { + joinAlias := as.Next(join.TableName) + join.apply(selectQuery, rootAlias, joinAlias) - selectQuery. - Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, gc.join.DataColumn)). - LeftJoin(fmt.Sprintf( - "%s AS %s ON %s", - gc.join.TableName, - joinAlias, - gc.join.On.SQL(rootAlias, joinAlias), - )) } var foundJSON []byte - var joinedJSON pq.StringArray + cols := make([]interface{}, 0) + cols = append(cols, &foundJSON) + for _, join := range gc.joins { + cols = append(cols, join.scanDest()) + } if gc.queryLogger != nil { gc.queryLogger(selectQuery) @@ -291,12 +326,7 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, }, func(ctx context.Context, tx sqrlx.Transaction) error { row := tx.SelectRow(ctx, selectQuery) - var err error - if gc.join != nil { - err = row.Scan(&foundJSON, &joinedJSON) - } else { - err = row.Scan(&foundJSON) - } + err := row.Scan(cols...) if err != nil { if errors.Is(err, sql.ErrNoRows) { var pkDescription string @@ -333,20 +363,11 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, } resReflect.Set(gc.stateField, stateMsg) - if gc.join != nil { - elementList := resReflect.Mutable(gc.join.fieldInParent).List() - for _, eventBytes := range joinedJSON { - if eventBytes == "" { - continue - } - - rowMessage := elementList.NewElement().Message() - if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { - return fmt.Errorf("joined unmarshal: %w", err) - } - elementList.Append(protoreflect.ValueOf(rowMessage)) + for i, join := range gc.joins { + iData := cols[i+1] + if err := join.unmarshal(iData, resReflect); err != nil { + return err } - } return nil From d105c12c0f033efd428f88bf7ecdba7aebab6544 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 15:39:50 -0800 Subject: [PATCH 3/7] Generic columns --- pquery/getter.go | 191 ++++++++++++++++++++++++------------------ pquery/getter_test.go | 68 ++++++++++++++- pquery/lister.go | 8 -- pquery/query.go | 63 ++++++++++++++ 4 files changed, 236 insertions(+), 94 deletions(-) diff --git a/pquery/getter.go b/pquery/getter.go index b5b860d..1ad45ae 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -9,7 +9,6 @@ import ( "github.com/bufbuild/protovalidate-go" sq "github.com/elgris/sqrl" - "github.com/lib/pq" "github.com/pentops/protostate/internal/dbconvert" "github.com/pentops/sqrlx.go/sqrlx" "google.golang.org/grpc/codes" @@ -101,13 +100,34 @@ func (gc ArrayJoinSpec) validate() error { return nil } +// jsonFieldRow is a jsonb SQL field mapped to a proto field. +type jsonFieldRow struct { + field protoreflect.FieldDescriptor + data []byte +} + +func (jc *jsonFieldRow) ScanTo() interface{} { + return &jc.data +} + +func (jc *jsonFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + if jc.data == nil { + return status.Error(codes.NotFound, "not found") + } + + stateMsg := resReflect.NewField(jc.field) + if err := protojson.Unmarshal(jc.data, stateMsg.Message().Interface()); err != nil { + return err + } + resReflect.Set(jc.field, stateMsg) + return nil +} + type Getter[ REQ GetRequest, RES proto.Message, ] struct { - stateField protoreflect.FieldDescriptor - - dataColumn string tableName string primaryKey func(REQ) (map[string]interface{}, error) auth AuthProvider @@ -117,50 +137,29 @@ type Getter[ validator *protovalidate.Validator - joins []*getJoin + columns []ColumnSpec } -type getJoin struct { - ArrayJoinSpec - fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type +type jsonColumn struct { + sqlColumn string + field protoreflect.FieldDescriptor } -func (join *getJoin) apply(query *sq.SelectBuilder, rootAlias, joinAlias string) { - query.Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)). - LeftJoin(fmt.Sprintf( - "%s AS %s ON %s", - join.TableName, - joinAlias, - join.On.SQL(rootAlias, joinAlias), - )) +func newJsonColumn(sqlColumn string, protoField protoreflect.FieldDescriptor) jsonColumn { + return jsonColumn{ + sqlColumn: sqlColumn, + field: protoField, + } } -func (join *getJoin) scanDest() interface{} { - v := pq.StringArray{} - return &v +func (jc jsonColumn) ApplyQuery(tableAlias string, sb SelectBuilder) { + sb.Column(jc, fmt.Sprintf("%s.%s", tableAlias, jc.sqlColumn)) } -func (join *getJoin) unmarshal(rawData interface{}, resReflect protoreflect.Message) error { - data, ok := rawData.(*pq.StringArray) - - if !ok { - return fmt.Errorf("expected []string, got %T", rawData) +func (jc jsonColumn) NewRow() ScanDest { + return &jsonFieldRow{ + field: jc.field, } - - elementList := resReflect.Mutable(join.fieldInParent).List() - for _, eventBytes := range *data { - if eventBytes == "" { - continue - } - - rowMessage := elementList.NewElement().Message() - if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { - return fmt.Errorf("joined unmarshal: %w", err) - } - elementList.Append(protoreflect.ValueOf(rowMessage)) - } - - return nil } func NewGetter[ @@ -174,7 +173,6 @@ func NewGetter[ } sc := &Getter[REQ, RES]{ - dataColumn: spec.DataColumn, tableName: spec.TableName, primaryKey: spec.PrimaryKey, auth: spec.Auth, @@ -187,13 +185,15 @@ func NewGetter[ defaultState = true spec.StateResponseField = protoreflect.Name("state") } - sc.stateField = resDesc.Fields().ByName(spec.StateResponseField) - if sc.stateField == nil { + + stateField := resDesc.Fields().ByName(spec.StateResponseField) + if stateField == nil { if defaultState { return nil, fmt.Errorf("no 'state' field in proto message, StateResponseField is left blank") } return nil, fmt.Errorf("no '%s' field in proto message", spec.StateResponseField) } + sc.columns = append(sc.columns, newJsonColumn(spec.DataColumn, stateField)) if spec.DataColumn == "" { return nil, fmt.Errorf("GetSpec missing DataColumn") @@ -221,7 +221,7 @@ func NewGetter[ return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.ArrayJoin.FieldInParent) } - sc.joins = append(sc.joins, &getJoin{ + sc.columns = append(sc.columns, &jsonArrayColumn{ ArrayJoinSpec: *spec.ArrayJoin, fieldInParent: joinField, }) @@ -240,10 +240,52 @@ func (gc *Getter[REQ, RES]) SetQueryLogger(logger QueryLogger) { gc.queryLogger = logger } -func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, resMsg RES) error { +type selectBuilder struct { + *sq.SelectBuilder + aliasSet *aliasSet + rootAlias string + + scanDest []ColumnDest +} +func newSelectBuilder(rootTable string) *selectBuilder { as := newAliasSet() - rootAlias := as.Next(gc.tableName) + rootAlias := as.Next(rootTable) + sb := sq.Select(). + From(fmt.Sprintf("%s AS %s", rootTable, rootAlias)) + + return &selectBuilder{ + SelectBuilder: sb, + aliasSet: as, + rootAlias: rootAlias, + } +} + +type ColumnDest interface { + NewRow() ScanDest +} + +type ScanDest interface { + ScanTo() interface{} + Unmarshal(protoreflect.Message) error +} + +func (sb *selectBuilder) Column(into ColumnDest, stmt string, args ...interface{}) { + sb.SelectBuilder.Column(stmt, args...) + sb.scanDest = append(sb.scanDest, into) +} + +func (sb *selectBuilder) LeftJoin(join string, rest ...interface{}) { + sb.SelectBuilder.LeftJoin(join, rest...) +} + +func (sb *selectBuilder) TableAlias(tableName string) string { + return sb.aliasSet.Next(tableName) +} + +func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, resMsg RES) error { + + sb := newSelectBuilder(gc.tableName) resReflect := resMsg.ProtoReflect() @@ -260,27 +302,22 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return fmt.Errorf("PrimaryKey() returned no fields") } - rootFilter, err := dbconvert.FieldsToEqMap(rootAlias, primaryKeyFields) + rootFilter, err := dbconvert.FieldsToEqMap(sb.rootAlias, primaryKeyFields) if err != nil { return err } - - selectQuery := sq. - Select(). - Column(fmt.Sprintf("%s.%s", rootAlias, gc.dataColumn)). - From(fmt.Sprintf("%s AS %s", gc.tableName, rootAlias)). - Where(rootFilter) + sb.Where(rootFilter) for pkField := range rootFilter { - selectQuery.GroupBy(pkField) + sb.GroupBy(pkField) } if gc.auth != nil { - authAlias := rootAlias + authAlias := sb.rootAlias for _, join := range gc.authJoin { priorAlias := authAlias - authAlias = as.Next(join.TableName) - selectQuery = selectQuery.LeftJoin(fmt.Sprintf( + authAlias = sb.TableAlias(join.TableName) + sb.LeftJoin(fmt.Sprintf( "%s AS %s ON %s", join.TableName, authAlias, @@ -298,25 +335,24 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, for k, v := range authFilter { claimFilter[fmt.Sprintf("%s.%s", authAlias, k)] = v } - selectQuery.Where(claimFilter) + sb.Where(claimFilter) } } - for _, join := range gc.joins { - joinAlias := as.Next(join.TableName) - join.apply(selectQuery, rootAlias, joinAlias) - + for _, join := range gc.columns { + join.ApplyQuery(sb.rootAlias, sb) } - var foundJSON []byte - cols := make([]interface{}, 0) - cols = append(cols, &foundJSON) - for _, join := range gc.joins { - cols = append(cols, join.scanDest()) + joins := make([]ScanDest, 0, len(gc.columns)) + rowCols := make([]interface{}, 0, len(sb.scanDest)) + for _, inQuery := range sb.scanDest { + colRow := inQuery.NewRow() + joins = append(joins, colRow) + rowCols = append(rowCols, colRow.ScanTo()) } if gc.queryLogger != nil { - gc.queryLogger(selectQuery) + gc.queryLogger(sb.SelectBuilder) } if err := db.Transact(ctx, &sqrlx.TxOptions{ @@ -324,9 +360,9 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, Retryable: true, Isolation: sql.LevelReadCommitted, }, func(ctx context.Context, tx sqrlx.Transaction) error { - row := tx.SelectRow(ctx, selectQuery) + row := tx.SelectRow(ctx, sb.SelectBuilder) - err := row.Scan(cols...) + err := row.Scan(rowCols...) if err != nil { if errors.Is(err, sql.ErrNoRows) { var pkDescription string @@ -349,23 +385,12 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return nil }); err != nil { - query, _, _ := selectQuery.ToSql() + query, _, _ := sb.ToSql() return fmt.Errorf("%s: %w", query, err) } - if foundJSON == nil { - return status.Error(codes.NotFound, "not found") - } - - stateMsg := resReflect.NewField(gc.stateField) - if err := protojson.Unmarshal(foundJSON, stateMsg.Message().Interface()); err != nil { - return err - } - resReflect.Set(gc.stateField, stateMsg) - - for i, join := range gc.joins { - iData := cols[i+1] - if err := join.unmarshal(iData, resReflect); err != nil { + for _, join := range joins { + if err := join.Unmarshal(resReflect); err != nil { return err } } diff --git a/pquery/getter_test.go b/pquery/getter_test.go index e4c5fef..c703cf8 100644 --- a/pquery/getter_test.go +++ b/pquery/getter_test.go @@ -35,6 +35,12 @@ func TestGetter(t *testing.T) { Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), TypeName: proto.String(".test.Foo"), Number: proto.Int32(1), + }, { + Name: proto.String("bars"), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".test.Bar"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), }}, }, { Name: proto.String("Foo"), @@ -47,12 +53,27 @@ func TestGetter(t *testing.T) { Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), Number: proto.Int32(2), }}, + }, { + Name: proto.String("Bar"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("bar_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }, { + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(2), + }, { + Name: proto.String("name"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(3), + }}, }}, } file, err := protodesc.NewFile(fileDescriptor, protoregistry.GlobalFiles) if err != nil { - t.Fatal(err.Error()) + t.Fatalf("Compiling test proto: %s", err.Error()) } reqDesc := file.Messages().ByName("TestRequest") @@ -74,12 +95,23 @@ func TestGetter(t *testing.T) { "foo_id": req.GetField(t, "foo_id").String(), }, nil }, + + ArrayJoin: &ArrayJoinSpec{ + TableName: "bar", + FieldInParent: "bars", + DataColumn: "state", + On: []JoinField{{ + JoinColumn: "foo_id", + RootColumn: "foo_id", + }}, + }, } gg, err := NewGetter(spec) if err != nil { t.Fatal(err.Error()) } + gg.SetQueryLogger(testLog(t)) ctx := context.Background() conn := pgtest.GetTestDB(t, pgtest.WithSchemaName("pquery")) @@ -92,6 +124,12 @@ func TestGetter(t *testing.T) { ( 'id0', '{"foo_id": "id0", "name": "foo0"}'), ( 'id1', '{"foo_id": "id1", "name": "foo1"}') `, + + `CREATE TABLE bar (bar_id text PRIMARY KEY, foo_id text REFERENCES foo(foo_id), state jsonb)`, + `INSERT INTO bar (bar_id, foo_id, state) VALUES + ( 'bar0', 'id0', '{"bar_id": "bar0", "foo_id": "id0", "name": "bar0"}'), + ( 'bar1', 'id0', '{"bar_id": "bar1", "foo_id": "id1", "name": "bar1"}') + `, ) reqMsg := newDynamicMessage(reqDesc) @@ -105,12 +143,29 @@ func TestGetter(t *testing.T) { resMsg.AssertField(t, "state.foo_id", protoreflect.ValueOf("id0")) resMsg.AssertField(t, "state.name", protoreflect.ValueOf("foo0")) + + bars := resMsg.GetField(t, "bars").List() + if bars.Len() != 2 { + t.Fatalf("expected 2 bars, got %d", bars.Len()) + } + } func newDynamicMessage(md protoreflect.MessageDescriptor) *tDynamicMessage { return &tDynamicMessage{dynamicpb.NewMessage(md)} } +func testLog(t *testing.T) QueryLogger { + return func(qq sqrlx.Sqlizer) { + stmt, args, err := qq.ToSql() + if err != nil { + t.Fatal(err.Error()) + } + t.Logf("Query: %s, args: %v", stmt, args) + + } +} + type tDynamicMessage struct { *dynamicpb.Message } @@ -125,10 +180,17 @@ func (dm *tDynamicMessage) GetField(t *testing.T, name string) protoreflect.Valu t.Errorf("field %s is not valid", p) } msg = field.Message() + if msg == nil { + t.Fatalf("field %s: msg is nil", p) + } + } + fieldDef := msg.Descriptor().Fields().ByName(protoreflect.Name(last)) + if fieldDef == nil { + t.Fatalf("field %s: no such field", name) } - field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(last))) + field := msg.Get(fieldDef) if !field.IsValid() { - t.Errorf("field %s is not valid", last) + t.Errorf("field %s is not valid", name) } return field } diff --git a/pquery/lister.go b/pquery/lister.go index c1b8d68..e6d165c 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -51,14 +51,6 @@ type TableSpec struct { FallbackSortColumns []pgstore.ProtoFieldSpec } -type Column struct { - Name string - - // The point within the root element which is stored in the column. An empty - // path means this stores the root element, - MountPoint *pgstore.Path -} - type ListSpec[REQ ListRequest, RES ListResponse] struct { TableSpec RequestFilter func(REQ) (map[string]interface{}, error) diff --git a/pquery/query.go b/pquery/query.go index ebad360..8cde5df 100644 --- a/pquery/query.go +++ b/pquery/query.go @@ -4,7 +4,9 @@ import ( "context" "fmt" + "github.com/lib/pq" "github.com/pentops/sqrlx.go/sqrlx" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -56,3 +58,64 @@ func newMethodDescriptor[REQ proto.Message, RES proto.Message]() *methodDescript response: res.ProtoReflect().Descriptor(), } } + +type SelectBuilder interface { + Column(into ColumnDest, stmt string, args ...interface{}) + LeftJoin(join string, rest ...interface{}) + TableAlias(tableName string) string +} + +type ColumnSpec interface { + ApplyQuery(parentAlias string, sb SelectBuilder) +} + +type jsonArrayColumn struct { + ArrayJoinSpec + fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type +} + +func (join *jsonArrayColumn) NewRow() ScanDest { + scanDest := pq.StringArray{} + return &jsonArrayFieldRow{ + fieldInParent: join.fieldInParent, + column: &scanDest, + } +} + +func (join *jsonArrayColumn) ApplyQuery(parentTable string, sb SelectBuilder) { + joinAlias := sb.TableAlias(join.TableName) + sb.Column(join, fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)) + sb.LeftJoin(fmt.Sprintf( + "%s AS %s ON %s", + join.TableName, + joinAlias, + join.On.SQL(parentTable, joinAlias), + )) +} + +type jsonArrayFieldRow struct { + fieldInParent protoreflect.FieldDescriptor + column *pq.StringArray +} + +func (join *jsonArrayFieldRow) ScanTo() interface{} { + return join.column +} + +func (join *jsonArrayFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + elementList := resReflect.Mutable(join.fieldInParent).List() + for _, eventBytes := range *join.column { + if eventBytes == "" { + continue + } + + rowMessage := elementList.NewElement().Message() + if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { + return fmt.Errorf("joined unmarshal: %w", err) + } + elementList.Append(protoreflect.ValueOf(rowMessage)) + } + + return nil +} From 03b2f82928f44b5d3385caa25a5df332025ca4b8 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 15:43:34 -0800 Subject: [PATCH 4/7] cleanup --- pquery/getter.go | 18 +++------- pquery/query.go | 63 --------------------------------- pquery/select_build.go | 79 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 77 deletions(-) create mode 100644 pquery/select_build.go diff --git a/pquery/getter.go b/pquery/getter.go index 1ad45ae..41077cd 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -244,8 +244,7 @@ type selectBuilder struct { *sq.SelectBuilder aliasSet *aliasSet rootAlias string - - scanDest []ColumnDest + columns []ColumnDest } func newSelectBuilder(rootTable string) *selectBuilder { @@ -261,18 +260,9 @@ func newSelectBuilder(rootTable string) *selectBuilder { } } -type ColumnDest interface { - NewRow() ScanDest -} - -type ScanDest interface { - ScanTo() interface{} - Unmarshal(protoreflect.Message) error -} - func (sb *selectBuilder) Column(into ColumnDest, stmt string, args ...interface{}) { sb.SelectBuilder.Column(stmt, args...) - sb.scanDest = append(sb.scanDest, into) + sb.columns = append(sb.columns, into) } func (sb *selectBuilder) LeftJoin(join string, rest ...interface{}) { @@ -344,8 +334,8 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, } joins := make([]ScanDest, 0, len(gc.columns)) - rowCols := make([]interface{}, 0, len(sb.scanDest)) - for _, inQuery := range sb.scanDest { + rowCols := make([]interface{}, 0, len(sb.columns)) + for _, inQuery := range sb.columns { colRow := inQuery.NewRow() joins = append(joins, colRow) rowCols = append(rowCols, colRow.ScanTo()) diff --git a/pquery/query.go b/pquery/query.go index 8cde5df..ebad360 100644 --- a/pquery/query.go +++ b/pquery/query.go @@ -4,9 +4,7 @@ import ( "context" "fmt" - "github.com/lib/pq" "github.com/pentops/sqrlx.go/sqrlx" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -58,64 +56,3 @@ func newMethodDescriptor[REQ proto.Message, RES proto.Message]() *methodDescript response: res.ProtoReflect().Descriptor(), } } - -type SelectBuilder interface { - Column(into ColumnDest, stmt string, args ...interface{}) - LeftJoin(join string, rest ...interface{}) - TableAlias(tableName string) string -} - -type ColumnSpec interface { - ApplyQuery(parentAlias string, sb SelectBuilder) -} - -type jsonArrayColumn struct { - ArrayJoinSpec - fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type -} - -func (join *jsonArrayColumn) NewRow() ScanDest { - scanDest := pq.StringArray{} - return &jsonArrayFieldRow{ - fieldInParent: join.fieldInParent, - column: &scanDest, - } -} - -func (join *jsonArrayColumn) ApplyQuery(parentTable string, sb SelectBuilder) { - joinAlias := sb.TableAlias(join.TableName) - sb.Column(join, fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)) - sb.LeftJoin(fmt.Sprintf( - "%s AS %s ON %s", - join.TableName, - joinAlias, - join.On.SQL(parentTable, joinAlias), - )) -} - -type jsonArrayFieldRow struct { - fieldInParent protoreflect.FieldDescriptor - column *pq.StringArray -} - -func (join *jsonArrayFieldRow) ScanTo() interface{} { - return join.column -} - -func (join *jsonArrayFieldRow) Unmarshal(resReflect protoreflect.Message) error { - - elementList := resReflect.Mutable(join.fieldInParent).List() - for _, eventBytes := range *join.column { - if eventBytes == "" { - continue - } - - rowMessage := elementList.NewElement().Message() - if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { - return fmt.Errorf("joined unmarshal: %w", err) - } - elementList.Append(protoreflect.ValueOf(rowMessage)) - } - - return nil -} diff --git a/pquery/select_build.go b/pquery/select_build.go new file mode 100644 index 0000000..97780e0 --- /dev/null +++ b/pquery/select_build.go @@ -0,0 +1,79 @@ +package pquery + +import ( + "fmt" + + "github.com/lib/pq" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type ColumnDest interface { + NewRow() ScanDest +} + +type ScanDest interface { + ScanTo() interface{} + Unmarshal(protoreflect.Message) error +} + +type SelectBuilder interface { + Column(into ColumnDest, stmt string, args ...interface{}) + LeftJoin(join string, rest ...interface{}) + TableAlias(tableName string) string +} + +type ColumnSpec interface { + ApplyQuery(parentAlias string, sb SelectBuilder) +} + +type jsonArrayColumn struct { + ArrayJoinSpec + fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type +} + +func (join *jsonArrayColumn) NewRow() ScanDest { + scanDest := pq.StringArray{} + return &jsonArrayFieldRow{ + fieldInParent: join.fieldInParent, + column: &scanDest, + } +} + +func (join *jsonArrayColumn) ApplyQuery(parentTable string, sb SelectBuilder) { + joinAlias := sb.TableAlias(join.TableName) + sb.Column(join, fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)) + sb.LeftJoin(fmt.Sprintf( + "%s AS %s ON %s", + join.TableName, + joinAlias, + join.On.SQL(parentTable, joinAlias), + )) +} + +type jsonArrayFieldRow struct { + fieldInParent protoreflect.FieldDescriptor + column *pq.StringArray +} + +func (join *jsonArrayFieldRow) ScanTo() interface{} { + return join.column +} + +func (join *jsonArrayFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + elementList := resReflect.Mutable(join.fieldInParent).List() + for _, eventBytes := range *join.column { + if eventBytes == "" { + continue + } + + rowMessage := elementList.NewElement().Message() + if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { + return fmt.Errorf("joined unmarshal: %w", err) + } + elementList.Append(protoreflect.ValueOf(rowMessage)) + } + + return nil +} From e36dd55c2c02e53107d14741225d1d0812ab5ca0 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 16:16:54 -0800 Subject: [PATCH 5/7] Use select builder in list --- internal/integration/pagination_test.go | 4 +- pquery/getter.go | 95 +----------------------- pquery/lister.go | 76 ++++++++++--------- pquery/select_build.go | 98 +++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 128 deletions(-) diff --git a/internal/integration/pagination_test.go b/internal/integration/pagination_test.go index 2b56edc..911ad5f 100644 --- a/internal/integration/pagination_test.go +++ b/internal/integration/pagination_test.go @@ -106,7 +106,7 @@ func TestPagination(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - printQuery(t, query) + printQuery(t, query.SelectBuilder) err = queryer.List(ctx, db, req, res) if err != nil { @@ -240,7 +240,7 @@ func TestEventPagination(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - printQuery(t, query) + printQuery(t, query.SelectBuilder) err = queryer.ListEvents(ctx, db, req, res) if err != nil { diff --git a/pquery/getter.go b/pquery/getter.go index 41077cd..068e285 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -8,12 +8,10 @@ import ( "strings" "github.com/bufbuild/protovalidate-go" - sq "github.com/elgris/sqrl" "github.com/pentops/protostate/internal/dbconvert" "github.com/pentops/sqrlx.go/sqrlx" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -100,30 +98,6 @@ func (gc ArrayJoinSpec) validate() error { return nil } -// jsonFieldRow is a jsonb SQL field mapped to a proto field. -type jsonFieldRow struct { - field protoreflect.FieldDescriptor - data []byte -} - -func (jc *jsonFieldRow) ScanTo() interface{} { - return &jc.data -} - -func (jc *jsonFieldRow) Unmarshal(resReflect protoreflect.Message) error { - - if jc.data == nil { - return status.Error(codes.NotFound, "not found") - } - - stateMsg := resReflect.NewField(jc.field) - if err := protojson.Unmarshal(jc.data, stateMsg.Message().Interface()); err != nil { - return err - } - resReflect.Set(jc.field, stateMsg) - return nil -} - type Getter[ REQ GetRequest, RES proto.Message, @@ -140,28 +114,6 @@ type Getter[ columns []ColumnSpec } -type jsonColumn struct { - sqlColumn string - field protoreflect.FieldDescriptor -} - -func newJsonColumn(sqlColumn string, protoField protoreflect.FieldDescriptor) jsonColumn { - return jsonColumn{ - sqlColumn: sqlColumn, - field: protoField, - } -} - -func (jc jsonColumn) ApplyQuery(tableAlias string, sb SelectBuilder) { - sb.Column(jc, fmt.Sprintf("%s.%s", tableAlias, jc.sqlColumn)) -} - -func (jc jsonColumn) NewRow() ScanDest { - return &jsonFieldRow{ - field: jc.field, - } -} - func NewGetter[ REQ GetRequest, RES GetResponse, @@ -240,39 +192,6 @@ func (gc *Getter[REQ, RES]) SetQueryLogger(logger QueryLogger) { gc.queryLogger = logger } -type selectBuilder struct { - *sq.SelectBuilder - aliasSet *aliasSet - rootAlias string - columns []ColumnDest -} - -func newSelectBuilder(rootTable string) *selectBuilder { - as := newAliasSet() - rootAlias := as.Next(rootTable) - sb := sq.Select(). - From(fmt.Sprintf("%s AS %s", rootTable, rootAlias)) - - return &selectBuilder{ - SelectBuilder: sb, - aliasSet: as, - rootAlias: rootAlias, - } -} - -func (sb *selectBuilder) Column(into ColumnDest, stmt string, args ...interface{}) { - sb.SelectBuilder.Column(stmt, args...) - sb.columns = append(sb.columns, into) -} - -func (sb *selectBuilder) LeftJoin(join string, rest ...interface{}) { - sb.SelectBuilder.LeftJoin(join, rest...) -} - -func (sb *selectBuilder) TableAlias(tableName string) string { - return sb.aliasSet.Next(tableName) -} - func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, resMsg RES) error { sb := newSelectBuilder(gc.tableName) @@ -333,13 +252,7 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, join.ApplyQuery(sb.rootAlias, sb) } - joins := make([]ScanDest, 0, len(gc.columns)) - rowCols := make([]interface{}, 0, len(sb.columns)) - for _, inQuery := range sb.columns { - colRow := inQuery.NewRow() - joins = append(joins, colRow) - rowCols = append(rowCols, colRow.ScanTo()) - } + fields, scanCols := sb.NewRow() if gc.queryLogger != nil { gc.queryLogger(sb.SelectBuilder) @@ -352,7 +265,7 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, }, func(ctx context.Context, tx sqrlx.Transaction) error { row := tx.SelectRow(ctx, sb.SelectBuilder) - err := row.Scan(rowCols...) + err := row.Scan(scanCols...) if err != nil { if errors.Is(err, sql.ErrNoRows) { var pkDescription string @@ -379,8 +292,8 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return fmt.Errorf("%s: %w", query, err) } - for _, join := range joins { - if err := join.Unmarshal(resReflect); err != nil { + for _, field := range fields { + if err := field.Unmarshal(resReflect); err != nil { return err } } diff --git a/pquery/lister.go b/pquery/lister.go index e6d165c..771945b 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -12,7 +12,6 @@ import ( "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "github.com/bufbuild/protovalidate-go" - "github.com/elgris/sqrl" sq "github.com/elgris/sqrl" "github.com/elgris/sqrl/pg" "github.com/pentops/j5/gen/j5/list/v1/list_j5pb" @@ -21,7 +20,6 @@ import ( "github.com/pentops/sqrlx.go/sqrlx" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" @@ -246,6 +244,10 @@ func (ll *Lister[REQ, RES]) SetQueryLogger(logger QueryLogger) { ll.queryLogger = logger } +type listRow struct { + columns []ScanDest +} + func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg proto.Message, resMsg proto.Message) error { if err := ll.validator.Validate(reqMsg); err != nil { return fmt.Errorf("validating request %s: %w", reqMsg.ProtoReflect().Descriptor().FullName(), err) @@ -259,7 +261,7 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot return fmt.Errorf("get page size: %w", err) } - selectQuery, err := ll.BuildQuery(ctx, req, res) + sb, err := ll.BuildQuery(ctx, req, res) if err != nil { return fmt.Errorf("build query: %w", err) } @@ -270,45 +272,50 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot Isolation: sql.LevelReadCommitted, } - var jsonRows = make([][]byte, 0, pageSize) + listRows := make([]listRow, 0, pageSize) + err = db.Transact(ctx, txOpts, func(ctx context.Context, tx sqrlx.Transaction) error { - rows, err := tx.Query(ctx, selectQuery) + rows, err := tx.Query(ctx, sb.SelectBuilder) if err != nil { return fmt.Errorf("run select: %w", err) } defer rows.Close() for rows.Next() { - var json []byte - if err := rows.Scan(&json); err != nil { + + fields, scanCols := sb.NewRow() + row := listRow{columns: fields} + listRows = append(listRows, row) + + if err := rows.Scan(scanCols...); err != nil { return fmt.Errorf("row scan: %w", err) } - jsonRows = append(jsonRows, json) } return rows.Err() }) if err != nil { - stmt, _, _ := selectQuery.ToSql() + stmt, _, _ := sb.SelectBuilder.ToSql() log.WithField(ctx, "query", stmt).Error("list query") return fmt.Errorf("list query: %w", err) } if ll.queryLogger != nil { - ll.queryLogger(selectQuery) + ll.queryLogger(sb.SelectBuilder) } list := res.Mutable(ll.arrayField).List() res.Set(ll.arrayField, protoreflect.ValueOf(list)) var nextToken string - for idx, rowBytes := range jsonRows { + for idx, rowBytes := range listRows { rowMessage := list.NewElement().Message() - err := protojson.Unmarshal(rowBytes, rowMessage.Interface()) - if err != nil { - return fmt.Errorf("unmarshal into %s from %s: %w", rowMessage.Descriptor().FullName(), string(rowBytes), err) + for i, col := range rowBytes.columns { + if err := col.Unmarshal(rowMessage); err != nil { + return fmt.Errorf("unmarshal column %d: %w", i, err) + } } if idx >= int(pageSize) { @@ -339,12 +346,11 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot return nil } -func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*sqrl.SelectBuilder, error) { - as := newAliasSet() - tableAlias := as.Next(ll.tableName) +func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*selectBuilder, error) { - selectQuery := sq.Select(fmt.Sprintf("%s.%s", tableAlias, ll.dataColumn)). - From(fmt.Sprintf("%s AS %s", ll.tableName, tableAlias)) + sb := newSelectBuilder(ll.tableName) + root := newJsonColumn(ll.dataColumn, nil) + root.ApplyQuery(sb.rootAlias, sb) sortFields := ll.defaultSortFields sortFields = append(sortFields, ll.tieBreakerFields...) @@ -358,11 +364,11 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes and := sq.And{} for k := range filter { - and = append(and, sq.Expr(fmt.Sprintf("%s.%s = ?", tableAlias, k), filter[k])) + and = append(and, sq.Expr(fmt.Sprintf("%s.%s = ?", sb.rootAlias, k), filter[k])) } if len(and) > 0 { - selectQuery.Where(and) + sb.Where(and) } } @@ -384,7 +390,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes queryFilters := reqQuery.GetFilters() if len(queryFilters) > 0 { - dynFilters, err := ll.buildDynamicFilter(tableAlias, queryFilters) + dynFilters, err := ll.buildDynamicFilter(sb.rootAlias, queryFilters) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build filters: %s", err) } @@ -394,7 +400,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes querySearches := reqQuery.GetSearches() if len(querySearches) > 0 { - searchFilters, err := ll.buildDynamicSearches(tableAlias, querySearches) + searchFilters, err := ll.buildDynamicSearches(sb.rootAlias, querySearches) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build searches: %s", err) } @@ -404,7 +410,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes } for i := range filterFields { - selectQuery.Where(filterFields[i]) + sb.Where(filterFields[i]) } // apply default filters if no filters have been requested @@ -413,14 +419,14 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes for _, spec := range ll.defaultFilterFields { or := sq.Or{} for _, val := range spec.filterVals { - or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", tableAlias, ll.dataColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) + or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", sb.rootAlias, ll.dataColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) } and = append(and, or) } if len(and) > 0 { - selectQuery.Where(and) + sb.Where(and) } } @@ -429,15 +435,15 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes if sortField.desc { direction = "DESC" } - selectQuery.OrderBy(fmt.Sprintf("%s %s", sortField.Selector(tableAlias), direction)) + sb.OrderBy(fmt.Sprintf("%s %s", sortField.Selector(sb.rootAlias), direction)) } if ll.auth != nil { - authAlias := tableAlias + authAlias := sb.rootAlias for _, join := range ll.authJoin { priorAlias := authAlias - authAlias = as.Next(join.TableName) - selectQuery = selectQuery.LeftJoin(fmt.Sprintf( + authAlias = sb.TableAlias(join.TableName) + sb.LeftJoin(fmt.Sprintf( "%s AS %s ON %s", join.TableName, authAlias, @@ -455,7 +461,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes for k, v := range authFilter { claimFilter[fmt.Sprintf("%s.%s", authAlias, k)] = v } - selectQuery.Where(claimFilter) + sb.Where(claimFilter) } } @@ -464,7 +470,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes return nil, err } - selectQuery.Limit(pageSize + 1) + sb.Limit(pageSize + 1) reqPage, ok := req.Get(ll.pageRequestField).Message().Interface().(*list_j5pb.PageRequest) if ok && reqPage != nil && reqPage.GetToken() != "" { @@ -484,7 +490,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes rhsPlaceholders := make([]string, 0, len(sortFields)) for _, sortField := range sortFields { - rowSelecter := sortField.Selector(tableAlias) + rowSelecter := sortField.Selector(sb.rootAlias) valuePlaceholder := "?" fieldVal, err := sortField.Path.GetValue(rowMessage) @@ -592,14 +598,14 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes // does not actually matter in which order the string field is sorted... // or don't because indexes. - selectQuery = selectQuery.Where( + sb.Where( fmt.Sprintf("(%s) >= (%s)", strings.Join(lhsFields, ","), strings.Join(rhsPlaceholders, ","), ), rhsValues...) } - return selectQuery, nil + return sb, nil } func (ll *Lister[REQ, RES]) getPageSize(req protoreflect.Message) (uint64, error) { diff --git a/pquery/select_build.go b/pquery/select_build.go index 97780e0..786de25 100644 --- a/pquery/select_build.go +++ b/pquery/select_build.go @@ -3,7 +3,10 @@ package pquery import ( "fmt" + sq "github.com/elgris/sqrl" "github.com/lib/pq" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -23,10 +26,105 @@ type SelectBuilder interface { TableAlias(tableName string) string } +type selectBuilder struct { + *sq.SelectBuilder + aliasSet *aliasSet + rootAlias string + columns []ColumnDest +} + +func newSelectBuilder(rootTable string) *selectBuilder { + as := newAliasSet() + rootAlias := as.Next(rootTable) + sb := sq.Select(). + From(fmt.Sprintf("%s AS %s", rootTable, rootAlias)) + + return &selectBuilder{ + SelectBuilder: sb, + aliasSet: as, + rootAlias: rootAlias, + } +} + +func (sb *selectBuilder) NewRow() ([]ScanDest, []interface{}) { + fields := make([]ScanDest, 0, len(sb.columns)) + rowCols := make([]interface{}, 0, len(sb.columns)) + for _, inQuery := range sb.columns { + colRow := inQuery.NewRow() + fields = append(fields, colRow) + rowCols = append(rowCols, colRow.ScanTo()) + } + return fields, rowCols +} + +func (sb *selectBuilder) Column(into ColumnDest, stmt string, args ...interface{}) { + sb.SelectBuilder.Column(stmt, args...) + sb.columns = append(sb.columns, into) +} + +func (sb *selectBuilder) LeftJoin(join string, rest ...interface{}) { + sb.SelectBuilder.LeftJoin(join, rest...) +} + +func (sb *selectBuilder) TableAlias(tableName string) string { + return sb.aliasSet.Next(tableName) +} + type ColumnSpec interface { ApplyQuery(parentAlias string, sb SelectBuilder) } +type jsonColumn struct { + sqlColumn string + field protoreflect.FieldDescriptor +} + +func newJsonColumn(sqlColumn string, protoField protoreflect.FieldDescriptor) jsonColumn { + return jsonColumn{ + sqlColumn: sqlColumn, + field: protoField, + } +} + +func (jc jsonColumn) ApplyQuery(tableAlias string, sb SelectBuilder) { + sb.Column(jc, fmt.Sprintf("%s.%s", tableAlias, jc.sqlColumn)) +} + +func (jc jsonColumn) NewRow() ScanDest { + return &jsonFieldRow{ + field: jc.field, + } +} + +// jsonFieldRow is a jsonb SQL field mapped to a proto field. +type jsonFieldRow struct { + field protoreflect.FieldDescriptor + data []byte +} + +func (jc *jsonFieldRow) ScanTo() interface{} { + return &jc.data +} + +func (jc *jsonFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + if jc.data == nil { + return status.Error(codes.NotFound, "not found") + } + + msg := resReflect + if jc.field != nil { + msg = resReflect.Mutable(jc.field).Message() + resReflect.Set(jc.field, protoreflect.ValueOf(msg)) + } + + if err := protojson.Unmarshal(jc.data, msg.Interface()); err != nil { + return err + } + + return nil +} + type jsonArrayColumn struct { ArrayJoinSpec fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type From 5627a07144e82ce2e3cc51a0087247e9d5eb6e5e Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 16:21:17 -0800 Subject: [PATCH 6/7] nested path to select --- pquery/lister.go | 2 +- pquery/select_build.go | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pquery/lister.go b/pquery/lister.go index 771945b..3a89270 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -349,7 +349,7 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*selectBuilder, error) { sb := newSelectBuilder(ll.tableName) - root := newJsonColumn(ll.dataColumn, nil) + root := newJsonColumn(ll.dataColumn) root.ApplyQuery(sb.rootAlias, sb) sortFields := ll.defaultSortFields diff --git a/pquery/select_build.go b/pquery/select_build.go index 786de25..68b8666 100644 --- a/pquery/select_build.go +++ b/pquery/select_build.go @@ -76,13 +76,13 @@ type ColumnSpec interface { type jsonColumn struct { sqlColumn string - field protoreflect.FieldDescriptor + protoPath []protoreflect.FieldDescriptor } -func newJsonColumn(sqlColumn string, protoField protoreflect.FieldDescriptor) jsonColumn { +func newJsonColumn(sqlColumn string, protoPath ...protoreflect.FieldDescriptor) jsonColumn { return jsonColumn{ sqlColumn: sqlColumn, - field: protoField, + protoPath: protoPath, } } @@ -92,14 +92,14 @@ func (jc jsonColumn) ApplyQuery(tableAlias string, sb SelectBuilder) { func (jc jsonColumn) NewRow() ScanDest { return &jsonFieldRow{ - field: jc.field, + protoPath: jc.protoPath, } } // jsonFieldRow is a jsonb SQL field mapped to a proto field. type jsonFieldRow struct { - field protoreflect.FieldDescriptor - data []byte + protoPath []protoreflect.FieldDescriptor + data []byte } func (jc *jsonFieldRow) ScanTo() interface{} { @@ -113,9 +113,10 @@ func (jc *jsonFieldRow) Unmarshal(resReflect protoreflect.Message) error { } msg := resReflect - if jc.field != nil { - msg = resReflect.Mutable(jc.field).Message() - resReflect.Set(jc.field, protoreflect.ValueOf(msg)) + for _, field := range jc.protoPath { + child := resReflect.Mutable(field).Message() + msg.Set(field, protoreflect.ValueOf(child)) + msg = child } if err := protojson.Unmarshal(jc.data, msg.Interface()); err != nil { From 60050254a32e2fe3a1de0e22259465823a132d67 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Fri, 6 Dec 2024 16:39:50 -0800 Subject: [PATCH 7/7] List methods to functions --- pquery/filter.go | 20 +++++++++++--------- pquery/lister.go | 36 ++++++++++++++++++++---------------- pquery/sort.go | 6 +++--- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/pquery/filter.go b/pquery/filter.go index 22cedf8..2fbc50c 100644 --- a/pquery/filter.go +++ b/pquery/filter.go @@ -288,32 +288,34 @@ func buildDefaultFilters(columnName string, message protoreflect.MessageDescript return filters, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) { +func buildDynamicFilter(filterMessage protoreflect.MessageDescriptor, tableAlias string, dataColumn string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) { out := []sq.Sqlizer{} for i := range filters { switch filters[i].GetType().(type) { case *list_j5pb.Filter_Field: pathSpec := pgstore.ParseJSONPathSpec(filters[i].GetField().GetName()) - spec, err := pgstore.NewJSONPath(ll.arrayField.Message(), pathSpec) + spec, err := pgstore.NewJSONPath(filterMessage, pathSpec) if err != nil { return nil, fmt.Errorf("dynamic filter: find field: %w", err) } + // TODO: get RootColumn from the shortest matching list column + biggerSpec := &pgstore.NestedField{ Path: *spec, - RootColumn: ll.dataColumn, + RootColumn: dataColumn, } var o sq.Sqlizer switch leaf := spec.Leaf().(type) { case protoreflect.OneofDescriptor: - o, err = ll.buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i]) + o, err = buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i]) if err != nil { return nil, fmt.Errorf("dynamic filter: build oneof: %w", err) } case protoreflect.FieldDescriptor: - o, err = ll.buildDynamicFilterField(tableAlias, biggerSpec, filters[i]) + o, err = buildDynamicFilterField(tableAlias, biggerSpec, filters[i]) if err != nil { return nil, fmt.Errorf("dynamic filter: build field: %w", err) } @@ -323,7 +325,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis out = append(out, o) case *list_j5pb.Filter_And: - f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetAnd().GetFilters()) + f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetAnd().GetFilters()) if err != nil { return nil, fmt.Errorf("dynamic filter: and: %w", err) } @@ -332,7 +334,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis out = append(out, and) case *list_j5pb.Filter_Or: - f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetOr().GetFilters()) + f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetOr().GetFilters()) if err != nil { return nil, fmt.Errorf("dynamic filter: or: %w", err) } @@ -346,7 +348,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis return out, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { +func buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { var out sq.And if filter.GetField() == nil { @@ -396,7 +398,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgs return out, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { +func buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { var out sq.And if filter.GetField() == nil { diff --git a/pquery/lister.go b/pquery/lister.go index 3a89270..f9a18a4 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -72,10 +72,7 @@ type ListReflectionSet struct { tsvColumnMap map[string]string - // TODO: This should be an array/map of columns to data types, allowing - // multiple JSONB values, as well as cached field values direcrly on the - // table - dataColumn string + columns []ColumnSpec } func BuildListReflection(req protoreflect.MessageDescriptor, res protoreflect.MessageDescriptor, table TableSpec) (*ListReflectionSet, error) { @@ -86,10 +83,13 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me var err error ll := ListReflectionSet{ defaultPageSize: uint64(20), - dataColumn: table.DataColumn, } fields := res.Fields() + dataColumn := table.DataColumn + rootColumn := newJsonColumn(dataColumn) + ll.columns = append(ll.columns, rootColumn) + for i := 0; i < fields.Len(); i++ { field := fields.Get(i) msg := field.Message() @@ -127,12 +127,12 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me return nil, fmt.Errorf("validate list annotations on %s: %w", ll.arrayField.Message().FullName(), err) } - ll.defaultSortFields, err = buildDefaultSorts(ll.dataColumn, ll.arrayField.Message()) + ll.defaultSortFields, err = buildDefaultSorts(dataColumn, ll.arrayField.Message()) if err != nil { return nil, fmt.Errorf("default sorts: %w", err) } - ll.tieBreakerFields, err = buildTieBreakerFields(ll.dataColumn, req, ll.arrayField.Message(), table.FallbackSortColumns) + ll.tieBreakerFields, err = buildTieBreakerFields(dataColumn, req, ll.arrayField.Message(), table.FallbackSortColumns) if err != nil { return nil, fmt.Errorf("tie breaker fields: %w", err) } @@ -141,7 +141,7 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me return nil, fmt.Errorf("no default sort field found, %s must have at least one field annotated as default sort, or specify a tie breaker in %s", ll.arrayField.Message().FullName(), req.FullName()) } - f, err := buildDefaultFilters(ll.dataColumn, ll.arrayField.Message()) + f, err := buildDefaultFilters(dataColumn, ll.arrayField.Message()) if err != nil { return nil, fmt.Errorf("default filters: %w", err) } @@ -210,6 +210,8 @@ type Lister[REQ ListRequest, RES ListResponse] struct { requestFilter func(REQ) (map[string]interface{}, error) validator *protovalidate.Validator + + rootStateColumn string } func NewLister[ @@ -217,9 +219,10 @@ func NewLister[ RES ListResponse, ](spec ListSpec[REQ, RES]) (*Lister[REQ, RES], error) { ll := &Lister[REQ, RES]{ - tableName: spec.TableName, - auth: spec.Auth, - authJoin: spec.AuthJoin, + tableName: spec.TableName, + auth: spec.Auth, + authJoin: spec.AuthJoin, + rootStateColumn: spec.DataColumn, } descriptors := newMethodDescriptor[REQ, RES]() @@ -349,8 +352,9 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*selectBuilder, error) { sb := newSelectBuilder(ll.tableName) - root := newJsonColumn(ll.dataColumn) - root.ApplyQuery(sb.rootAlias, sb) + for _, colSpec := range ll.columns { + colSpec.ApplyQuery(sb.rootAlias, sb) + } sortFields := ll.defaultSortFields sortFields = append(sortFields, ll.tieBreakerFields...) @@ -380,7 +384,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes querySorts := reqQuery.GetSorts() if len(querySorts) > 0 { - dynSorts, err := ll.buildDynamicSortSpec(querySorts) + dynSorts, err := buildDynamicSortSpec(ll.arrayField.Message(), ll.rootStateColumn, querySorts) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build sorts: %s", err) } @@ -390,7 +394,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes queryFilters := reqQuery.GetFilters() if len(queryFilters) > 0 { - dynFilters, err := ll.buildDynamicFilter(sb.rootAlias, queryFilters) + dynFilters, err := buildDynamicFilter(ll.arrayField.Message(), sb.rootAlias, ll.rootStateColumn, queryFilters) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build filters: %s", err) } @@ -419,7 +423,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes for _, spec := range ll.defaultFilterFields { or := sq.Or{} for _, val := range spec.filterVals { - or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", sb.rootAlias, ll.dataColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) + or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", sb.rootAlias, spec.RootColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) } and = append(and, or) diff --git a/pquery/sort.go b/pquery/sort.go index f7bb30c..82fa36c 100644 --- a/pquery/sort.go +++ b/pquery/sort.go @@ -162,19 +162,19 @@ func buildDefaultSorts(columnName string, message protoreflect.MessageDescriptor return defaultSortFields, nil } -func (ll *Lister[REQ, RES]) buildDynamicSortSpec(sorts []*list_j5pb.Sort) ([]sortSpec, error) { +func buildDynamicSortSpec(rootMessage protoreflect.MessageDescriptor, rootColumn string, sorts []*list_j5pb.Sort) ([]sortSpec, error) { results := []sortSpec{} direction := "" for _, sort := range sorts { pathSpec := pgstore.ParseJSONPathSpec(sort.Field) - spec, err := pgstore.NewJSONPath(ll.arrayField.Message(), pathSpec) + spec, err := pgstore.NewJSONPath(rootMessage, pathSpec) if err != nil { return nil, fmt.Errorf("dynamic filter: find field: %w", err) } biggerSpec := &pgstore.NestedField{ Path: *spec, - RootColumn: ll.dataColumn, + RootColumn: rootColumn, } results = append(results, sortSpec{