Propagate context to requests

This commit is contained in:
Elara 2022-05-12 17:13:44 -07:00
parent af77b121f8
commit 3bcc01fdb6
3 changed files with 53 additions and 18 deletions

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/gob" "encoding/gob"
"net" "net"
@ -33,5 +34,5 @@ func main() {
s.Register(Arith{}) s.Register(Arith{})
ln, _ := net.Listen("tcp", ":9090") ln, _ := net.Listen("tcp", ":9090")
s.Serve(ln, codec.Gob) s.Serve(context.Background(), ln, codec.Gob)
} }

View File

@ -37,6 +37,24 @@ type Context struct {
doneCh chan struct{} doneCh chan struct{}
canceled bool canceled bool
ctx context.Context
}
func newContext(ctx context.Context, codec codec.Codec) *Context {
out := &Context{
doneCh: make(chan struct{}),
codec: codec,
ctx: ctx,
}
if ctx == nil {
out.ctx = context.Background()
}
go func() {
<-out.ctx.Done()
out.cancel()
}()
return out
} }
// MakeChannel changes the function it's called in into a // MakeChannel changes the function it's called in into a
@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
} }
// Cancel cancels the context // Cancel cancels the context
func (ctx *Context) Cancel() { func (ctx *Context) cancel() {
if ctx.canceled {
return
}
ctx.canceled = true ctx.canceled = true
close(ctx.doneCh) close(ctx.doneCh)
} }

View File

@ -19,6 +19,7 @@
package server package server
import ( import (
"context"
"errors" "errors"
"io" "io"
"net" "net"
@ -67,7 +68,7 @@ func New() *Server {
// Close closes the server // Close closes the server
func (s *Server) Close() { func (s *Server) Close() {
for _, ctx := range s.contexts { for _, ctx := range s.contexts {
ctx.Cancel() ctx.cancel()
} }
} }
@ -98,7 +99,7 @@ func (s *Server) Register(v any) error {
} }
// execute runs a method of a registered value // execute runs a method of a registered value
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) { func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
// Try to get value from receivers map // Try to get value from receivers map
val, ok := s.rcvrs[typ] val, ok := s.rcvrs[typ]
if !ok { if !ok {
@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
arg = val.Interface() arg = val.Interface()
} }
// Create new context ctx = newContext(pCtx, c)
ctx = &Context{
doneCh: make(chan struct{}, 1),
codec: c,
}
// Get reflect value of context // Get reflect value of context
ctxVal := reflect.ValueOf(ctx) ctxVal := reflect.ValueOf(ctx)
@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// Serve starts the server using the provided listener // Serve starts the server using the provided listener
// and codec function // and codec function
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) { func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
go func() {
<-ctx.Done()
ln.Close()
}()
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if errors.Is(err, net.ErrClosed) {
break
} else if err != nil {
continue continue
} }
// Create new instance of codec bound to conn // Create new instance of codec bound to conn
c := cf(conn) c := cf(conn)
// Handle connection // Handle connection
go s.handleConn(c) go s.handleConn(ctx, c)
} }
} }
// ServeWS starts a server using WebSocket. This may be useful for // ServeWS starts a server using WebSocket. This may be useful for
// clients written in other languages, such as JS for a browser. // clients written in other languages, such as JS for a browser.
func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) { func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
// Create new WebSocket server // Create new WebSocket server
ws := websocket.Server{} ws := websocket.Server{}
@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
// Set server handler // Set server handler
ws.Handler = func(c *websocket.Conn) { ws.Handler = func(c *websocket.Conn) {
s.handleConn(cf(c)) s.handleConn(c.Request().Context(), cf(c))
}
server := &http.Server{
Addr: addr,
BaseContext: func(net.Listener) context.Context {
return ctx
},
Handler: http.HandlerFunc(ws.ServeHTTP),
} }
// Listen and serve on given address // Listen and serve on given address
return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP)) return server.ListenAndServe()
} }
// handleConn handles a listener connection // handleConn handles a listener connection
func (s *Server) handleConn(c codec.Codec) { func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
codecMtx := &sync.Mutex{} codecMtx := &sync.Mutex{}
for { for {
@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) {
// Execute decoded call // Execute decoded call
val, ctx, err := s.execute( val, ctx, err := s.execute(
pCtx,
call.Receiver, call.Receiver,
call.Method, call.Method,
call.Arg, call.Arg,
@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) {
} }
// Cancel context // Cancel context
ctx.Cancel() ctx.cancel()
// Delete context from map // Delete context from map
s.contextsMtx.Lock() s.contextsMtx.Lock()
delete(s.contexts, ctx.channelID) delete(s.contexts, ctx.channelID)
@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
} }
// Cancel context // Cancel context
ctx.Cancel() ctx.cancel()
// Delete context from map // Delete context from map
l.srv.contextsMtx.Lock() l.srv.contextsMtx.Lock()
delete(l.srv.contexts, id) delete(l.srv.contexts, id)