Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/integration/pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestPagination(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
printQuery(t, query)
printQuery(t, query.SelectBuilder)

err = queryer.List(ctx, db, req, res)
if err != nil {
Expand Down Expand Up @@ -240,7 +240,7 @@ func TestEventPagination(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
printQuery(t, query)
printQuery(t, query.SelectBuilder)

err = queryer.ListEvents(ctx, db, req, res)
if err != nil {
Expand Down
20 changes: 11 additions & 9 deletions pquery/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,32 +288,34 @@ func buildDefaultFilters(columnName string, message protoreflect.MessageDescript
return filters, nil
}

func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) {
func buildDynamicFilter(filterMessage protoreflect.MessageDescriptor, tableAlias string, dataColumn string, filters []*list_j5pb.Filter) ([]sq.Sqlizer, error) {
out := []sq.Sqlizer{}

for i := range filters {
switch filters[i].GetType().(type) {
case *list_j5pb.Filter_Field:
pathSpec := pgstore.ParseJSONPathSpec(filters[i].GetField().GetName())
spec, err := pgstore.NewJSONPath(ll.arrayField.Message(), pathSpec)
spec, err := pgstore.NewJSONPath(filterMessage, pathSpec)
if err != nil {
return nil, fmt.Errorf("dynamic filter: find field: %w", err)
}

// TODO: get RootColumn from the shortest matching list column

biggerSpec := &pgstore.NestedField{
Path: *spec,
RootColumn: ll.dataColumn,
RootColumn: dataColumn,
}

var o sq.Sqlizer
switch leaf := spec.Leaf().(type) {
case protoreflect.OneofDescriptor:
o, err = ll.buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i])
o, err = buildDynamicFilterOneof(tableAlias, biggerSpec, filters[i])
if err != nil {
return nil, fmt.Errorf("dynamic filter: build oneof: %w", err)
}
case protoreflect.FieldDescriptor:
o, err = ll.buildDynamicFilterField(tableAlias, biggerSpec, filters[i])
o, err = buildDynamicFilterField(tableAlias, biggerSpec, filters[i])
if err != nil {
return nil, fmt.Errorf("dynamic filter: build field: %w", err)
}
Expand All @@ -323,7 +325,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis

out = append(out, o)
case *list_j5pb.Filter_And:
f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetAnd().GetFilters())
f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetAnd().GetFilters())
if err != nil {
return nil, fmt.Errorf("dynamic filter: and: %w", err)
}
Expand All @@ -332,7 +334,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis

out = append(out, and)
case *list_j5pb.Filter_Or:
f, err := ll.buildDynamicFilter(tableAlias, filters[i].GetOr().GetFilters())
f, err := buildDynamicFilter(filterMessage, tableAlias, dataColumn, filters[i].GetOr().GetFilters())
if err != nil {
return nil, fmt.Errorf("dynamic filter: or: %w", err)
}
Expand All @@ -346,7 +348,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilter(tableAlias string, filters []*lis
return out, nil
}

func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) {
func buildDynamicFilterField(tableAlias string, spec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) {
var out sq.And

if filter.GetField() == nil {
Expand Down Expand Up @@ -396,7 +398,7 @@ func (ll *Lister[REQ, RES]) buildDynamicFilterField(tableAlias string, spec *pgs
return out, nil
}

func (ll *Lister[REQ, RES]) buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) {
func buildDynamicFilterOneof(tableAlias string, ospec *pgstore.NestedField, filter *list_j5pb.Filter) (sq.Sqlizer, error) {
var out sq.And

if filter.GetField() == nil {
Expand Down
152 changes: 57 additions & 95 deletions pquery/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ import (
"strings"

"github.com/bufbuild/protovalidate-go"
sq "github.com/elgris/sqrl"
"github.com/lib/pq"
"github.com/pentops/protostate/internal/dbconvert"
"github.com/pentops/sqrlx.go/sqrlx"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
Expand All @@ -34,13 +31,16 @@ type GetSpec[
TableName string
DataColumn string
Auth AuthProvider
AuthJoin []*LeftJoin
AuthJoin []*KeyJoin

PrimaryKey func(REQ) (map[string]interface{}, error)

StateResponseField protoreflect.Name

Join *GetJoinSpec
ArrayJoin *ArrayJoinSpec

// ResponseDescriptor must describe the RES message, defaults to new(RES).ProtoReflect().Descriptor()
ResponseDescriptor protoreflect.MessageDescriptor
}

// JoinConstraint defines a
Expand Down Expand Up @@ -77,14 +77,14 @@ func (jc JoinFields) SQL(rootAlias string, joinAlias string) string {
return strings.Join(conditions, " AND ")
}

type GetJoinSpec struct {
type ArrayJoinSpec struct {
TableName string
DataColumn string
On JoinFields
FieldInParent protoreflect.Name
}

func (gc GetJoinSpec) validate() error {
func (gc ArrayJoinSpec) validate() error {
if gc.TableName == "" {
return fmt.Errorf("missing TableName")
}
Expand All @@ -102,37 +102,29 @@ type Getter[
REQ GetRequest,
RES proto.Message,
] struct {
stateField protoreflect.FieldDescriptor

dataColumn string
tableName string
primaryKey func(REQ) (map[string]interface{}, error)
auth AuthProvider
authJoin []*LeftJoin
authJoin []*KeyJoin

queryLogger QueryLogger

validator *protovalidate.Validator

join *getJoin
}

type getJoin struct {
dataColumn string
tableName string
fieldInParent protoreflect.FieldDescriptor // wraps the ListFooEventResponse type
on JoinFields
columns []ColumnSpec
}

func NewGetter[
REQ GetRequest,
RES GetResponse,
](spec GetSpec[REQ, RES]) (*Getter[REQ, RES], error) {
descriptors := newMethodDescriptor[REQ, RES]()
resDesc := descriptors.response
resDesc := spec.ResponseDescriptor
if resDesc == nil {
descriptors := newMethodDescriptor[REQ, RES]()
resDesc = descriptors.response
}

sc := &Getter[REQ, RES]{
dataColumn: spec.DataColumn,
tableName: spec.TableName,
primaryKey: spec.PrimaryKey,
auth: spec.Auth,
Expand All @@ -145,39 +137,46 @@ func NewGetter[
defaultState = true
spec.StateResponseField = protoreflect.Name("state")
}
sc.stateField = resDesc.Fields().ByName(spec.StateResponseField)
if sc.stateField == nil {

stateField := resDesc.Fields().ByName(spec.StateResponseField)
if stateField == nil {
if defaultState {
return nil, fmt.Errorf("no 'state' field in proto message - did you mean to override StateResponseField?")
return nil, fmt.Errorf("no 'state' field in proto message, StateResponseField is left blank")
}
return nil, fmt.Errorf("no '%s' field in proto message", spec.StateResponseField)
}
sc.columns = append(sc.columns, newJsonColumn(spec.DataColumn, stateField))

if spec.DataColumn == "" {
return nil, fmt.Errorf("GetSpec missing DataColumn")
}

if spec.PrimaryKey == nil {
return nil, fmt.Errorf("missing PrimaryKey func")
return nil, fmt.Errorf("GetSpec missing PrimaryKey function")
}

if spec.Join != nil {
if spec.TableName == "" {
return nil, fmt.Errorf("GetSpec missing TableName")
}

if err := spec.Join.validate(); err != nil {
if spec.ArrayJoin != nil {
if err := spec.ArrayJoin.validate(); err != nil {
return nil, fmt.Errorf("invalid join spec: %w", err)
}

joinField := resDesc.Fields().ByName(protoreflect.Name(spec.Join.FieldInParent))
joinField := resDesc.Fields().ByName(protoreflect.Name(spec.ArrayJoin.FieldInParent))
if joinField == nil {
return nil, fmt.Errorf("field %s not found in response message", spec.Join.FieldInParent)
return nil, fmt.Errorf("field %s not found in response message", spec.ArrayJoin.FieldInParent)
}

if !joinField.IsList() {
return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.Join.FieldInParent)
return nil, fmt.Errorf("field %s, in join spec, is not a list", spec.ArrayJoin.FieldInParent)
}

sc.join = &getJoin{
tableName: spec.Join.TableName,
dataColumn: spec.Join.DataColumn,
sc.columns = append(sc.columns, &jsonArrayColumn{
ArrayJoinSpec: *spec.ArrayJoin,
fieldInParent: joinField,
on: spec.Join.On,
}
})
}

var err error
Expand All @@ -195,8 +194,7 @@ func (gc *Getter[REQ, RES]) SetQueryLogger(logger QueryLogger) {

func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ, resMsg RES) error {

as := newAliasSet()
rootAlias := as.Next(gc.tableName)
sb := newSelectBuilder(gc.tableName)

resReflect := resMsg.ProtoReflect()

Expand All @@ -209,27 +207,26 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ,
return err
}

rootFilter, err := dbconvert.FieldsToEqMap(rootAlias, primaryKeyFields)
if len(primaryKeyFields) == 0 {
return fmt.Errorf("PrimaryKey() returned no fields")
}

rootFilter, err := dbconvert.FieldsToEqMap(sb.rootAlias, primaryKeyFields)
if err != nil {
return err
}

selectQuery := sq.
Select().
Column(fmt.Sprintf("%s.%s", rootAlias, gc.dataColumn)).
From(fmt.Sprintf("%s AS %s", gc.tableName, rootAlias)).
Where(rootFilter)
sb.Where(rootFilter)

for pkField := range rootFilter {
selectQuery.GroupBy(pkField)
sb.GroupBy(pkField)
}

if gc.auth != nil {
authAlias := rootAlias
authAlias := sb.rootAlias
for _, join := range gc.authJoin {
priorAlias := authAlias
authAlias = as.Next(join.TableName)
selectQuery = selectQuery.LeftJoin(fmt.Sprintf(
authAlias = sb.TableAlias(join.TableName)
sb.LeftJoin(fmt.Sprintf(
"%s AS %s ON %s",
join.TableName,
authAlias,
Expand All @@ -247,43 +244,28 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ,
for k, v := range authFilter {
claimFilter[fmt.Sprintf("%s.%s", authAlias, k)] = v
}
selectQuery.Where(claimFilter)
sb.Where(claimFilter)
}
}

if gc.join != nil {
joinAlias := as.Next(gc.join.tableName)

selectQuery.
Column(fmt.Sprintf("ARRAY_AGG(%s.%s)", joinAlias, gc.join.dataColumn)).
LeftJoin(fmt.Sprintf(
"%s AS %s ON %s",
gc.join.tableName,
joinAlias,
gc.join.on.SQL(rootAlias, joinAlias),
))
for _, join := range gc.columns {
join.ApplyQuery(sb.rootAlias, sb)
}

var foundJSON []byte
var joinedJSON pq.StringArray
fields, scanCols := sb.NewRow()

if gc.queryLogger != nil {
gc.queryLogger(selectQuery)
gc.queryLogger(sb.SelectBuilder)
}

if err := db.Transact(ctx, &sqrlx.TxOptions{
ReadOnly: true,
Retryable: true,
Isolation: sql.LevelReadCommitted,
}, func(ctx context.Context, tx sqrlx.Transaction) error {
row := tx.SelectRow(ctx, selectQuery)
row := tx.SelectRow(ctx, sb.SelectBuilder)

var err error
if gc.join != nil {
err = row.Scan(&foundJSON, &joinedJSON)
} else {
err = row.Scan(&foundJSON)
}
err := row.Scan(scanCols...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
var pkDescription string
Expand All @@ -306,34 +288,14 @@ func (gc *Getter[REQ, RES]) Get(ctx context.Context, db Transactor, reqMsg REQ,

return nil
}); err != nil {
query, _, _ := selectQuery.ToSql()
query, _, _ := sb.ToSql()
return fmt.Errorf("%s: %w", query, err)
}

if foundJSON == nil {
return status.Error(codes.NotFound, "not found")
}

stateMsg := resReflect.NewField(gc.stateField)
if err := protojson.Unmarshal(foundJSON, stateMsg.Message().Interface()); err != nil {
return err
}
resReflect.Set(gc.stateField, stateMsg)

if gc.join != nil {
elementList := resReflect.Mutable(gc.join.fieldInParent).List()
for _, eventBytes := range joinedJSON {
if eventBytes == "" {
continue
}

rowMessage := elementList.NewElement().Message()
if err := protojson.Unmarshal([]byte(eventBytes), rowMessage.Interface()); err != nil {
return fmt.Errorf("joined unmarshal: %w", err)
}
elementList.Append(protoreflect.ValueOf(rowMessage))
for _, field := range fields {
if err := field.Unmarshal(resReflect); err != nil {
return err
}

}

return nil
Expand Down
Loading