This repository has been archived on 2022-08-07. You can view files and clone it, but cannot push or open issues or pull requests.
lrpc/client/main.go

198 lines
4.2 KiB
Go

package client
import (
"errors"
"net"
"reflect"
"sync"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types"
"github.com/gofrs/uuid"
)
// <= go1.17 compatibility
type any = interface{}
// Client error values
var (
ErrReturnNotChannel = errors.New("function call returns channel but return value is not a channel type")
ErrReturnNotPointer = errors.New("function call returns value but return value is not a pointer")
ErrMismatchedType = errors.New("type of channel does not match type returned by server")
)
// Client is an lrpc client
type Client struct {
conn net.Conn
codec codec.Codec
chMtx sync.Mutex
chs map[string]chan *types.Response
}
// New creates and returns a new client
func New(conn net.Conn, cf codec.CodecFunc) *Client {
out := &Client{
conn: conn,
codec: cf(conn),
chs: map[string]chan *types.Response{},
}
go out.handleConn()
return out
}
// Call calls a method on the server
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
// Create new v4 UUOD
id, err := uuid.NewV4()
if err != nil {
return err
}
idStr := id.String()
// Create new channel using the generated ID
c.chMtx.Lock()
c.chs[idStr] = make(chan *types.Response, 1)
c.chMtx.Unlock()
// Encode request using codec
err = c.codec.Encode(types.Request{
ID: idStr,
Receiver: rcvr,
Method: method,
Arg: arg,
})
if err != nil {
return err
}
// Get response from channel
resp := <-c.chs[idStr]
// Close and delete channel
c.chMtx.Lock()
close(c.chs[idStr])
delete(c.chs, idStr)
c.chMtx.Unlock()
// If response is an error, return error
if resp.IsError {
return errors.New(resp.Error)
}
// If there is no return value, stop now
if resp.Return == nil {
return nil
}
// Get reflect value of return value
retVal := reflect.ValueOf(ret)
// If response is a channel
if resp.IsChannel {
// If return value is not a channel, return error
if retVal.Kind() != reflect.Chan {
return ErrReturnNotChannel
}
// Get channel ID returned in response
chID := resp.Return.(string)
// Create new channel using channel ID
c.chMtx.Lock()
c.chs[chID] = make(chan *types.Response, 5)
c.chMtx.Unlock()
go func() {
// Get type of channel elements
chElemType := retVal.Type().Elem()
// For every value received from channel
for val := range c.chs[chID] {
// Get reflect value from channel response
rVal := reflect.ValueOf(val.Return)
// If return value is not the same as the channel
if rVal.Type() != chElemType {
// Attempt to convert value, skip if impossible
newVal, err := reflectutil.Convert(rVal, chElemType)
if err != nil {
continue
}
rVal = newVal
}
// Try to read from the channel
recvVal, ok := retVal.TryRecv()
// IF the channel cannot be read but the value is valid,
// the channel must be closed
if !ok && recvVal.IsValid() {
// Send done signal
c.Call("lrpc", "ChannelDone", idStr, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
break
}
// Send value to channel
retVal.Send(rVal)
}
}()
} else {
// IF return value is not a pointer, return error
if retVal.Kind() != reflect.Pointer {
return ErrReturnNotPointer
}
// Get return type
retType := retVal.Type().Elem()
// Get refkect value from response
rVal := reflect.ValueOf(resp.Return)
// If types do not match
if rVal.Type() != retType {
// Attempt to convert types, return error if not possible
newVal, err := reflectutil.Convert(rVal, retType)
if err != nil {
return err
}
rVal = newVal
}
// Set return value to received value
retVal.Elem().Set(rVal)
}
return nil
}
func (c *Client) handleConn() {
for {
resp := &types.Response{}
// Attempt to decode response using codec
err := c.codec.Decode(resp)
if err != nil {
continue
}
// Get channel from map, skip if it doesn't exist
ch, ok := c.chs[resp.ID]
if !ok {
continue
}
// Send response to channel
ch <- resp
}
}
// Close closes the client
func (c *Client) Close() error {
return c.conn.Close()
}