diff --git a/api/helper.go b/api/helper.go new file mode 100644 index 0000000..89b4451 --- /dev/null +++ b/api/helper.go @@ -0,0 +1,19 @@ +package api + +import ( + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type Session struct { + id string + aConn *websocket.Conn + bConn *websocket.Conn + lastInteractionTime time.Time + lastInteractedPartyIsA bool +} + +var openSessions = make(map[string]*Session) +var sessionsMu sync.RWMutex diff --git a/api/http_handler.go b/api/http_handler.go index e9a688a..1a2dffe 100644 --- a/api/http_handler.go +++ b/api/http_handler.go @@ -4,6 +4,9 @@ import ( "fmt" "log" "net/http" + "time" + + "echo/misc" "github.com/gorilla/mux" "github.com/gorilla/websocket" @@ -22,29 +25,62 @@ func serveStaticFile(fileName string) http.HandlerFunc { } func createSession(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, ` - { - "sessionId":"XyZ123" + + sessionsMu.RLock() + var id string + for { + id, _ = misc.RandomString(10) + + _, ok := openSessions[id] + if !ok { + break // make sure we don't have duplicated ids + } + } - `) + sessionsMu.RUnlock() + + session := Session{ + id: id, + lastInteractionTime: time.Now(), + lastInteractedPartyIsA: true, + } + + sessionsMu.Lock() + openSessions[session.id] = &session + sessionsMu.Unlock() + + fmt.Fprintf(w, + `{ + "sessionId":%s +}`, session.id) } var upgrader = websocket.Upgrader{} func signalWS(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - party := vars["party"] - session := vars["session"] + partyStr := vars["party"] + sessionId := vars["session"] - fmt.Printf("Initiating a websocket, party=%s session=%s\n", party, session) + fmt.Printf("Initiating a websocket, party=%s session=%s\n", partyStr, sessionId) + partyIsA := true + switch partyStr { + case "A": + + case "B": + partyIsA = false + default: + fmt.Printf("Party is invalid partyStr=%s", partyStr) + return + } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - fmt.Printf("Error initiating websocket, r.URL=%s e=%s", r.URL.String(), err) + fmt.Printf("Error initiating websocket, r.URL=%s party=%v e=%s", r.URL.String(), partyIsA, err) } - websocketSignaler(conn) + websocketSignaler(conn, partyIsA, sessionId) } diff --git a/api/ws_handler.go b/api/ws_handler.go index e81645d..203b182 100644 --- a/api/ws_handler.go +++ b/api/ws_handler.go @@ -1,21 +1,76 @@ package api -import "github.com/gorilla/websocket" +import ( + "fmt" + "time" -func websocketSignaler(conn *websocket.Conn) { + "github.com/gorilla/websocket" +) + +func websocketSignaler(conn *websocket.Conn, partyIsA bool, sessionId string) { defer conn.Close() + sessionsMu.RLock() + openSession, ok := openSessions[sessionId] + sessionsMu.RUnlock() + + if !ok { + // conn.WriteMessage(1, "{Session not found or something}") + + return + } + + fmt.Println("Session", openSession) + + if partyIsA { + if openSession.aConn != nil { + openSession.aConn.Close() + } + openSession.aConn = conn + if !openSession.lastInteractedPartyIsA { + openSession.lastInteractedPartyIsA = true + openSession.lastInteractionTime = time.Now() + } + } else { + if openSession.bConn != nil { + openSession.bConn.Close() + } + openSession.bConn = conn + + if openSession.lastInteractedPartyIsA { + openSession.lastInteractedPartyIsA = false + openSession.lastInteractionTime = time.Now() + } + } + for { msgType, msg, err := conn.ReadMessage() if err != nil { + break + } + + if !ok { + // conn.WriteMessage(1, "{Session not found or something}") + return } - // Let's just echo for now as a placeholder + var otherConn *websocket.Conn + if partyIsA { + otherConn = openSession.bConn + } else { + otherConn = openSession.aConn + } - err = conn.WriteMessage(msgType, msg) + if otherConn == nil { + // conn.WriteMessage(msgType, "other party is not connected to websocket") + continue + } + + err = otherConn.WriteMessage(msgType, msg) if err != nil { - return + // do something here } } + } diff --git a/misc/utils.go b/misc/utils.go new file mode 100644 index 0000000..0f39c90 --- /dev/null +++ b/misc/utils.go @@ -0,0 +1,19 @@ +package misc + +import "crypto/rand" + +const letters = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func RandomString(n int) (string, error) { + bytes := make([]byte, n) + _, err := rand.Read(bytes) + if err != nil { + return "", err + } + + for i := range n { + bytes[i] = letters[int(bytes[i])%len(letters)] + } + + return string(bytes), nil +}