Compare commits

...

2 Commits

Author SHA1 Message Date
Elara f609d5a97f Add introspection functions 2022-05-07 14:59:04 -07:00
Elara ff5f211a83 Use type uint8 to replace boolean fields in response 2022-05-07 14:01:10 -07:00
5 changed files with 150 additions and 41 deletions

View File

@ -105,7 +105,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
c.chMtx.Unlock() c.chMtx.Unlock()
// If response is an error, return error // If response is an error, return error
if resp.IsError { if resp.Type == types.ResponseTypeError {
return errors.New(resp.Error) return errors.New(resp.Error)
} }
@ -118,7 +118,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
retVal := reflect.ValueOf(ret) retVal := reflect.ValueOf(ret)
// If response is a channel // If response is a channel
if resp.IsChannel { if resp.Type == types.ResponseTypeChannel {
// If return value is not a channel, return error // If return value is not a channel, return error
if retVal.Kind() != reflect.Chan { if retVal.Kind() != reflect.Chan {
return ErrReturnNotChannel return ErrReturnNotChannel
@ -137,7 +137,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
// For every value received from channel // For every value received from channel
for val := range c.chs[chID] { for val := range c.chs[chID] {
//s := time.Now() //s := time.Now()
if val.ChannelDone { if val.Type == types.ResponseTypeChannelDone {
// Close and delete channel // Close and delete channel
c.chMtx.Lock() c.chMtx.Lock()
close(c.chs[chID]) close(c.chs[chID])
@ -177,7 +177,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
} }
} }
}() }()
} else { } else if resp.Type == types.ResponseTypeNormal {
// IF return value is not a pointer, return error // IF return value is not a pointer, return error
if retVal.Kind() != reflect.Ptr { if retVal.Kind() != reflect.Ptr {
return ErrReturnNotPointer return ErrReturnNotPointer

View File

@ -26,6 +26,9 @@ import (
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
) )
// <= go1.17 compatibility
type any = interface{}
// CodecFunc is a function that returns a new Codec // CodecFunc is a function that returns a new Codec
// bound to the given io.ReadWriter // bound to the given io.ReadWriter
type CodecFunc func(io.ReadWriter) Codec type CodecFunc func(io.ReadWriter) Codec

View File

@ -26,6 +26,9 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
// <= go1.17 compatibility
type any = interface{}
// Convert attempts to convert the given value to the given type // Convert attempts to convert the given value to the given type
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) { func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
// Get input type // Get input type

View File

@ -18,6 +18,9 @@
package types package types
// <= go1.17 compatibility
type any = interface{}
// Request represents a request sent to the server // Request represents a request sent to the server
type Request struct { type Request struct {
ID string ID string
@ -26,12 +29,19 @@ type Request struct {
Arg any Arg any
} }
type ResponseType uint8
const (
ResponseTypeNormal ResponseType = iota
ResponseTypeError
ResponseTypeChannel
ResponseTypeChannelDone
)
// Response represents a response returned by the server // Response represents a response returned by the server
type Response struct { type Response struct {
ID string Type ResponseType
ChannelDone bool ID string
IsChannel bool Error string
IsError bool Return any
Error string
Return any
} }

View File

@ -36,13 +36,10 @@ import (
type any = interface{} type any = interface{}
var ( var (
ErrInvalidType = errors.New("type must be struct or pointer to struct") ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrTooManyInputs = errors.New("method may not have more than two inputs") ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrTooManyOutputs = errors.New("method may not have more than two return values") ErrNoSuchMethod = errors.New("no such method was found")
ErrNoSuchReceiver = errors.New("no such receiver registered") ErrInvalidMethod = errors.New("method invalid for lrpc call")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidSecondReturn = errors.New("second return value must be error")
ErrInvalidFirstInput = errors.New("first input must be *Context")
) )
// Server is an lrpc server // Server is an lrpc server
@ -114,22 +111,13 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
return nil, nil, ErrNoSuchMethod return nil, nil, ErrNoSuchMethod
} }
// Get method's type // If method invalid, return error
if !mtdValid(mtd) {
return nil, nil, ErrInvalidMethod
}
// Get method type
mtdType := mtd.Type() mtdType := mtd.Type()
if mtdType.NumIn() > 2 {
return nil, nil, ErrTooManyInputs
} else if mtdType.NumIn() < 1 {
return nil, nil, ErrInvalidFirstInput
}
if mtdType.NumOut() > 2 {
return nil, nil, ErrTooManyOutputs
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf(&Context{}) {
return nil, nil, ErrInvalidFirstInput
}
//TODO: if arg not nil but fn has no arg, err //TODO: if arg not nil but fn has no arg, err
@ -215,7 +203,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid // If second return value is not an error, the function is invalid
if !ok { if !ok {
a, err = nil, ErrInvalidSecondReturn a, err = nil, ErrInvalidMethod
} }
} }
@ -231,7 +219,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid // If second return value is not an error, the function is invalid
err, ok = out1.(error) err, ok = out1.(error)
if !ok { if !ok {
a, err = nil, ErrInvalidSecondReturn a, err = nil, ErrInvalidMethod
} }
} }
@ -312,7 +300,7 @@ func (s *Server) handleConn(c codec.Codec) {
// If function has created a channel // If function has created a channel
if ctx.isChannel { if ctx.isChannel {
// Set IsChannel to true // Set IsChannel to true
res.IsChannel = true res.Type = types.ResponseTypeChannel
// Overwrite return value with channel ID // Overwrite return value with channel ID
res.Return = ctx.channelID res.Return = ctx.channelID
@ -342,8 +330,8 @@ func (s *Server) handleConn(c codec.Codec) {
codecMtx.Lock() codecMtx.Lock()
c.Encode(types.Response{ c.Encode(types.Response{
ID: ctx.channelID, Type: types.ResponseTypeChannelDone,
ChannelDone: true, ID: ctx.channelID,
}) })
codecMtx.Unlock() codecMtx.Unlock()
}() }()
@ -361,10 +349,10 @@ func (s *Server) handleConn(c codec.Codec) {
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) { func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
// Encode error response using codec // Encode error response using codec
c.Encode(types.Response{ c.Encode(types.Response{
ID: req.ID, Type: types.ResponseTypeError,
IsError: true, ID: req.ID,
Error: err.Error(), Error: err.Error(),
Return: val, Return: val,
}) })
} }
@ -388,3 +376,108 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
delete(l.srv.contexts, id) delete(l.srv.contexts, id)
l.srv.contextsMtx.Unlock() l.srv.contextsMtx.Unlock()
} }
// MethodDesc describes methods on a receiver
type MethodDesc struct {
Name string
Args []string
Returns []string
}
// Introspect returns method descriptions for the given receiver
func (l lrpc) Introspect(_ *Context, name string) ([]MethodDesc, error) {
// Attempt to get receiver
rcvr, ok := l.srv.rcvrs[name]
if !ok {
return nil, ErrNoSuchReceiver
}
// Get receiver type value
rcvrType := rcvr.Type()
// Create slice for output
var out []MethodDesc
// For every method on receiver
for i := 0; i < rcvr.NumMethod(); i++ {
// Get receiver method
mtd := rcvr.Method(i)
// If invalid, skip
if !mtdValid(mtd) {
continue
}
// Get method type
mtdType := mtd.Type()
// Get amount of arguments
numIn := mtdType.NumIn()
args := make([]string, numIn-1)
// For every argument, store type in slice
// Skip first argument, as it is *Context
for i := 1; i < numIn; i++ {
args[i-1] = mtdType.In(i).String()
}
// Get amount of returns
numOut := mtdType.NumOut()
returns := make([]string, numOut)
// For every return, store type in slice
for i := 0; i < numOut; i++ {
returns[i] = mtdType.Out(i).String()
}
out = append(out, MethodDesc{
Name: rcvrType.Method(i).Name,
Args: args,
Returns: returns,
})
}
return out, nil
}
// IntrospectAll runs Introspect on all registered receivers and returns all results
func (l lrpc) IntrospectAll(_ *Context) (map[string][]MethodDesc, error) {
// Create map for output
out := make(map[string][]MethodDesc, len(l.srv.rcvrs))
// For every registered receiver
for name := range l.srv.rcvrs {
// Introspect receiver
descs, err := l.Introspect(nil, name)
if err != nil {
return nil, err
}
// Set results in map
out[name] = descs
}
return out, nil
}
func mtdValid(mtd reflect.Value) bool {
// Get method's type
mtdType := mtd.Type()
// If method has more than 2 or less than 1 input, it is invalid
if mtdType.NumIn() > 2 || mtdType.NumIn() < 1 {
return false
}
// If method has more than 2 outputs, it is invalid
if mtdType.NumOut() > 2 {
return false
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf((*Context)(nil)) {
return false
}
// If method has 2 outputs
if mtdType.NumOut() == 2 {
// Check to ensure the second one is an error
if mtdType.Out(1).Name() != "error" {
return false
}
}
return true
}