Compare commits

...

2 Commits

Author SHA1 Message Date
Elara f609d5a97f Add introspection functions 2022-05-07 14:59:04 -07:00
Elara ff5f211a83 Use type uint8 to replace boolean fields in response 2022-05-07 14:01:10 -07:00
5 changed files with 150 additions and 41 deletions

View File

@ -105,7 +105,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
c.chMtx.Unlock()
// If response is an error, return error
if resp.IsError {
if resp.Type == types.ResponseTypeError {
return errors.New(resp.Error)
}
@ -118,7 +118,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
retVal := reflect.ValueOf(ret)
// If response is a channel
if resp.IsChannel {
if resp.Type == types.ResponseTypeChannel {
// If return value is not a channel, return error
if retVal.Kind() != reflect.Chan {
return ErrReturnNotChannel
@ -137,7 +137,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
// For every value received from channel
for val := range c.chs[chID] {
//s := time.Now()
if val.ChannelDone {
if val.Type == types.ResponseTypeChannelDone {
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
@ -177,7 +177,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
}
}
}()
} else {
} else if resp.Type == types.ResponseTypeNormal {
// IF return value is not a pointer, return error
if retVal.Kind() != reflect.Ptr {
return ErrReturnNotPointer

View File

@ -26,6 +26,9 @@ import (
"github.com/vmihailenco/msgpack/v5"
)
// <= go1.17 compatibility
type any = interface{}
// CodecFunc is a function that returns a new Codec
// bound to the given io.ReadWriter
type CodecFunc func(io.ReadWriter) Codec

View File

@ -26,6 +26,9 @@ import (
"github.com/mitchellh/mapstructure"
)
// <= go1.17 compatibility
type any = interface{}
// Convert attempts to convert the given value to the given type
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
// Get input type

View File

@ -18,6 +18,9 @@
package types
// <= go1.17 compatibility
type any = interface{}
// Request represents a request sent to the server
type Request struct {
ID string
@ -26,12 +29,19 @@ type Request struct {
Arg any
}
type ResponseType uint8
const (
ResponseTypeNormal ResponseType = iota
ResponseTypeError
ResponseTypeChannel
ResponseTypeChannelDone
)
// Response represents a response returned by the server
type Response struct {
ID string
ChannelDone bool
IsChannel bool
IsError bool
Error string
Return any
Type ResponseType
ID string
Error string
Return any
}

View File

@ -36,13 +36,10 @@ import (
type any = interface{}
var (
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrTooManyInputs = errors.New("method may not have more than two inputs")
ErrTooManyOutputs = errors.New("method may not have more than two return values")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidSecondReturn = errors.New("second return value must be error")
ErrInvalidFirstInput = errors.New("first input must be *Context")
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidMethod = errors.New("method invalid for lrpc call")
)
// Server is an lrpc server
@ -114,22 +111,13 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
return nil, nil, ErrNoSuchMethod
}
// Get method's type
// If method invalid, return error
if !mtdValid(mtd) {
return nil, nil, ErrInvalidMethod
}
// Get method type
mtdType := mtd.Type()
if mtdType.NumIn() > 2 {
return nil, nil, ErrTooManyInputs
} else if mtdType.NumIn() < 1 {
return nil, nil, ErrInvalidFirstInput
}
if mtdType.NumOut() > 2 {
return nil, nil, ErrTooManyOutputs
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf(&Context{}) {
return nil, nil, ErrInvalidFirstInput
}
//TODO: if arg not nil but fn has no arg, err
@ -215,7 +203,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid
if !ok {
a, err = nil, ErrInvalidSecondReturn
a, err = nil, ErrInvalidMethod
}
}
@ -231,7 +219,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid
err, ok = out1.(error)
if !ok {
a, err = nil, ErrInvalidSecondReturn
a, err = nil, ErrInvalidMethod
}
}
@ -312,7 +300,7 @@ func (s *Server) handleConn(c codec.Codec) {
// If function has created a channel
if ctx.isChannel {
// Set IsChannel to true
res.IsChannel = true
res.Type = types.ResponseTypeChannel
// Overwrite return value with channel ID
res.Return = ctx.channelID
@ -342,8 +330,8 @@ func (s *Server) handleConn(c codec.Codec) {
codecMtx.Lock()
c.Encode(types.Response{
ID: ctx.channelID,
ChannelDone: true,
Type: types.ResponseTypeChannelDone,
ID: ctx.channelID,
})
codecMtx.Unlock()
}()
@ -361,10 +349,10 @@ func (s *Server) handleConn(c codec.Codec) {
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
// Encode error response using codec
c.Encode(types.Response{
ID: req.ID,
IsError: true,
Error: err.Error(),
Return: val,
Type: types.ResponseTypeError,
ID: req.ID,
Error: err.Error(),
Return: val,
})
}
@ -388,3 +376,108 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
delete(l.srv.contexts, id)
l.srv.contextsMtx.Unlock()
}
// MethodDesc describes methods on a receiver
type MethodDesc struct {
Name string
Args []string
Returns []string
}
// Introspect returns method descriptions for the given receiver
func (l lrpc) Introspect(_ *Context, name string) ([]MethodDesc, error) {
// Attempt to get receiver
rcvr, ok := l.srv.rcvrs[name]
if !ok {
return nil, ErrNoSuchReceiver
}
// Get receiver type value
rcvrType := rcvr.Type()
// Create slice for output
var out []MethodDesc
// For every method on receiver
for i := 0; i < rcvr.NumMethod(); i++ {
// Get receiver method
mtd := rcvr.Method(i)
// If invalid, skip
if !mtdValid(mtd) {
continue
}
// Get method type
mtdType := mtd.Type()
// Get amount of arguments
numIn := mtdType.NumIn()
args := make([]string, numIn-1)
// For every argument, store type in slice
// Skip first argument, as it is *Context
for i := 1; i < numIn; i++ {
args[i-1] = mtdType.In(i).String()
}
// Get amount of returns
numOut := mtdType.NumOut()
returns := make([]string, numOut)
// For every return, store type in slice
for i := 0; i < numOut; i++ {
returns[i] = mtdType.Out(i).String()
}
out = append(out, MethodDesc{
Name: rcvrType.Method(i).Name,
Args: args,
Returns: returns,
})
}
return out, nil
}
// IntrospectAll runs Introspect on all registered receivers and returns all results
func (l lrpc) IntrospectAll(_ *Context) (map[string][]MethodDesc, error) {
// Create map for output
out := make(map[string][]MethodDesc, len(l.srv.rcvrs))
// For every registered receiver
for name := range l.srv.rcvrs {
// Introspect receiver
descs, err := l.Introspect(nil, name)
if err != nil {
return nil, err
}
// Set results in map
out[name] = descs
}
return out, nil
}
func mtdValid(mtd reflect.Value) bool {
// Get method's type
mtdType := mtd.Type()
// If method has more than 2 or less than 1 input, it is invalid
if mtdType.NumIn() > 2 || mtdType.NumIn() < 1 {
return false
}
// If method has more than 2 outputs, it is invalid
if mtdType.NumOut() > 2 {
return false
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf((*Context)(nil)) {
return false
}
// If method has 2 outputs
if mtdType.NumOut() == 2 {
// Check to ensure the second one is an error
if mtdType.Out(1).Name() != "error" {
return false
}
}
return true
}