Propagate context to requests

This commit is contained in:
Elara 2022-05-12 17:13:44 -07:00
parent af77b121f8
commit 3bcc01fdb6
3 changed files with 53 additions and 18 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/gob"
"net"
@ -33,5 +34,5 @@ func main() {
s.Register(Arith{})
ln, _ := net.Listen("tcp", ":9090")
s.Serve(ln, codec.Gob)
s.Serve(context.Background(), ln, codec.Gob)
}

View File

@ -37,6 +37,24 @@ type Context struct {
doneCh chan struct{}
canceled bool
ctx context.Context
}
func newContext(ctx context.Context, codec codec.Codec) *Context {
out := &Context{
doneCh: make(chan struct{}),
codec: codec,
ctx: ctx,
}
if ctx == nil {
out.ctx = context.Background()
}
go func() {
<-out.ctx.Done()
out.cancel()
}()
return out
}
// MakeChannel changes the function it's called in into a
@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
}
// Cancel cancels the context
func (ctx *Context) Cancel() {
func (ctx *Context) cancel() {
if ctx.canceled {
return
}
ctx.canceled = true
close(ctx.doneCh)
}

View File

@ -19,6 +19,7 @@
package server
import (
"context"
"errors"
"io"
"net"
@ -67,7 +68,7 @@ func New() *Server {
// Close closes the server
func (s *Server) Close() {
for _, ctx := range s.contexts {
ctx.Cancel()
ctx.cancel()
}
}
@ -98,7 +99,7 @@ func (s *Server) Register(v any) error {
}
// execute runs a method of a registered value
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
// Try to get value from receivers map
val, ok := s.rcvrs[typ]
if !ok {
@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
arg = val.Interface()
}
// Create new context
ctx = &Context{
doneCh: make(chan struct{}, 1),
codec: c,
}
ctx = newContext(pCtx, c)
// Get reflect value of context
ctxVal := reflect.ValueOf(ctx)
@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// Serve starts the server using the provided listener
// and codec function
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
go func() {
<-ctx.Done()
ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
} else if err != nil {
continue
}
// Create new instance of codec bound to conn
c := cf(conn)
// Handle connection
go s.handleConn(c)
go s.handleConn(ctx, c)
}
}
// ServeWS starts a server using WebSocket. This may be useful for
// clients written in other languages, such as JS for a browser.
func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
// Create new WebSocket server
ws := websocket.Server{}
@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
// Set server handler
ws.Handler = func(c *websocket.Conn) {
s.handleConn(cf(c))
s.handleConn(c.Request().Context(), cf(c))
}
server := &http.Server{
Addr: addr,
BaseContext: func(net.Listener) context.Context {
return ctx
},
Handler: http.HandlerFunc(ws.ServeHTTP),
}
// Listen and serve on given address
return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP))
return server.ListenAndServe()
}
// handleConn handles a listener connection
func (s *Server) handleConn(c codec.Codec) {
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
codecMtx := &sync.Mutex{}
for {
@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) {
// Execute decoded call
val, ctx, err := s.execute(
pCtx,
call.Receiver,
call.Method,
call.Arg,
@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) {
}
// Cancel context
ctx.Cancel()
ctx.cancel()
// Delete context from map
s.contextsMtx.Lock()
delete(s.contexts, ctx.channelID)
@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
}
// Cancel context
ctx.Cancel()
ctx.cancel()
// Delete context from map
l.srv.contextsMtx.Lock()
delete(l.srv.contexts, id)