Resolve data races using mutex

This commit is contained in:
Elara 2022-05-04 16:15:35 -07:00
parent 7ef9e56505
commit b1e7ded874
2 changed files with 17 additions and 2 deletions

View File

@ -47,7 +47,7 @@ type Client struct {
conn io.ReadWriteCloser conn io.ReadWriteCloser
codec codec.Codec codec codec.Codec
chMtx sync.Mutex chMtx *sync.Mutex
chs map[string]chan *types.Response chs map[string]chan *types.Response
} }
@ -57,6 +57,7 @@ func New(conn io.ReadWriteCloser, cf codec.CodecFunc) *Client {
conn: conn, conn: conn,
codec: cf(conn), codec: cf(conn),
chs: map[string]chan *types.Response{}, chs: map[string]chan *types.Response{},
chMtx: &sync.Mutex{},
} }
go out.handleConn() go out.handleConn()
@ -92,7 +93,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
} }
// Get response from channel // Get response from channel
resp := <-c.chs[idStr] c.chMtx.Lock()
respCh := c.chs[idStr]
c.chMtx.Unlock()
resp := <-respCh
// Close and delete channel // Close and delete channel
c.chMtx.Lock() c.chMtx.Lock()
@ -210,11 +214,14 @@ func (c *Client) handleConn() {
continue continue
} }
c.chMtx.Lock()
// Get channel from map, skip if it doesn't exist // Get channel from map, skip if it doesn't exist
ch, ok := c.chs[resp.ID] ch, ok := c.chs[resp.ID]
if !ok { if !ok {
c.chMtx.Unlock()
continue continue
} }
c.chMtx.Unlock()
// Send response to channel // Send response to channel
ch <- resp ch <- resp

View File

@ -280,6 +280,8 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
// handleConn handles a listener connection // handleConn handles a listener connection
func (s *Server) handleConn(c codec.Codec) { func (s *Server) handleConn(c codec.Codec) {
codecMtx := &sync.Mutex{}
for { for {
var call types.Request var call types.Request
// Read request using codec // Read request using codec
@ -322,11 +324,13 @@ func (s *Server) handleConn(c codec.Codec) {
go func() { go func() {
// For every value received from channel // For every value received from channel
for val := range ctx.channel { for val := range ctx.channel {
codecMtx.Lock()
// Encode response using codec // Encode response using codec
c.Encode(types.Response{ c.Encode(types.Response{
ID: ctx.channelID, ID: ctx.channelID,
Return: val, Return: val,
}) })
codecMtx.Unlock()
} }
// Cancel context // Cancel context
@ -336,15 +340,19 @@ func (s *Server) handleConn(c codec.Codec) {
delete(s.contexts, ctx.channelID) delete(s.contexts, ctx.channelID)
s.contextsMtx.Unlock() s.contextsMtx.Unlock()
codecMtx.Lock()
c.Encode(types.Response{ c.Encode(types.Response{
ID: ctx.channelID, ID: ctx.channelID,
ChannelDone: true, ChannelDone: true,
}) })
codecMtx.Unlock()
}() }()
} }
// Encode response using codec // Encode response using codec
codecMtx.Lock()
c.Encode(res) c.Encode(res)
codecMtx.Unlock()
} }
} }
} }