diff --git a/api/api.go b/api/api.go index 59c4ce3..b01f9b8 100644 --- a/api/api.go +++ b/api/api.go @@ -5,13 +5,14 @@ import ( "net" "go.arsenm.dev/itd/internal/rpc" + "storj.io/drpc" "storj.io/drpc/drpcconn" ) const DefaultAddr = "/tmp/itd/socket" type Client struct { - conn *drpcconn.Conn + conn drpc.Conn client rpc.DRPCITDClient } @@ -20,11 +21,15 @@ func New(sockPath string) (*Client, error) { if err != nil { return nil, err } - dconn := drpcconn.New(conn) + + mconn, err := newMuxConn(conn) + if err != nil { + return nil, err + } return &Client{ - conn: dconn, - client: rpc.NewDRPCITDClient(dconn), + conn: mconn, + client: rpc.NewDRPCITDClient(mconn), }, nil } diff --git a/api/drpc.go b/api/drpc.go new file mode 100644 index 0000000..d5db5bf --- /dev/null +++ b/api/drpc.go @@ -0,0 +1,65 @@ +package api + +import ( + "context" + "io" + + "github.com/hashicorp/yamux" + "storj.io/drpc" + "storj.io/drpc/drpcconn" +) + +var _ drpc.Conn = &muxConn{} + +type muxConn struct { + conn io.ReadWriteCloser + sess *yamux.Session + closed chan struct{} +} + +func newMuxConn(conn io.ReadWriteCloser) (*muxConn, error) { + sess, err := yamux.Client(conn, nil) + if err != nil { + return nil, err + } + + return &muxConn{ + conn: conn, + sess: sess, + closed: make(chan struct{}), + }, nil +} + +func (m *muxConn) Close() error { + defer close(m.closed) + + err := m.sess.Close() + if err != nil { + return err + } + return m.conn.Close() +} + +func (m *muxConn) Closed() <-chan struct{} { + return m.closed +} + +func (m *muxConn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { + conn, err := m.sess.Open() + if err != nil { + return err + } + defer conn.Close() + dconn := drpcconn.New(conn) + return dconn.Invoke(ctx, rpc, enc, in, out) +} + +func (m *muxConn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + conn, err := m.sess.Open() + if err != nil { + return nil, err + } + + dconn := drpcconn.New(conn) + return dconn.NewStream(ctx, rpc, enc) +} diff --git a/go.mod b/go.mod index f153556..d3f109a 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/cheggaaa/pb/v3 v3.1.0 github.com/gen2brain/dlgs v0.0.0-20220603100644-40c77870fa8d github.com/godbus/dbus/v5 v5.1.0 + github.com/hashicorp/yamux v0.1.1 github.com/knadh/koanf v1.4.4 github.com/mattn/go-isatty v0.0.17 github.com/mozillazg/go-pinyin v0.19.0 diff --git a/go.sum b/go.sum index 781eae2..1592110 100644 --- a/go.sum +++ b/go.sum @@ -297,6 +297,8 @@ github.com/hashicorp/vault/api v1.0.4/go.mod h1:gDcqh3WGcR1cpF5AJz/B1UFheUEneMoI github.com/hashicorp/vault/sdk v0.1.13/go.mod h1:B+hVj7TpuQY1Y/GPbCpffmgd+tSEwvhkWnjtSYCaS2M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= +github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= +github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hjson/hjson-go/v4 v4.0.0 h1:wlm6IYYqHjOdXH1gHev4VoXCaW20HdQAGCxdOEEg2cs= github.com/hjson/hjson-go/v4 v4.0.0/go.mod h1:KaYt3bTw3zhBjYqnXkYywcYctk0A2nxeEFTse3rH13E= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/socket.go b/socket.go index 0a355f3..0e8921c 100644 --- a/socket.go +++ b/socket.go @@ -27,6 +27,7 @@ import ( "path/filepath" "time" + "github.com/hashicorp/yamux" "github.com/rs/zerolog/log" "go.arsenm.dev/infinitime" "go.arsenm.dev/infinitime/blefs" @@ -79,7 +80,30 @@ func startSocket(ctx context.Context, dev *infinitime.Device) error { return err } - go drpcserver.New(mux).Serve(ctx, ln) + srv := drpcserver.New(mux) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal().Err(err).Msg("Error accepting connection") + } + + sess, err := yamux.Server(conn, nil) + if err != nil { + log.Fatal().Err(err).Msg("Error creating multiplexed session") + } + + for { + conn, err := sess.Accept() + if err != nil { + log.Fatal().Err(err).Msg("Error accepting stream") + } + + go srv.ServeOne(ctx, conn) + } + } + }() // Log socket start log.Info().Str("path", k.String("socket.path")).Msg("Started control socket")