From f609d5a97fd6392b0ada76bf35c2b30e250e59ef Mon Sep 17 00:00:00 2001 From: Arsen Musayelyan Date: Sat, 7 May 2022 14:59:04 -0700 Subject: [PATCH] Add introspection functions --- client/client.go | 2 +- codec/codec.go | 3 + internal/reflectutil/utils.go | 3 + internal/types/types.go | 14 ++-- server/server.go | 141 ++++++++++++++++++++++++++++------ 5 files changed, 133 insertions(+), 30 deletions(-) diff --git a/client/client.go b/client/client.go index f3fbc59..3b660d5 100644 --- a/client/client.go +++ b/client/client.go @@ -177,7 +177,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, } } }() - } else { + } else if resp.Type == types.ResponseTypeNormal { // IF return value is not a pointer, return error if retVal.Kind() != reflect.Ptr { return ErrReturnNotPointer diff --git a/codec/codec.go b/codec/codec.go index de1096e..edb0fe0 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -26,6 +26,9 @@ import ( "github.com/vmihailenco/msgpack/v5" ) +// <= go1.17 compatibility +type any = interface{} + // CodecFunc is a function that returns a new Codec // bound to the given io.ReadWriter type CodecFunc func(io.ReadWriter) Codec diff --git a/internal/reflectutil/utils.go b/internal/reflectutil/utils.go index 6a4f94b..0c7a6c1 100644 --- a/internal/reflectutil/utils.go +++ b/internal/reflectutil/utils.go @@ -26,6 +26,9 @@ import ( "github.com/mitchellh/mapstructure" ) +// <= go1.17 compatibility +type any = interface{} + // Convert attempts to convert the given value to the given type func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) { // Get input type diff --git a/internal/types/types.go b/internal/types/types.go index 5632484..faed6bb 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -18,6 +18,9 @@ package types +// <= go1.17 compatibility +type any = interface{} + // Request represents a request sent to the server type Request struct { ID string @@ -29,15 +32,16 @@ type Request struct { type ResponseType uint8 const ( - ResponseTypeError ResponseType = iota + ResponseTypeNormal ResponseType = iota + ResponseTypeError ResponseTypeChannel ResponseTypeChannelDone ) // Response represents a response returned by the server type Response struct { - Type ResponseType - ID string - Error string - Return any + Type ResponseType + ID string + Error string + Return any } diff --git a/server/server.go b/server/server.go index 4243c1a..169bc9c 100644 --- a/server/server.go +++ b/server/server.go @@ -36,13 +36,10 @@ import ( type any = interface{} var ( - ErrInvalidType = errors.New("type must be struct or pointer to struct") - ErrTooManyInputs = errors.New("method may not have more than two inputs") - ErrTooManyOutputs = errors.New("method may not have more than two return values") - ErrNoSuchReceiver = errors.New("no such receiver registered") - ErrNoSuchMethod = errors.New("no such method was found") - ErrInvalidSecondReturn = errors.New("second return value must be error") - ErrInvalidFirstInput = errors.New("first input must be *Context") + ErrInvalidType = errors.New("type must be struct or pointer to struct") + ErrNoSuchReceiver = errors.New("no such receiver registered") + ErrNoSuchMethod = errors.New("no such method was found") + ErrInvalidMethod = errors.New("method invalid for lrpc call") ) // Server is an lrpc server @@ -114,22 +111,13 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any return nil, nil, ErrNoSuchMethod } - // Get method's type + // If method invalid, return error + if !mtdValid(mtd) { + return nil, nil, ErrInvalidMethod + } + + // Get method type mtdType := mtd.Type() - if mtdType.NumIn() > 2 { - return nil, nil, ErrTooManyInputs - } else if mtdType.NumIn() < 1 { - return nil, nil, ErrInvalidFirstInput - } - - if mtdType.NumOut() > 2 { - return nil, nil, ErrTooManyOutputs - } - - // Check to ensure first parameter is context - if mtdType.In(0) != reflect.TypeOf(&Context{}) { - return nil, nil, ErrInvalidFirstInput - } //TODO: if arg not nil but fn has no arg, err @@ -215,7 +203,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any // If second return value is not an error, the function is invalid if !ok { - a, err = nil, ErrInvalidSecondReturn + a, err = nil, ErrInvalidMethod } } @@ -231,7 +219,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any // If second return value is not an error, the function is invalid err, ok = out1.(error) if !ok { - a, err = nil, ErrInvalidSecondReturn + a, err = nil, ErrInvalidMethod } } @@ -388,3 +376,108 @@ func (l lrpc) ChannelDone(_ *Context, id string) { delete(l.srv.contexts, id) l.srv.contextsMtx.Unlock() } + +// MethodDesc describes methods on a receiver +type MethodDesc struct { + Name string + Args []string + Returns []string +} + +// Introspect returns method descriptions for the given receiver +func (l lrpc) Introspect(_ *Context, name string) ([]MethodDesc, error) { + // Attempt to get receiver + rcvr, ok := l.srv.rcvrs[name] + if !ok { + return nil, ErrNoSuchReceiver + } + // Get receiver type value + rcvrType := rcvr.Type() + + // Create slice for output + var out []MethodDesc + // For every method on receiver + for i := 0; i < rcvr.NumMethod(); i++ { + // Get receiver method + mtd := rcvr.Method(i) + // If invalid, skip + if !mtdValid(mtd) { + continue + } + // Get method type + mtdType := mtd.Type() + + // Get amount of arguments + numIn := mtdType.NumIn() + args := make([]string, numIn-1) + // For every argument, store type in slice + // Skip first argument, as it is *Context + for i := 1; i < numIn; i++ { + args[i-1] = mtdType.In(i).String() + } + + // Get amount of returns + numOut := mtdType.NumOut() + returns := make([]string, numOut) + // For every return, store type in slice + for i := 0; i < numOut; i++ { + returns[i] = mtdType.Out(i).String() + } + + out = append(out, MethodDesc{ + Name: rcvrType.Method(i).Name, + Args: args, + Returns: returns, + }) + } + + return out, nil +} + +// IntrospectAll runs Introspect on all registered receivers and returns all results +func (l lrpc) IntrospectAll(_ *Context) (map[string][]MethodDesc, error) { + // Create map for output + out := make(map[string][]MethodDesc, len(l.srv.rcvrs)) + // For every registered receiver + for name := range l.srv.rcvrs { + // Introspect receiver + descs, err := l.Introspect(nil, name) + if err != nil { + return nil, err + } + + // Set results in map + out[name] = descs + } + return out, nil +} + +func mtdValid(mtd reflect.Value) bool { + // Get method's type + mtdType := mtd.Type() + + // If method has more than 2 or less than 1 input, it is invalid + if mtdType.NumIn() > 2 || mtdType.NumIn() < 1 { + return false + } + + // If method has more than 2 outputs, it is invalid + if mtdType.NumOut() > 2 { + return false + } + + // Check to ensure first parameter is context + if mtdType.In(0) != reflect.TypeOf((*Context)(nil)) { + return false + } + + // If method has 2 outputs + if mtdType.NumOut() == 2 { + // Check to ensure the second one is an error + if mtdType.Out(1).Name() != "error" { + return false + } + } + + return true +}