From 3bcc01fdb6a0408b9bee76894e2c48864d71514f Mon Sep 17 00:00:00 2001 From: Arsen Musayelyan Date: Thu, 12 May 2022 17:13:44 -0700 Subject: [PATCH] Propagate context to requests --- examples/server/main.go | 3 ++- server/context.go | 23 ++++++++++++++++++++- server/server.go | 45 ++++++++++++++++++++++++++--------------- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/examples/server/main.go b/examples/server/main.go index 57d9465..30dae2b 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/gob" "net" @@ -33,5 +34,5 @@ func main() { s.Register(Arith{}) ln, _ := net.Listen("tcp", ":9090") - s.Serve(ln, codec.Gob) + s.Serve(context.Background(), ln, codec.Gob) } diff --git a/server/context.go b/server/context.go index 7cf2fe5..d389c11 100644 --- a/server/context.go +++ b/server/context.go @@ -37,6 +37,24 @@ type Context struct { doneCh chan struct{} canceled bool + + ctx context.Context +} + +func newContext(ctx context.Context, codec codec.Codec) *Context { + out := &Context{ + doneCh: make(chan struct{}), + codec: codec, + ctx: ctx, + } + if ctx == nil { + out.ctx = context.Background() + } + go func() { + <-out.ctx.Done() + out.cancel() + }() + return out } // MakeChannel changes the function it's called in into a @@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} { } // Cancel cancels the context -func (ctx *Context) Cancel() { +func (ctx *Context) cancel() { + if ctx.canceled { + return + } ctx.canceled = true close(ctx.doneCh) } diff --git a/server/server.go b/server/server.go index 169bc9c..96e411f 100644 --- a/server/server.go +++ b/server/server.go @@ -19,6 +19,7 @@ package server import ( + "context" "errors" "io" "net" @@ -67,7 +68,7 @@ func New() *Server { // Close closes the server func (s *Server) Close() { for _, ctx := range s.contexts { - ctx.Cancel() + ctx.cancel() } } @@ -98,7 +99,7 @@ func (s *Server) Register(v any) error { } // execute runs a method of a registered value -func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) { +func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) { // Try to get value from receivers map val, ok := s.rcvrs[typ] if !ok { @@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any arg = val.Interface() } - // Create new context - ctx = &Context{ - doneCh: make(chan struct{}, 1), - codec: c, - } + ctx = newContext(pCtx, c) // Get reflect value of context ctxVal := reflect.ValueOf(ctx) @@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any // Serve starts the server using the provided listener // and codec function -func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) { +func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) { + go func() { + <-ctx.Done() + ln.Close() + }() + for { conn, err := ln.Accept() - if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } else if err != nil { continue } // Create new instance of codec bound to conn c := cf(conn) // Handle connection - go s.handleConn(c) + go s.handleConn(ctx, c) } } // ServeWS starts a server using WebSocket. This may be useful for // clients written in other languages, such as JS for a browser. -func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) { +func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) { // Create new WebSocket server ws := websocket.Server{} @@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) { // Set server handler ws.Handler = func(c *websocket.Conn) { - s.handleConn(cf(c)) + s.handleConn(c.Request().Context(), cf(c)) + } + + server := &http.Server{ + Addr: addr, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + Handler: http.HandlerFunc(ws.ServeHTTP), } // Listen and serve on given address - return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP)) + return server.ListenAndServe() } // handleConn handles a listener connection -func (s *Server) handleConn(c codec.Codec) { +func (s *Server) handleConn(pCtx context.Context, c codec.Codec) { codecMtx := &sync.Mutex{} for { @@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) { // Execute decoded call val, ctx, err := s.execute( + pCtx, call.Receiver, call.Method, call.Arg, @@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) { } // Cancel context - ctx.Cancel() + ctx.cancel() // Delete context from map s.contextsMtx.Lock() delete(s.contexts, ctx.channelID) @@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) { } // Cancel context - ctx.Cancel() + ctx.cancel() // Delete context from map l.srv.contextsMtx.Lock() delete(l.srv.contexts, id)