Compare commits

...

2 Commits

2 changed files with 24 additions and 26 deletions

View File

@ -1,6 +1,7 @@
package client
import (
"context"
"errors"
"net"
"reflect"
@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
}
// Call calls a method on the server
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
// Create new v4 UUOD
id, err := uuid.NewV4()
if err != nil {
@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
}
idStr := id.String()
ctxDoneVal := reflect.ValueOf(ctx.Done())
// Create new channel using the generated ID
c.chMtx.Lock()
c.chs[idStr] = make(chan *types.Response, 1)
@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
c.chs[chID] = make(chan *types.Response, 5)
c.chMtx.Unlock()
channelClosed := false
go func() {
// Get type of channel elements
chElemType := retVal.Type().Elem()
// For every value received from channel
for val := range c.chs[chID] {
//s := time.Now()
if val.ChannelDone {
// Close and delete channel
c.chMtx.Lock()
@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
// Close return channel
retVal.Close()
channelClosed = true
break
}
// Get reflect value from channel response
@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
rVal = newVal
}
// Send value to channel
retVal.Send(rVal)
}
}()
chosen, _, _ := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
})
if chosen == 1 {
c.Call(context.Background(), "lrpc", "ChannelDone", chID, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
go func() {
for {
val, ok := retVal.Recv()
if !ok && val.IsValid() {
break
retVal.Close()
}
}
if !channelClosed {
c.Call("lrpc", "ChannelDone", id, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
}
}()
} else {
// IF return value is not a pointer, return error

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/gob"
"fmt"
"net"
@ -12,21 +13,23 @@ import (
func main() {
gob.Register([2]int{})
ctx := context.Background()
conn, _ := net.Dial("tcp", "localhost:9090")
c := client.New(conn, codec.Gob)
defer c.Close()
var add int
c.Call("Arith", "Add", [2]int{5, 5}, &add)
c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
var sub int
c.Call("Arith", "Sub", [2]int{5, 5}, &sub)
c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
var mul int
c.Call("Arith", "Mul", [2]int{5, 5}, &mul)
c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
var div int
c.Call("Arith", "Div", [2]int{5, 5}, &div)
c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
fmt.Printf(
"add: %d, sub: %d, mul: %d, div: %d\n",