diff --git a/internal/integration/pagination_test.go b/internal/integration/pagination_test.go index 2b56edc..911ad5f 100644 --- a/internal/integration/pagination_test.go +++ b/internal/integration/pagination_test.go @@ -106,7 +106,7 @@ func TestPagination(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - printQuery(t, query) + printQuery(t, query.SelectBuilder) err = queryer.List(ctx, db, req, res) if err != nil { @@ -240,7 +240,7 @@ func TestEventPagination(t *testing.T) { if err != nil { t.Fatal(err.Error()) } - printQuery(t, query) + printQuery(t, query.SelectBuilder) err = queryer.ListEvents(ctx, db, req, res) if err != nil { diff --git a/pquery/filter.go b/pquery/filter.go index 22cedf8..2fbc50c 100644 --- a/pquery/filter.go +++ b/pquery/filter.go @@ -288,32 +288,34 @@ func buildDefaultFilters(columnName string, message protoreflect.MessageDescript return filters, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) { +func buildDynamicFilter(filterMessage protoreflect.MessageDescriptor, tableAlias string, dataColumn string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) { out := []sq.Sqlizer{} for i := range filters { switch filters[i].GetType().(type) { case *list_j5pb.Filter_Field: pathSpec := pgstore.ParseJSONPathSpec(filters[i].GetField().GetName()) - spec, err := pgstore.NewJSONPath(ll.arrayField.Message(), pathSpec) + spec, err := pgstore.NewJSONPath(filterMessage, pathSpec) if err != nil { return nil, fmt.Errorf("dynamic filter: find field: %w", err) } + // TODO: get RootColumn from the shortest matching list column + biggerSpec := &pgstore.NestedField{ Path: *spec, - RootColumn: ll.dataColumn, + RootColumn: dataColumn, } var o sq.Sqlizer switch leaf := spec.Leaf().(type) { case protoreflect.OneofDescriptor: - o, err = ll.buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i]) + o, err = buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i]) if err != nil { return nil, fmt.Errorf("dynamic filter: build oneof: %w", err) } case protoreflect.FieldDescriptor: - o, err = ll.buildDynamicFilterField(tableAlias, biggerSpec, filters[i]) + o, err = buildDynamicFilterField(tableAlias, biggerSpec, filters[i]) if err != nil { return nil, fmt.Errorf("dynamic filter: build field: %w", err) } @@ -323,7 +325,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis out = append(out, o) case *list_j5pb.Filter_And: - f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetAnd().GetFilters()) + f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetAnd().GetFilters()) if err != nil { return nil, fmt.Errorf("dynamic filter: and: %w", err) } @@ -332,7 +334,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis out = append(out, and) case *list_j5pb.Filter_Or: - f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetOr().GetFilters()) + f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetOr().GetFilters()) if err != nil { return nil, fmt.Errorf("dynamic filter: or: %w", err) } @@ -346,7 +348,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis return out, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { +func buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { var out sq.And if filter.GetField() == nil { @@ -396,7 +398,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgs return out, nil } -func (ll *Lister[REQ, RES]) buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { +func buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) { var out sq.And if filter.GetField() == nil { diff --git a/pquery/getter.go b/pquery/getter.go index cdb9d09..068e285 100644 --- a/pquery/getter.go +++ b/pquery/getter.go @@ -8,13 +8,10 @@ import ( "strings" "github.com/bufbuild/protovalidate-go" - sq "github.com/elgris/sqrl" - "github.com/lib/pq" "github.com/pentops/protostate/internal/dbconvert" "github.com/pentops/sqrlx.go/sqrlx" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -34,13 +31,16 @@ type GetSpec[ TableName string DataColumn string Auth AuthProvider - AuthJoin []*LeftJoin + AuthJoin []*KeyJoin PrimaryKey func(REQ) (map[string]interface{}, error) StateResponseField protoreflect.Name - Join *GetJoinSpec + ArrayJoin *ArrayJoinSpec + + // ResponseDescriptor must describe the RES message, defaults to new(RES).ProtoReflect().Descriptor() + ResponseDescriptor protoreflect.MessageDescriptor } // JoinConstraint defines a @@ -77,14 +77,14 @@ func (jc JoinFields) SQL(rootAlias string, joinAlias string) string { return strings.Join(conditions, " AND ") } -type GetJoinSpec struct { +type ArrayJoinSpec struct { TableName string DataColumn string On JoinFields FieldInParent protoreflect.Name } -func (gc GetJoinSpec) validate() error { +func (gc ArrayJoinSpec) validate() error { if gc.TableName == "" { return fmt.Errorf("missing TableName") } @@ -102,37 +102,29 @@ type Getter[ REQ GetRequest, RES proto.Message, ] struct { - stateField protoreflect.FieldDescriptor - - dataColumn string tableName string primaryKey func(REQ) (map[string]interface{}, error) auth AuthProvider - authJoin []*LeftJoin + authJoin []*KeyJoin queryLogger QueryLogger validator *protovalidate.Validator - join *getJoin -} - -type getJoin struct { - dataColumn string - tableName string - fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type - on JoinFields + columns []ColumnSpec } func NewGetter[ REQ GetRequest, RES GetResponse, ](spec GetSpec[REQ, RES]) (*Getter[REQ, RES], error) { - descriptors := newMethodDescriptor[REQ, RES]() - resDesc := descriptors.response + resDesc := spec.ResponseDescriptor + if resDesc == nil { + descriptors := newMethodDescriptor[REQ, RES]() + resDesc = descriptors.response + } sc := &Getter[REQ, RES]{ - dataColumn: spec.DataColumn, tableName: spec.TableName, primaryKey: spec.PrimaryKey, auth: spec.Auth, @@ -145,39 +137,46 @@ func NewGetter[ defaultState = true spec.StateResponseField = protoreflect.Name("state") } - sc.stateField = resDesc.Fields().ByName(spec.StateResponseField) - if sc.stateField == nil { + + stateField := resDesc.Fields().ByName(spec.StateResponseField) + if stateField == nil { if defaultState { - return nil, fmt.Errorf("no 'state' field in proto message - did you mean to override StateResponseField?") + return nil, fmt.Errorf("no 'state' field in proto message, StateResponseField is left blank") } return nil, fmt.Errorf("no '%s' field in proto message", spec.StateResponseField) } + sc.columns = append(sc.columns, newJsonColumn(spec.DataColumn, stateField)) + + if spec.DataColumn == "" { + return nil, fmt.Errorf("GetSpec missing DataColumn") + } if spec.PrimaryKey == nil { - return nil, fmt.Errorf("missing PrimaryKey func") + return nil, fmt.Errorf("GetSpec missing PrimaryKey function") } - if spec.Join != nil { + if spec.TableName == "" { + return nil, fmt.Errorf("GetSpec missing TableName") + } - if err := spec.Join.validate(); err != nil { + if spec.ArrayJoin != nil { + if err := spec.ArrayJoin.validate(); err != nil { return nil, fmt.Errorf("invalid join spec: %w", err) } - joinField := resDesc.Fields().ByName(protoreflect.Name(spec.Join.FieldInParent)) + joinField := resDesc.Fields().ByName(protoreflect.Name(spec.ArrayJoin.FieldInParent)) if joinField == nil { - return nil, fmt.Errorf("field %s not found in response message", spec.Join.FieldInParent) + return nil, fmt.Errorf("field %s not found in response message", spec.ArrayJoin.FieldInParent) } if !joinField.IsList() { - return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.Join.FieldInParent) + return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.ArrayJoin.FieldInParent) } - sc.join = &getJoin{ - tableName: spec.Join.TableName, - dataColumn: spec.Join.DataColumn, + sc.columns = append(sc.columns, &jsonArrayColumn{ + ArrayJoinSpec: *spec.ArrayJoin, fieldInParent: joinField, - on: spec.Join.On, - } + }) } var err error @@ -195,8 +194,7 @@ func (gc *Getter[REQ, RES]) SetQueryLogger(logger QueryLogger) { func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, resMsg RES) error { - as := newAliasSet() - rootAlias := as.Next(gc.tableName) + sb := newSelectBuilder(gc.tableName) resReflect := resMsg.ProtoReflect() @@ -209,27 +207,26 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return err } - rootFilter, err := dbconvert.FieldsToEqMap(rootAlias, primaryKeyFields) + if len(primaryKeyFields) == 0 { + return fmt.Errorf("PrimaryKey() returned no fields") + } + + rootFilter, err := dbconvert.FieldsToEqMap(sb.rootAlias, primaryKeyFields) if err != nil { return err } - - selectQuery := sq. - Select(). - Column(fmt.Sprintf("%s.%s", rootAlias, gc.dataColumn)). - From(fmt.Sprintf("%s AS %s", gc.tableName, rootAlias)). - Where(rootFilter) + sb.Where(rootFilter) for pkField := range rootFilter { - selectQuery.GroupBy(pkField) + sb.GroupBy(pkField) } if gc.auth != nil { - authAlias := rootAlias + authAlias := sb.rootAlias for _, join := range gc.authJoin { priorAlias := authAlias - authAlias = as.Next(join.TableName) - selectQuery = selectQuery.LeftJoin(fmt.Sprintf( + authAlias = sb.TableAlias(join.TableName) + sb.LeftJoin(fmt.Sprintf( "%s AS %s ON %s", join.TableName, authAlias, @@ -247,28 +244,18 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, for k, v := range authFilter { claimFilter[fmt.Sprintf("%s.%s", authAlias, k)] = v } - selectQuery.Where(claimFilter) + sb.Where(claimFilter) } } - if gc.join != nil { - joinAlias := as.Next(gc.join.tableName) - - selectQuery. - Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, gc.join.dataColumn)). - LeftJoin(fmt.Sprintf( - "%s AS %s ON %s", - gc.join.tableName, - joinAlias, - gc.join.on.SQL(rootAlias, joinAlias), - )) + for _, join := range gc.columns { + join.ApplyQuery(sb.rootAlias, sb) } - var foundJSON []byte - var joinedJSON pq.StringArray + fields, scanCols := sb.NewRow() if gc.queryLogger != nil { - gc.queryLogger(selectQuery) + gc.queryLogger(sb.SelectBuilder) } if err := db.Transact(ctx, &sqrlx.TxOptions{ @@ -276,14 +263,9 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, Retryable: true, Isolation: sql.LevelReadCommitted, }, func(ctx context.Context, tx sqrlx.Transaction) error { - row := tx.SelectRow(ctx, selectQuery) + row := tx.SelectRow(ctx, sb.SelectBuilder) - var err error - if gc.join != nil { - err = row.Scan(&foundJSON, &joinedJSON) - } else { - err = row.Scan(&foundJSON) - } + err := row.Scan(scanCols...) if err != nil { if errors.Is(err, sql.ErrNoRows) { var pkDescription string @@ -306,34 +288,14 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, return nil }); err != nil { - query, _, _ := selectQuery.ToSql() + query, _, _ := sb.ToSql() return fmt.Errorf("%s: %w", query, err) } - if foundJSON == nil { - return status.Error(codes.NotFound, "not found") - } - - stateMsg := resReflect.NewField(gc.stateField) - if err := protojson.Unmarshal(foundJSON, stateMsg.Message().Interface()); err != nil { - return err - } - resReflect.Set(gc.stateField, stateMsg) - - if gc.join != nil { - elementList := resReflect.Mutable(gc.join.fieldInParent).List() - for _, eventBytes := range joinedJSON { - if eventBytes == "" { - continue - } - - rowMessage := elementList.NewElement().Message() - if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { - return fmt.Errorf("joined unmarshal: %w", err) - } - elementList.Append(protoreflect.ValueOf(rowMessage)) + for _, field := range fields { + if err := field.Unmarshal(resReflect); err != nil { + return err } - } return nil diff --git a/pquery/getter_test.go b/pquery/getter_test.go new file mode 100644 index 0000000..c703cf8 --- /dev/null +++ b/pquery/getter_test.go @@ -0,0 +1,230 @@ +package pquery + +import ( + "context" + "database/sql" + "strings" + "testing" + + "github.com/pentops/pgtest.go/pgtest" + "github.com/pentops/sqrlx.go/sqrlx" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +func TestGetter(t *testing.T) { + + fileDescriptor := &descriptorpb.FileDescriptorProto{ + Name: proto.String("test.proto"), + Package: proto.String("test"), + MessageType: []*descriptorpb.DescriptorProto{{ + Name: proto.String("TestRequest"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }}, + }, { + Name: proto.String("TestResponse"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("state"), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".test.Foo"), + Number: proto.Int32(1), + }, { + Name: proto.String("bars"), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".test.Bar"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + }}, + }, { + Name: proto.String("Foo"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }, { + Name: proto.String("name"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(2), + }}, + }, { + Name: proto.String("Bar"), + Field: []*descriptorpb.FieldDescriptorProto{{ + Name: proto.String("bar_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + }, { + Name: proto.String("foo_id"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(2), + }, { + Name: proto.String("name"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(3), + }}, + }}, + } + + file, err := protodesc.NewFile(fileDescriptor, protoregistry.GlobalFiles) + if err != nil { + t.Fatalf("Compiling test proto: %s", err.Error()) + } + + reqDesc := file.Messages().ByName("TestRequest") + if reqDesc == nil { + t.Fatal("reqDesc is nil") + } + + resDesc := file.Messages().ByName("TestResponse") + if resDesc == nil { + t.Fatal("resDesc is nil") + } + + spec := GetSpec[*tDynamicMessage, *tDynamicMessage]{ + ResponseDescriptor: resDesc, + DataColumn: "state", + TableName: "foo", + PrimaryKey: func(req *tDynamicMessage) (map[string]interface{}, error) { + return map[string]interface{}{ + "foo_id": req.GetField(t, "foo_id").String(), + }, nil + }, + + ArrayJoin: &ArrayJoinSpec{ + TableName: "bar", + FieldInParent: "bars", + DataColumn: "state", + On: []JoinField{{ + JoinColumn: "foo_id", + RootColumn: "foo_id", + }}, + }, + } + + gg, err := NewGetter(spec) + if err != nil { + t.Fatal(err.Error()) + } + gg.SetQueryLogger(testLog(t)) + + ctx := context.Background() + conn := pgtest.GetTestDB(t, pgtest.WithSchemaName("pquery")) + + db := sqrlx.NewPostgres(conn) + + execAll(t, conn, + "CREATE TABLE foo (foo_id text PRIMARY KEY, state jsonb)", + `INSERT INTO foo (foo_id, state) VALUES + ( 'id0', '{"foo_id": "id0", "name": "foo0"}'), + ( 'id1', '{"foo_id": "id1", "name": "foo1"}') + `, + + `CREATE TABLE bar (bar_id text PRIMARY KEY, foo_id text REFERENCES foo(foo_id), state jsonb)`, + `INSERT INTO bar (bar_id, foo_id, state) VALUES + ( 'bar0', 'id0', '{"bar_id": "bar0", "foo_id": "id0", "name": "bar0"}'), + ( 'bar1', 'id0', '{"bar_id": "bar1", "foo_id": "id1", "name": "bar1"}') + `, + ) + + reqMsg := newDynamicMessage(reqDesc) + reqMsg.SetField(t, "foo_id", protoreflect.ValueOf("id0")) + + resMsg := newDynamicMessage(resDesc) + err = gg.Get(ctx, db, reqMsg, resMsg) + if err != nil { + t.Fatal(err.Error()) + } + + resMsg.AssertField(t, "state.foo_id", protoreflect.ValueOf("id0")) + resMsg.AssertField(t, "state.name", protoreflect.ValueOf("foo0")) + + bars := resMsg.GetField(t, "bars").List() + if bars.Len() != 2 { + t.Fatalf("expected 2 bars, got %d", bars.Len()) + } + +} + +func newDynamicMessage(md protoreflect.MessageDescriptor) *tDynamicMessage { + return &tDynamicMessage{dynamicpb.NewMessage(md)} +} + +func testLog(t *testing.T) QueryLogger { + return func(qq sqrlx.Sqlizer) { + stmt, args, err := qq.ToSql() + if err != nil { + t.Fatal(err.Error()) + } + t.Logf("Query: %s, args: %v", stmt, args) + + } +} + +type tDynamicMessage struct { + *dynamicpb.Message +} + +func (dm *tDynamicMessage) GetField(t *testing.T, name string) protoreflect.Value { + path := strings.Split(name, ".") + pathTo, last := path[:len(path)-1], path[len(path)-1] + var msg protoreflect.Message = dm.Message + for _, p := range pathTo { + field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(p))) + if !field.IsValid() { + t.Errorf("field %s is not valid", p) + } + msg = field.Message() + if msg == nil { + t.Fatalf("field %s: msg is nil", p) + } + } + fieldDef := msg.Descriptor().Fields().ByName(protoreflect.Name(last)) + if fieldDef == nil { + t.Fatalf("field %s: no such field", name) + } + field := msg.Get(fieldDef) + if !field.IsValid() { + t.Errorf("field %s is not valid", name) + } + return field +} + +func (dm *tDynamicMessage) SetField(t *testing.T, name string, value protoreflect.Value) { + path := strings.Split(name, ".") + pathTo, last := path[:len(path)-1], path[len(path)-1] + var msg protoreflect.Message = dm.Message + for _, p := range pathTo { + field := msg.Get(msg.Descriptor().Fields().ByName(protoreflect.Name(p))) + if field.IsValid() { + msg = field.Message() + continue + } + field = protoreflect.ValueOf(msg.NewField(msg.Descriptor().Fields().ByName(protoreflect.Name(p)))) + msg.Set(msg.Descriptor().Fields().ByName(protoreflect.Name(p)), field) + t.Errorf("field %s is not valid", p) + } + + msg.Set(msg.Descriptor().Fields().ByName(protoreflect.Name(last)), value) +} + +func (dm *tDynamicMessage) AssertField(t *testing.T, name string, value protoreflect.Value) { + field := dm.GetField(t, name) + if !field.Equal(value) { + t.Errorf("expected %v, got %v", value, field) + } +} + +func execAll(t *testing.T, conn *sql.DB, queries ...string) { + for _, query := range queries { + _, err := conn.Exec(query) + if err != nil { + t.Fatal(err.Error()) + } + } +} diff --git a/pquery/lister.go b/pquery/lister.go index f04e144..f9a18a4 100644 --- a/pquery/lister.go +++ b/pquery/lister.go @@ -12,7 +12,6 @@ import ( "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "github.com/bufbuild/protovalidate-go" - "github.com/elgris/sqrl" sq "github.com/elgris/sqrl" "github.com/elgris/sqrl/pg" "github.com/pentops/j5/gen/j5/list/v1/list_j5pb" @@ -21,7 +20,6 @@ import ( "github.com/pentops/sqrlx.go/sqrlx" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" @@ -43,7 +41,7 @@ type TableSpec struct { TableName string Auth AuthProvider - AuthJoin []*LeftJoin + AuthJoin []*KeyJoin DataColumn string // TODO: Replace with array Columns []Column @@ -51,14 +49,6 @@ type TableSpec struct { FallbackSortColumns []pgstore.ProtoFieldSpec } -type Column struct { - Name string - - // The point within the root element which is stored in the column. An empty - // path means this stores the root element, - MountPoint *pgstore.Path -} - type ListSpec[REQ ListRequest, RES ListResponse] struct { TableSpec RequestFilter func(REQ) (map[string]interface{}, error) @@ -82,10 +72,7 @@ type ListReflectionSet struct { tsvColumnMap map[string]string - // TODO: This should be an array/map of columns to data types, allowing - // multiple JSONB values, as well as cached field values direcrly on the - // table - dataColumn string + columns []ColumnSpec } func BuildListReflection(req protoreflect.MessageDescriptor, res protoreflect.MessageDescriptor, table TableSpec) (*ListReflectionSet, error) { @@ -96,10 +83,13 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me var err error ll := ListReflectionSet{ defaultPageSize: uint64(20), - dataColumn: table.DataColumn, } fields := res.Fields() + dataColumn := table.DataColumn + rootColumn := newJsonColumn(dataColumn) + ll.columns = append(ll.columns, rootColumn) + for i := 0; i < fields.Len(); i++ { field := fields.Get(i) msg := field.Message() @@ -137,12 +127,12 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me return nil, fmt.Errorf("validate list annotations on %s: %w", ll.arrayField.Message().FullName(), err) } - ll.defaultSortFields, err = buildDefaultSorts(ll.dataColumn, ll.arrayField.Message()) + ll.defaultSortFields, err = buildDefaultSorts(dataColumn, ll.arrayField.Message()) if err != nil { return nil, fmt.Errorf("default sorts: %w", err) } - ll.tieBreakerFields, err = buildTieBreakerFields(ll.dataColumn, req, ll.arrayField.Message(), table.FallbackSortColumns) + ll.tieBreakerFields, err = buildTieBreakerFields(dataColumn, req, ll.arrayField.Message(), table.FallbackSortColumns) if err != nil { return nil, fmt.Errorf("tie breaker fields: %w", err) } @@ -151,7 +141,7 @@ func buildListReflection(req protoreflect.MessageDescriptor, res protoreflect.Me return nil, fmt.Errorf("no default sort field found, %s must have at least one field annotated as default sort, or specify a tie breaker in %s", ll.arrayField.Message().FullName(), req.FullName()) } - f, err := buildDefaultFilters(ll.dataColumn, ll.arrayField.Message()) + f, err := buildDefaultFilters(dataColumn, ll.arrayField.Message()) if err != nil { return nil, fmt.Errorf("default filters: %w", err) } @@ -215,11 +205,13 @@ type Lister[REQ ListRequest, RES ListResponse] struct { queryLogger QueryLogger auth AuthProvider - authJoin []*LeftJoin + authJoin []*KeyJoin requestFilter func(REQ) (map[string]interface{}, error) validator *protovalidate.Validator + + rootStateColumn string } func NewLister[ @@ -227,9 +219,10 @@ func NewLister[ RES ListResponse, ](spec ListSpec[REQ, RES]) (*Lister[REQ, RES], error) { ll := &Lister[REQ, RES]{ - tableName: spec.TableName, - auth: spec.Auth, - authJoin: spec.AuthJoin, + tableName: spec.TableName, + auth: spec.Auth, + authJoin: spec.AuthJoin, + rootStateColumn: spec.DataColumn, } descriptors := newMethodDescriptor[REQ, RES]() @@ -254,6 +247,10 @@ func (ll *Lister[REQ, RES]) SetQueryLogger(logger QueryLogger) { ll.queryLogger = logger } +type listRow struct { + columns []ScanDest +} + func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg proto.Message, resMsg proto.Message) error { if err := ll.validator.Validate(reqMsg); err != nil { return fmt.Errorf("validating request %s: %w", reqMsg.ProtoReflect().Descriptor().FullName(), err) @@ -267,7 +264,7 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot return fmt.Errorf("get page size: %w", err) } - selectQuery, err := ll.BuildQuery(ctx, req, res) + sb, err := ll.BuildQuery(ctx, req, res) if err != nil { return fmt.Errorf("build query: %w", err) } @@ -278,45 +275,50 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot Isolation: sql.LevelReadCommitted, } - var jsonRows = make([][]byte, 0, pageSize) + listRows := make([]listRow, 0, pageSize) + err = db.Transact(ctx, txOpts, func(ctx context.Context, tx sqrlx.Transaction) error { - rows, err := tx.Query(ctx, selectQuery) + rows, err := tx.Query(ctx, sb.SelectBuilder) if err != nil { return fmt.Errorf("run select: %w", err) } defer rows.Close() for rows.Next() { - var json []byte - if err := rows.Scan(&json); err != nil { + + fields, scanCols := sb.NewRow() + row := listRow{columns: fields} + listRows = append(listRows, row) + + if err := rows.Scan(scanCols...); err != nil { return fmt.Errorf("row scan: %w", err) } - jsonRows = append(jsonRows, json) } return rows.Err() }) if err != nil { - stmt, _, _ := selectQuery.ToSql() + stmt, _, _ := sb.SelectBuilder.ToSql() log.WithField(ctx, "query", stmt).Error("list query") return fmt.Errorf("list query: %w", err) } if ll.queryLogger != nil { - ll.queryLogger(selectQuery) + ll.queryLogger(sb.SelectBuilder) } list := res.Mutable(ll.arrayField).List() res.Set(ll.arrayField, protoreflect.ValueOf(list)) var nextToken string - for idx, rowBytes := range jsonRows { + for idx, rowBytes := range listRows { rowMessage := list.NewElement().Message() - err := protojson.Unmarshal(rowBytes, rowMessage.Interface()) - if err != nil { - return fmt.Errorf("unmarshal into %s from %s: %w", rowMessage.Descriptor().FullName(), string(rowBytes), err) + for i, col := range rowBytes.columns { + if err := col.Unmarshal(rowMessage); err != nil { + return fmt.Errorf("unmarshal column %d: %w", i, err) + } } if idx >= int(pageSize) { @@ -347,12 +349,12 @@ func (ll *Lister[REQ, RES]) List(ctx context.Context, db Transactor, reqMsg prot return nil } -func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*sqrl.SelectBuilder, error) { - as := newAliasSet() - tableAlias := as.Next(ll.tableName) +func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Message, res protoreflect.Message) (*selectBuilder, error) { - selectQuery := sq.Select(fmt.Sprintf("%s.%s", tableAlias, ll.dataColumn)). - From(fmt.Sprintf("%s AS %s", ll.tableName, tableAlias)) + sb := newSelectBuilder(ll.tableName) + for _, colSpec := range ll.columns { + colSpec.ApplyQuery(sb.rootAlias, sb) + } sortFields := ll.defaultSortFields sortFields = append(sortFields, ll.tieBreakerFields...) @@ -366,11 +368,11 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes and := sq.And{} for k := range filter { - and = append(and, sq.Expr(fmt.Sprintf("%s.%s = ?", tableAlias, k), filter[k])) + and = append(and, sq.Expr(fmt.Sprintf("%s.%s = ?", sb.rootAlias, k), filter[k])) } if len(and) > 0 { - selectQuery.Where(and) + sb.Where(and) } } @@ -382,7 +384,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes querySorts := reqQuery.GetSorts() if len(querySorts) > 0 { - dynSorts, err := ll.buildDynamicSortSpec(querySorts) + dynSorts, err := buildDynamicSortSpec(ll.arrayField.Message(), ll.rootStateColumn, querySorts) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build sorts: %s", err) } @@ -392,7 +394,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes queryFilters := reqQuery.GetFilters() if len(queryFilters) > 0 { - dynFilters, err := ll.buildDynamicFilter(tableAlias, queryFilters) + dynFilters, err := buildDynamicFilter(ll.arrayField.Message(), sb.rootAlias, ll.rootStateColumn, queryFilters) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build filters: %s", err) } @@ -402,7 +404,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes querySearches := reqQuery.GetSearches() if len(querySearches) > 0 { - searchFilters, err := ll.buildDynamicSearches(tableAlias, querySearches) + searchFilters, err := ll.buildDynamicSearches(sb.rootAlias, querySearches) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "build searches: %s", err) } @@ -412,7 +414,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes } for i := range filterFields { - selectQuery.Where(filterFields[i]) + sb.Where(filterFields[i]) } // apply default filters if no filters have been requested @@ -421,14 +423,14 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes for _, spec := range ll.defaultFilterFields { or := sq.Or{} for _, val := range spec.filterVals { - or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", tableAlias, ll.dataColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) + or = append(or, sq.Expr(fmt.Sprintf("jsonb_path_query_array(%s.%s, '%s') @> ?", sb.rootAlias, spec.RootColumn, spec.Path.JSONPathQuery()), pg.JSONB(val))) } and = append(and, or) } if len(and) > 0 { - selectQuery.Where(and) + sb.Where(and) } } @@ -437,15 +439,15 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes if sortField.desc { direction = "DESC" } - selectQuery.OrderBy(fmt.Sprintf("%s %s", sortField.Selector(tableAlias), direction)) + sb.OrderBy(fmt.Sprintf("%s %s", sortField.Selector(sb.rootAlias), direction)) } if ll.auth != nil { - authAlias := tableAlias + authAlias := sb.rootAlias for _, join := range ll.authJoin { priorAlias := authAlias - authAlias = as.Next(join.TableName) - selectQuery = selectQuery.LeftJoin(fmt.Sprintf( + authAlias = sb.TableAlias(join.TableName) + sb.LeftJoin(fmt.Sprintf( "%s AS %s ON %s", join.TableName, authAlias, @@ -463,7 +465,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes for k, v := range authFilter { claimFilter[fmt.Sprintf("%s.%s", authAlias, k)] = v } - selectQuery.Where(claimFilter) + sb.Where(claimFilter) } } @@ -472,7 +474,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes return nil, err } - selectQuery.Limit(pageSize + 1) + sb.Limit(pageSize + 1) reqPage, ok := req.Get(ll.pageRequestField).Message().Interface().(*list_j5pb.PageRequest) if ok && reqPage != nil && reqPage.GetToken() != "" { @@ -492,7 +494,7 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes rhsPlaceholders := make([]string, 0, len(sortFields)) for _, sortField := range sortFields { - rowSelecter := sortField.Selector(tableAlias) + rowSelecter := sortField.Selector(sb.rootAlias) valuePlaceholder := "?" fieldVal, err := sortField.Path.GetValue(rowMessage) @@ -600,14 +602,14 @@ func (ll *Lister[REQ, RES]) BuildQuery(ctx context.Context, req protoreflect.Mes // does not actually matter in which order the string field is sorted... // or don't because indexes. - selectQuery = selectQuery.Where( + sb.Where( fmt.Sprintf("(%s) >= (%s)", strings.Join(lhsFields, ","), strings.Join(rhsPlaceholders, ","), ), rhsValues...) } - return selectQuery, nil + return sb, nil } func (ll *Lister[REQ, RES]) getPageSize(req protoreflect.Message) (uint64, error) { diff --git a/pquery/query.go b/pquery/query.go index 7306609..ebad360 100644 --- a/pquery/query.go +++ b/pquery/query.go @@ -34,10 +34,10 @@ func (f AuthProviderFunc) AuthFilter(ctx context.Context) (map[string]string, er return f(ctx) } -// LeftJoin is a specification for joining in the form -// ON . =
. -// Main is defined in the outer struct holding this LeftJoin -type LeftJoin struct { +// KeyJoin is a join on a primary (or unique) key in the RHS table. +// LEFT JOIN ON . =
. +// Main is defined in the outer struct holding this KeyJoin +type KeyJoin struct { TableName string On JoinFields } diff --git a/pquery/select_build.go b/pquery/select_build.go new file mode 100644 index 0000000..68b8666 --- /dev/null +++ b/pquery/select_build.go @@ -0,0 +1,178 @@ +package pquery + +import ( + "fmt" + + sq "github.com/elgris/sqrl" + "github.com/lib/pq" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type ColumnDest interface { + NewRow() ScanDest +} + +type ScanDest interface { + ScanTo() interface{} + Unmarshal(protoreflect.Message) error +} + +type SelectBuilder interface { + Column(into ColumnDest, stmt string, args ...interface{}) + LeftJoin(join string, rest ...interface{}) + TableAlias(tableName string) string +} + +type selectBuilder struct { + *sq.SelectBuilder + aliasSet *aliasSet + rootAlias string + columns []ColumnDest +} + +func newSelectBuilder(rootTable string) *selectBuilder { + as := newAliasSet() + rootAlias := as.Next(rootTable) + sb := sq.Select(). + From(fmt.Sprintf("%s AS %s", rootTable, rootAlias)) + + return &selectBuilder{ + SelectBuilder: sb, + aliasSet: as, + rootAlias: rootAlias, + } +} + +func (sb *selectBuilder) NewRow() ([]ScanDest, []interface{}) { + fields := make([]ScanDest, 0, len(sb.columns)) + rowCols := make([]interface{}, 0, len(sb.columns)) + for _, inQuery := range sb.columns { + colRow := inQuery.NewRow() + fields = append(fields, colRow) + rowCols = append(rowCols, colRow.ScanTo()) + } + return fields, rowCols +} + +func (sb *selectBuilder) Column(into ColumnDest, stmt string, args ...interface{}) { + sb.SelectBuilder.Column(stmt, args...) + sb.columns = append(sb.columns, into) +} + +func (sb *selectBuilder) LeftJoin(join string, rest ...interface{}) { + sb.SelectBuilder.LeftJoin(join, rest...) +} + +func (sb *selectBuilder) TableAlias(tableName string) string { + return sb.aliasSet.Next(tableName) +} + +type ColumnSpec interface { + ApplyQuery(parentAlias string, sb SelectBuilder) +} + +type jsonColumn struct { + sqlColumn string + protoPath []protoreflect.FieldDescriptor +} + +func newJsonColumn(sqlColumn string, protoPath ...protoreflect.FieldDescriptor) jsonColumn { + return jsonColumn{ + sqlColumn: sqlColumn, + protoPath: protoPath, + } +} + +func (jc jsonColumn) ApplyQuery(tableAlias string, sb SelectBuilder) { + sb.Column(jc, fmt.Sprintf("%s.%s", tableAlias, jc.sqlColumn)) +} + +func (jc jsonColumn) NewRow() ScanDest { + return &jsonFieldRow{ + protoPath: jc.protoPath, + } +} + +// jsonFieldRow is a jsonb SQL field mapped to a proto field. +type jsonFieldRow struct { + protoPath []protoreflect.FieldDescriptor + data []byte +} + +func (jc *jsonFieldRow) ScanTo() interface{} { + return &jc.data +} + +func (jc *jsonFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + if jc.data == nil { + return status.Error(codes.NotFound, "not found") + } + + msg := resReflect + for _, field := range jc.protoPath { + child := resReflect.Mutable(field).Message() + msg.Set(field, protoreflect.ValueOf(child)) + msg = child + } + + if err := protojson.Unmarshal(jc.data, msg.Interface()); err != nil { + return err + } + + return nil +} + +type jsonArrayColumn struct { + ArrayJoinSpec + fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type +} + +func (join *jsonArrayColumn) NewRow() ScanDest { + scanDest := pq.StringArray{} + return &jsonArrayFieldRow{ + fieldInParent: join.fieldInParent, + column: &scanDest, + } +} + +func (join *jsonArrayColumn) ApplyQuery(parentTable string, sb SelectBuilder) { + joinAlias := sb.TableAlias(join.TableName) + sb.Column(join, fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, join.DataColumn)) + sb.LeftJoin(fmt.Sprintf( + "%s AS %s ON %s", + join.TableName, + joinAlias, + join.On.SQL(parentTable, joinAlias), + )) +} + +type jsonArrayFieldRow struct { + fieldInParent protoreflect.FieldDescriptor + column *pq.StringArray +} + +func (join *jsonArrayFieldRow) ScanTo() interface{} { + return join.column +} + +func (join *jsonArrayFieldRow) Unmarshal(resReflect protoreflect.Message) error { + + elementList := resReflect.Mutable(join.fieldInParent).List() + for _, eventBytes := range *join.column { + if eventBytes == "" { + continue + } + + rowMessage := elementList.NewElement().Message() + if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil { + return fmt.Errorf("joined unmarshal: %w", err) + } + elementList.Append(protoreflect.ValueOf(rowMessage)) + } + + return nil +} diff --git a/pquery/sort.go b/pquery/sort.go index f7bb30c..82fa36c 100644 --- a/pquery/sort.go +++ b/pquery/sort.go @@ -162,19 +162,19 @@ func buildDefaultSorts(columnName string, message protoreflect.MessageDescriptor return defaultSortFields, nil } -func (ll *Lister[REQ, RES]) buildDynamicSortSpec(sorts []*list_j5pb.Sort) ([]sortSpec, error) { +func buildDynamicSortSpec(rootMessage protoreflect.MessageDescriptor, rootColumn string, sorts []*list_j5pb.Sort) ([]sortSpec, error) { results := []sortSpec{} direction := "" for _, sort := range sorts { pathSpec := pgstore.ParseJSONPathSpec(sort.Field) - spec, err := pgstore.NewJSONPath(ll.arrayField.Message(), pathSpec) + spec, err := pgstore.NewJSONPath(rootMessage, pathSpec) if err != nil { return nil, fmt.Errorf("dynamic filter: find field: %w", err) } biggerSpec := &pgstore.NestedField{ Path: *spec, - RootColumn: ll.dataColumn, + RootColumn: rootColumn, } results = append(results, sortSpec{ diff --git a/psm/state_query.go b/psm/state_query.go index 75ede22..586c8e2 100644 --- a/psm/state_query.go +++ b/psm/state_query.go @@ -89,7 +89,7 @@ func (gc *StateQuerySet[ type StateQueryOptions struct { Auth pquery.AuthProvider - AuthJoin *pquery.LeftJoin + AuthJoin *pquery.KeyJoin SkipEvents bool } @@ -156,7 +156,7 @@ func BuildStateQuerySet[ } if options.AuthJoin != nil { - getSpec.AuthJoin = []*pquery.LeftJoin{options.AuthJoin} + getSpec.AuthJoin = []*pquery.KeyJoin{options.AuthJoin} } if options.Auth != nil { @@ -196,7 +196,7 @@ func BuildStateQuerySet[ if smSpec.Event.Root == nil { return nil, fmt.Errorf("missing EventDataColumn in state spec for %s", smSpec.State.TableName) } - getSpec.Join = &pquery.GetJoinSpec{ + getSpec.ArrayJoin = &pquery.ArrayJoinSpec{ TableName: smSpec.Event.TableName, DataColumn: smSpec.Event.Root.ColumnName, FieldInParent: eventsInGet,