Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions grpcrun/go_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ import (

var (
mu sync.Mutex
log *zap.Logger
node *snowflake.Node
)

func init() {
var err error
mu = sync.Mutex{}
log, _ = zap.NewDevelopment()
zap.ReplaceGlobals(log)
// zap.ReplaceGlobals(&log)
if node, err = snowflake.NewNode(int64(time.Now().Day())); err != nil {
panic(err)
}
Expand All @@ -37,36 +41,38 @@ func init() {
// run.Wait()
// }
type GoGrpc struct {
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
wait sync.WaitGroup
Timeout time.Duration
Task map[string]*GrpcTask
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
wait sync.WaitGroup
time time.Duration
Task map[string]*GrpcTask
}

func NewGoGrpc() *GoGrpc {
mu.Lock()
defer mu.Unlock()
g := GoGrpc{}
g.ctx, g.cancel = context.WithTimeout(context.Background(), 3*time.Second)
g.mu = sync.Mutex{}
g.time = 3 * time.Second
g.wait = sync.WaitGroup{}
g.Task = make(map[string]*GrpcTask, 0)
g.ctx, g.cancel = context.WithTimeout(context.Background(), g.time)
return &g
}

// SetTimeout reset timeout, replace default timeout with a special time duration
func (g *GoGrpc) SetTimeout(timeout time.Duration) {
mu.Lock()
mu.Unlock()
g.ctx, g.cancel = context.WithTimeout(context.Background(), timeout)
g.mu.Lock()
defer g.mu.Unlock()
g.time = timeout
}

func (g *GoGrpc) Run() {
for _, t := range g.Task {
go g.run(t)
for _, task := range g.Task {
go g.run(task)
}
g.Wait()
}

func (g *GoGrpc) Wait() {
Expand All @@ -84,16 +90,20 @@ func (g *GoGrpc) AddTask(task *GrpcTask) {
func (g *GoGrpc) AddNewTask(grpcName string, grpcMethod any, request any) {
g.mu.Lock()
defer g.mu.Unlock()
zap.S()
task := GrpcTask{
ctx: &g.ctx,

if grpcName == "" {
grpcName = node.Generate().String()
}

task := &GrpcTask{
ctx: g.ctx,
grpcMethod: grpcMethod,
request: request,
Name: grpcName,
log: zap.S().Named(grpcName),
log: zap.S(),
}

g.Task[node.Generate().String()] = &task
g.Task[task.Name] = task
g.wait.Add(1)
return
}
Expand All @@ -103,10 +113,12 @@ func (g *GoGrpc) run(t *GrpcTask) {
for {
select {
case <-g.ctx.Done():
t.log.Info("context done")
t.Err = errors.New("context canceled")
return
default:
t.Call()
t.log.Info("success call function")
return
}
}
Expand Down
81 changes: 68 additions & 13 deletions grpcrun/grpcrun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"go.uber.org/zap"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -66,14 +67,22 @@ func Login4(ctx context.Context, req *loginReq) (*loginResp, int) {
return &loginResp{UserId: 21, Token: "test grpc call success"}, 1
}

func Login5(ctx context.Context, req *loginReq) (*loginResp, error) {
fmt.Println("sleep: ", req)
time.Sleep(time.Second)
fmt.Println("over: ", req)
return &loginResp{UserId: 2333, Token: "test grpc call success"}, nil
}

var (
datas []*data
datas []*data
timeouts []*data
)

func TestGrpcTask(t *testing.T) {

for i, d := range datas {
call := grpcrun.NewGrpcTask(&d.ctx, "test{"+strconv.Itoa(i)+"}", d.method, d.req)
call := grpcrun.NewGrpcTask(d.ctx, "test{"+strconv.Itoa(i)+"}", d.method, d.req)
call.Call()

t.Logf("第 %d 次执行\n", i+1)
Expand Down Expand Up @@ -108,19 +117,33 @@ func init() {

// 测试表格
datas = []*data{
newData(ctx, Login, req), // 正常
newData(ctx, Login1, req), // [grpcMethod]必须有2个参数(context.Context, *request)
newData(ctx, Login2, req), // [grpcMethod]的第1个参数必须是:context.Context
newData(ctx, Login3, req), // [grpcMethod]必须有2个返回值(*Response, error)
newData(ctx, Login4, req), // [grpcMethod]的第2个返回值必须是:error
newData(nil, Login, req), // 请正确的传递[Context],不支持:nil
newData(ctx, nil, req), // [grpcMethod]必须是一个GRPC的函数类型,现在是:invalid
newData(ctx, Login, nil), // 请正确的传递[request],不支持:invalid
newData(ctx, "其他类型", req), // [grpcMethod]必须是一个GRPC的函数类型,现在是:string
newData(ctx, Login, "其他类型"), // 请正确的传入[request],不支持:string
newData(ctx, Login, zap.S()), // [request]的参数与[grpcMethod]的参数不匹配:grpcMethod = v3_test.loginReq, request = zap.SugaredLogger
newData(ctx, Login, req), // 正常
newData(ctx, Login1, req), // [grpcMethod]必须有2个参数(context.Context, *request)
newData(ctx, Login2, req), // [grpcMethod]的第1个参数必须是:context.Context
newData(ctx, Login3, req), // [grpcMethod]必须有2个返回值(*Response, error)
newData(ctx, Login4, req), // [grpcMethod]的第2个返回值必须是:error
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(nil, Login, req), // 请正确的传递[Context],不支持:nil
newData(ctx, nil, req), // [grpcMethod]必须是一个GRPC的函数类型,现在是:invalid
newData(ctx, Login, nil), // 请正确的传递[request],不支持:invalid
newData(ctx, "其他类型", req), // [grpcMethod]必须是一个GRPC的函数类型,现在是:string
newData(ctx, Login, "其他类型"), // 请正确的传入[request],不支持:string
newData(ctx, Login, zap.S()), // [request]的参数与[grpcMethod]的参数不匹配:grpcMethod = v3_test.loginReq, request = zap.SugaredLogger

}

timeouts = []*data{
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
newData(ctx, Login5, req), // [grpcMethod]的 timeout
}
}

func TestGoGrpc_AddNewTask(t *testing.T) {
Expand Down Expand Up @@ -148,3 +171,35 @@ func TestGoGrpc_Run(t *testing.T) {
fmt.Println()
}
}

func TestGoGrpc_Timeout(t *testing.T) {
run := grpcrun.NewGoGrpc()
for i, d := range timeouts {
run.AddNewTask("test{"+strconv.Itoa(i)+"}", d.method, d.req)
}

run.Run()

for k, t := range run.Task {
if t.Err != nil {
fmt.Println(k, t.Err.(error))
}
}
// fmt.Println(run.Task["test{5}"].Err.(error))
}

func TestGo(t *testing.T) {
type muNum struct {
mu sync.Mutex
num int
}
n := muNum{num: 1}
for i := 1; i <= 10; i++ {
go func(num int) {
time.Sleep(time.Second)
n.num = num
fmt.Println(num)
}(i)
}
time.Sleep(2 * time.Second)
}
9 changes: 4 additions & 5 deletions grpcrun/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type GrpcTask struct {
grpcMethod any

// GRPC的调用参数
ctx *context.Context
ctx context.Context
request any

// GRPC的调用返回值
Expand All @@ -30,10 +30,9 @@ type GrpcTask struct {
//
// Note:
// @param grpcName string name of the grpc, this should be unique
func NewGrpcTask(ctx *context.Context, grpcName string, grpcMethod any, request any) *GrpcTask {
func NewGrpcTask(ctx context.Context, grpcName string, grpcMethod any, request any) *GrpcTask {
mu.Lock()
defer mu.Unlock()
zap.S()

if grpcName == "" {
grpcName = node.Generate().String()
Expand Down Expand Up @@ -66,7 +65,7 @@ func (c *GrpcTask) call() {

// 调用参数
argv := make([]reflect.Value, 2)
argv[0] = reflect.ValueOf(*c.ctx)
argv[0] = reflect.ValueOf(c.ctx)
argv[1] = reflect.ValueOf(c.request)

// 反射调用
Expand All @@ -90,7 +89,7 @@ func (c *GrpcTask) validate() {
}

// 校验 ctx 类型
ctxV := reflect.ValueOf(c.ctx).Elem()
ctxV := reflect.ValueOf(&c.ctx).Elem()
if ctxV.IsNil() {
c.Err = fmt.Errorf("请正确的传递[Context],不支持:nil")
return
Expand Down