import (
"errors"
"fmt"
"io"
"net/http"
"net/url"
"github.com/gin-gonic/gin"
"github.com/gobwas/ws"
)
func WebsocketProxy(c *gin.Context) {
connClient, _, _, err := ws.UpgradeHTTP(c.Request, c.Writer)
if err != nil {
c.Status(http.StatusInternalServerError)
return
}
defer connClient.Close()
u, _ := url.Parse("ws://websocket.default.svc.cluster.local:8080/ws")
u.RawQuery = c.Request.URL.RawQuery
dial := ws.Dialer{}
dialHeader := http.Header{}
dial.Header = ws.HandshakeHeaderHTTP(dialHeader)
connServer, _, _, err := dial.Dial(c.Request.Context(), u.String())
if err != nil {
c.Status(http.StatusInternalServerError)
return
}
defer connServer.Close()
errClient := make(chan error, 1)
errServer := make(chan error, 1)
go func() {
for {
header, err := ws.ReadHeader(connClient)
if err != nil {
errClient <- err
}
payload := make([]byte, header.Length)
_, err = io.ReadFull(connClient, payload)
if err != nil {
errClient <- err
}
if err := ws.WriteHeader(connServer, header); err != nil {
errServer <- err
}
if _, err := connServer.Write(payload); err != nil {
errServer <- err
}
if header.OpCode == ws.OpClose {
errServer <- errConnectionClose
}
}
}()
go func() {
for {
header, err := ws.ReadHeader(connServer)
if err != nil {
errServer <- err
}
payload := make([]byte, header.Length)
_, err = io.ReadFull(connServer, payload)
if err != nil {
errServer <- err
}
if err := ws.WriteHeader(connClient, header); err != nil {
errClient <- err
}
if _, err := connClient.Write(payload); err != nil {
errClient <- err
}
if header.OpCode == ws.OpClose {
errClient <- errConnectionClose
}
}
}()
select {
case err = <-errClient:
err = fmt.Errorf("client: %w", err)
case err = <-errServer:
err = fmt.Errorf("server: %w", err)
}
if errors.Is(err, errConnectionClose) {
return
} else {
c.Status(http.StatusInternalServerError)
c.Error(err)
}
}