diff --git a/client/client.go b/client/client.go index 389fc39..f2a22c6 100644 --- a/client/client.go +++ b/client/client.go @@ -26,7 +26,6 @@ import ( "sync" "go.arsenm.dev/lrpc/codec" - "go.arsenm.dev/lrpc/internal/reflectutil" "go.arsenm.dev/lrpc/internal/types" "github.com/gofrs/uuid" @@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, c.chs[idStr] = make(chan *types.Response, 1) c.chMtx.Unlock() + argData, err := c.codec.Marshal(arg) + if err != nil { + return err + } + // Encode request using codec err = c.codec.Encode(types.Request{ ID: idStr, Receiver: rcvr, Method: method, - Arg: arg, + Arg: argData, }) if err != nil { return err @@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, return ErrReturnNotChannel } // Get channel ID returned in response - chID := resp.Return.(string) + var chID string + err = c.codec.Unmarshal(resp.Return, &chID) + if resp.Return == nil { + return nil + } // Create new channel using channel ID c.chMtx.Lock() @@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, retVal.Close() break } - // Get reflect value from channel response - rVal := reflect.ValueOf(val.Return) - // If return value is not the same as the channel - if rVal.Type() != chElemType { - // Attempt to convert value, skip if impossible - newVal, err := reflectutil.Convert(rVal, chElemType) - if err != nil { - continue - } - rVal = newVal + outVal := reflect.New(chElemType) + err = c.codec.Unmarshal(val.Return, outVal.Interface()) + if err != nil { + continue } + outVal = outVal.Elem() chosen, _, _ := reflect.Select([]reflect.SelectCase{ - {Dir: reflect.SelectSend, Chan: retVal, Send: rVal}, + {Dir: reflect.SelectSend, Chan: retVal, Send: outVal}, {Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}}, }) if chosen == 1 { @@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, } }() } else if resp.Type == types.ResponseTypeNormal { - // IF return value is not a pointer, return error - if retVal.Kind() != reflect.Ptr { - return ErrReturnNotPointer + err = c.codec.Unmarshal(resp.Return, ret) + if err != nil { + return err } - - // Get return type - retType := retVal.Type().Elem() - // Get refkect value from response - rVal := reflect.ValueOf(resp.Return) - - // If types do not match - if rVal.Type() != retType { - // Attempt to convert types, return error if not possible - newVal, err := reflectutil.Convert(rVal, retType) - if err != nil { - return err - } - rVal = newVal - } - - // Set return value to received value - retVal.Elem().Set(rVal) } return nil diff --git a/codec/codec.go b/codec/codec.go index edb0fe0..067be1b 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -19,6 +19,7 @@ package codec import ( + "bytes" "encoding/gob" "encoding/json" "io" @@ -38,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec type Codec interface { Encode(val any) error Decode(val any) error + Unmarshal(data []byte, v any) error + Marshal(v any) ([]byte, error) } // Default is the default CodecFunc var Default = Msgpack +type JsonCodec struct { + *json.Encoder + *json.Decoder +} + +func (JsonCodec) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (JsonCodec) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + // JSON is a CodecFunc that creates a JSON Codec func JSON(rw io.ReadWriter) Codec { - type jsonCodec struct { - *json.Encoder - *json.Decoder - } - return jsonCodec{ + return JsonCodec{ Encoder: json.NewEncoder(rw), Decoder: json.NewDecoder(rw), } } +type MsgpackCodec struct { + *msgpack.Encoder + *msgpack.Decoder +} + +func (MsgpackCodec) Unmarshal(data []byte, v any) error { + return msgpack.Unmarshal(data, v) +} + +func (MsgpackCodec) Marshal(v any) ([]byte, error) { + return msgpack.Marshal(v) +} + // Msgpack is a CodecFunc that creates a Msgpack Codec func Msgpack(rw io.ReadWriter) Codec { - type msgpackCodec struct { - *msgpack.Encoder - *msgpack.Decoder - } - return msgpackCodec{ + return MsgpackCodec{ Encoder: msgpack.NewEncoder(rw), Decoder: msgpack.NewDecoder(rw), } } +type GobCodec struct { + *gob.Encoder + *gob.Decoder +} + +func (GobCodec) Unmarshal(data []byte, v any) error { + return gob.NewDecoder(bytes.NewReader(data)).Decode(v) +} + +func (GobCodec) Marshal(v any) ([]byte, error) { + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(v) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + // Gob is a CodecFunc that creates a Gob Codec func Gob(rw io.ReadWriter) Codec { - type gobCodec struct { - *gob.Encoder - *gob.Decoder - } - return gobCodec{ + return GobCodec{ Encoder: gob.NewEncoder(rw), Decoder: gob.NewDecoder(rw), } diff --git a/internal/types/types.go b/internal/types/types.go index faed6bb..7a981e9 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -26,7 +26,7 @@ type Request struct { ID string Receiver string Method string - Arg any + Arg []byte } type ResponseType uint8 @@ -43,5 +43,5 @@ type Response struct { Type ResponseType ID string Error string - Return any + Return []byte } diff --git a/lrpc_test.go b/lrpc_test.go index 72cb76a..586a048 100644 --- a/lrpc_test.go +++ b/lrpc_test.go @@ -122,7 +122,7 @@ func TestCodecs(t *testing.T) { if err != nil { t.Errorf("codec/%s: %v", name, err) } - + if add != 4 { t.Errorf("codec/%s: add: expected 4, got %d", name, add) } diff --git a/server/server.go b/server/server.go index e5f0cc2..f630ce7 100644 --- a/server/server.go +++ b/server/server.go @@ -28,7 +28,6 @@ import ( "sync" "go.arsenm.dev/lrpc/codec" - "go.arsenm.dev/lrpc/internal/reflectutil" "go.arsenm.dev/lrpc/internal/types" "golang.org/x/net/websocket" ) @@ -37,12 +36,11 @@ import ( type any = interface{} var ( - ErrInvalidType = errors.New("type must be struct or pointer to struct") - ErrNoSuchReceiver = errors.New("no such receiver registered") - ErrNoSuchMethod = errors.New("no such method was found") - ErrInvalidMethod = errors.New("method invalid for lrpc call") - ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments") - ErrArgNotProvided = errors.New("method expected an argument, but none was provided") + ErrInvalidType = errors.New("type must be struct or pointer to struct") + ErrNoSuchReceiver = errors.New("no such receiver registered") + ErrNoSuchMethod = errors.New("no such method was found") + ErrInvalidMethod = errors.New("method invalid for lrpc call") + ErrArgNotProvided = errors.New("method expected an argument, but none was provided") ) // Server is an lrpc server @@ -101,7 +99,7 @@ func (s *Server) Register(v any) error { } // execute runs a method of a registered value -func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) { +func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) { // Try to get value from receivers map val, ok := s.rcvrs[typ] if !ok { @@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, // Get method type mtdType := mtd.Type() - // Return error if argument provided but isn't expected - if mtdType.NumIn() == 1 && arg != nil { - return nil, nil, ErrUnexpectedArgument - } + //TODO: if arg not nil but fn has no arg, err - // IF argument is []any - anySlice, ok := arg.([]any) - if ok { - // Convert slice to the method's arg type and - // set arg to the newly-converted slice - arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1)) - } + argType := mtdType.In(1) + argVal := reflect.New(argType) + arg := argVal.Interface() - // Get argument value - argVal := reflect.ValueOf(arg) - // If argument's type does not match method's argument type - if arg != nil && argVal.Type() != mtdType.In(1) { - val, err = reflectutil.Convert(argVal, mtdType.In(1)) - if err != nil { - return nil, nil, err - } - arg = val.Interface() + err = c.Unmarshal(data, arg) + if err != nil { + return nil, nil, err } + + arg = argVal.Elem().Interface() + ctx = newContext(pCtx, c) // Get reflect value of context @@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) { if err != nil { s.sendErr(c, call, val, err) } else { + valData, err := c.Marshal(val) + if err != nil { + s.sendErr(c, call, val, err) + continue + } + // Create response res := types.Response{ ID: call.ID, - Return: val, + Return: valData, } // If function has created a channel if ctx.isChannel { + idData, err := c.Marshal(ctx.channelID) + if err != nil { + s.sendErr(c, call, val, err) + continue + } + // Set IsChannel to true res.Type = types.ResponseTypeChannel // Overwrite return value with channel ID - res.Return = ctx.channelID + res.Return = idData // Store context in map for future use s.contextsMtx.Lock() @@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) { // For every value received from channel for val := range ctx.channel { codecMtx.Lock() + + valData, err := c.Marshal(val) + if err != nil { + continue + } + // Encode response using codec c.Encode(types.Response{ ID: ctx.channelID, - Return: val, + Return: valData, }) + codecMtx.Unlock() } @@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) { // sendErr sends an error response func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) { + valData, _ := c.Marshal(val) + // Encode error response using codec c.Encode(types.Response{ Type: types.ResponseTypeError, ID: req.ID, Error: err.Error(), - Return: val, + Return: valData, }) }