diff --git a/client/main.go b/client/main.go index 79504f4..691b7c1 100644 --- a/client/main.go +++ b/client/main.go @@ -1,6 +1,7 @@ package client import ( + "context" "errors" "net" "reflect" @@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client { } // Call calls a method on the server -func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error { +func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error { // Create new v4 UUOD id, err := uuid.NewV4() if err != nil { @@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err } idStr := id.String() + ctxDoneVal := reflect.ValueOf(ctx.Done()) + // Create new channel using the generated ID c.chMtx.Lock() c.chs[idStr] = make(chan *types.Response, 1) @@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err c.chs[chID] = make(chan *types.Response, 5) c.chMtx.Unlock() - channelClosed := false go func() { // Get type of channel elements chElemType := retVal.Type().Elem() // For every value received from channel for val := range c.chs[chID] { + //s := time.Now() if val.ChannelDone { // Close and delete channel c.chMtx.Lock() @@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err // Close return channel retVal.Close() - - channelClosed = true - break } // Get reflect value from channel response @@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err rVal = newVal } - // Send value to channel - retVal.Send(rVal) - } - }() + chosen, _, _ := reflect.Select([]reflect.SelectCase{ + {Dir: reflect.SelectSend, Chan: retVal, Send: rVal}, + {Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}}, + }) + if chosen == 1 { + c.Call(context.Background(), "lrpc", "ChannelDone", id, nil) + // Close and delete channel + c.chMtx.Lock() + close(c.chs[chID]) + delete(c.chs, chID) + c.chMtx.Unlock() - go func() { - for { - val, ok := retVal.Recv() - if !ok && val.IsValid() { - break + retVal.Close() } } - if !channelClosed { - c.Call("lrpc", "ChannelDone", id, nil) - // Close and delete channel - c.chMtx.Lock() - close(c.chs[chID]) - delete(c.chs, chID) - c.chMtx.Unlock() - } }() } else { // IF return value is not a pointer, return error diff --git a/examples/client/main.go b/examples/client/main.go index 0c34f11..3f5fa98 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/gob" "fmt" "net" @@ -12,21 +13,23 @@ import ( func main() { gob.Register([2]int{}) + ctx := context.Background() + conn, _ := net.Dial("tcp", "localhost:9090") c := client.New(conn, codec.Gob) defer c.Close() var add int - c.Call("Arith", "Add", [2]int{5, 5}, &add) + c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add) var sub int - c.Call("Arith", "Sub", [2]int{5, 5}, &sub) + c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub) var mul int - c.Call("Arith", "Mul", [2]int{5, 5}, &mul) + c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul) var div int - c.Call("Arith", "Div", [2]int{5, 5}, &div) + c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div) fmt.Printf( "add: %d, sub: %d, mul: %d, div: %d\n",