diff --git a/client/main.go b/client/main.go index 698b4c2..f6b7b19 100644 --- a/client/main.go +++ b/client/main.go @@ -111,6 +111,18 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err chElemType := retVal.Type().Elem() // For every value received from channel for val := range c.chs[chID] { + if val.ChannelDone { + // Close and delete channel + c.chMtx.Lock() + close(c.chs[chID]) + delete(c.chs, chID) + c.chMtx.Unlock() + + // Close return channel + retVal.Close() + + break + } // Get reflect value from channel response rVal := reflect.ValueOf(val.Return) @@ -124,25 +136,25 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err rVal = newVal } - // Try to read from the channel - recvVal, ok := retVal.TryRecv() - // IF the channel cannot be read but the value is valid, - // the channel must be closed - if !ok && recvVal.IsValid() { - // Send done signal - c.Call("lrpc", "ChannelDone", idStr, nil) - // Close and delete channel - c.chMtx.Lock() - close(c.chs[chID]) - delete(c.chs, chID) - c.chMtx.Unlock() - break - } - // Send value to channel retVal.Send(rVal) } }() + + go func() { + for { + val, ok := retVal.Recv() + if !ok && val.IsValid() { + break + } + } + c.Call("lrpc", "ChannelDone", id, nil) + // Close and delete channel + c.chMtx.Lock() + close(c.chs[chID]) + delete(c.chs, chID) + c.chMtx.Unlock() + }() } else { // IF return value is not a pointer, return error if retVal.Kind() != reflect.Ptr { diff --git a/internal/types/types.go b/internal/types/types.go index 6ce5e51..ae09981 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -11,6 +11,7 @@ type Request struct { // Response represents a response returned by the server type Response struct { ID string + ChannelDone bool IsChannel bool IsError bool Error string diff --git a/server/server.go b/server/server.go index af37c14..ce47106 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "errors" + "io" "net" "reflect" "sync" @@ -243,7 +244,9 @@ func (s *Server) handleConn(c codec.Codec) { var call types.Request // Read request using codec err := c.Decode(&call) - if err != nil { + if err == io.EOF { + break + } else if err != nil { s.sendErr(c, call, nil, err) continue } @@ -285,6 +288,18 @@ func (s *Server) handleConn(c codec.Codec) { Return: val, }) } + + // Cancel context + ctx.Cancel() + // Delete context from map + s.contextsMtx.Lock() + delete(s.contexts, ctx.channelID) + s.contextsMtx.Unlock() + + c.Encode(types.Response{ + ID: ctx.channelID, + ChannelDone: true, + }) }() } @@ -321,5 +336,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) { // Cancel context ctx.Cancel() // Delete context from map + l.srv.contextsMtx.Lock() delete(l.srv.contexts, id) + l.srv.contextsMtx.Unlock() }