1. 初始化模块
cd test
go mod init jinsse
  1. 安装xo并生成model
go install github.com/xo/xo@latest
xo schema "postgres://myuser:mypassword@localhost:5432/shimazu?sslmode=disable" -o internal/models
ls internal/models
  1. db.go文件:internal/db/dg.go
package db

import (
	"database/sql"

	"github.com/uptrace/bun"
	"github.com/uptrace/bun/dialect/pgdialect"
	"github.com/uptrace/bun/driver/pgdriver"
)

func SetupDatabase() *bun.DB {
	dsn := "postgres://myuser:mypassword@localhost:5432/shimazu?sslmode=disable"
	sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))

	db := bun.NewDB(sqldb, pgdialect.New())
	return db
}
  1. main.go文件:

package main

import ( “context” “encoding/json” “fmt” “log” “net/http” “strconv” “time”

"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/uptrace/bun"

. "sse/internal/db"
. "sse/internal/models"

)

func streamComments(w http.ResponseWriter, r *http.Request, db *bun.DB) { id := chi.URLParam(r, “id”) userID, err := strconv.Atoi(id) if err != nil { log.Printf(“Invalid user ID: %s”, id) http.Error(w, “Invalid user ID”, http.StatusBadRequest) return }

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

flusher, ok := w.(http.Flusher)
if !ok {
	log.Printf("Streaming unsupported")
	http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
	return
}

clientChan := make(chan Comment)
errorChan := make(chan error)
defer close(clientChan)
defer close(errorChan)

ctx, cancel := context.WithCancel(r.Context())
defer cancel()

go func() {
	listenForNewComments(ctx, db, clientChan, errorChan, userID)
}()

for {
	select {
	case comment, ok := <-clientChan:
		if !ok {
			log.Printf("Client channel closed")
			return
		}
		commentJSON, err := json.Marshal(comment)
		if err != nil {
			log.Printf("Error encoding comment: %v", err)
			fmt.Fprintf(w, "event: error\ndata: Error encoding comment\n\n")
		} else {
			log.Printf("New comment: %s", string(commentJSON))
			// 构造一个符合 SSE 格式的消息字符串。这个字符串以 "data: " 开头,后面跟着 JSON 数据,最后以两个换行符 "\n\n" 结束
			fmt.Fprintf(w, "event: new_comment\ndata: %s\n\n", string(commentJSON))
		}

		flusher.Flush()
	case err := <-errorChan:
		log.Printf("Error: %v", err)
		fmt.Fprintf(w, "event: error\ndata: %s\n\n", err.Error())
		flusher.Flush()
	case <-ctx.Done():
		log.Printf("Context done")
		return
	}
}

}

func listenForNewComments(ctx context.Context, db *bun.DB, clientChan chan<- Comment, errorChan chan<- error, userID int) { // 零值的 time.Time 类型变量:0001-01-01 00:00:00 +0000 UTC // var lastCreatedAt time.Time lastCreatedAt := time.Now()

for {
	select {
	case <-ctx.Done():
		return
	default:
		var comments []Comment
		err := db.NewSelect().
			Model(&comments).
			Where("created_at > ? AND user_id = ?", lastCreatedAt, userID).
			Order("created_at DESC").
			Scan(ctx)

		if err != nil {
			log.Printf("error querying database: %v", err)
			errorChan <- fmt.Errorf("error querying database: %v", err)
			time.Sleep(5 * time.Second)
			continue
		}

		for _, comment := range comments {
			clientChan <- comment
			if comment.CreatedAt.After(lastCreatedAt) {
				lastCreatedAt = comment.CreatedAt
			}
		}

		time.Sleep(5 * time.Second)
	}
}

}

func getComment(w http.ResponseWriter, r *http.Request, db *bun.DB) { id := chi.URLParam(r, “id”) userID, err := strconv.Atoi(id) if err != nil { log.Printf(“Invalid user ID: %s”, id) http.Error(w, “Invalid user ID”, http.StatusBadRequest) return }

var comments []Comment
err = db.NewSelect().Model(&comments).Where("user_id = ?", userID).OrderExpr("created_at DESC").Limit(2).Scan(context.Background())
if err != nil {
	log.Printf("Error Comment not found")
	http.Error(w, "Comment not found", http.StatusNotFound)
	return
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(comments)

}

func setupRouter(db *bun.DB) *chi.Mux { r := chi.NewRouter() r.Use(middleware.Logger)

// 基本的 CORS 设置
r.Use(cors.Handler(cors.Options{
	AllowedOrigins:   []string{"*"}, // 允许所有来源,你可以指定具体的域名
	AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
	AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
	ExposedHeaders:   []string{"Link"},
	AllowCredentials: false,
	MaxAge:           300, // 最大缓存时间(秒)
}))

r.Get("/comments/users/{id}", func(w http.ResponseWriter, r *http.Request) {
	getComment(w, r, db)
})

r.Get("/comments/stream/users/{id}", func(w http.ResponseWriter, r *http.Request) {
	streamComments(w, r, db)
})

return r

}

func main() { db := SetupDatabase() defer db.Close()

r := setupRouter(db)
log.Println("Server is running on http://localhost:9080")
http.ListenAndServe(":9080", r)

}

5. html
```html
<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>实时评论更新</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            line-height: 1.6;
            margin: 0;
            padding: 20px;
            background-color: #f4f4f4;
        }

        h1 {
            color: #333;
        }

        #comments {
            background-color: white;
            border: 1px solid #ddd;
            border-radius: 5px;
            padding: 20px;
            margin-top: 20px;
        }

        .comment {
            border-bottom: 1px solid #eee;
            padding: 10px 0;
        }

        .comment:last-child {
            border-bottom: none;
        }
    </style>
</head>

<body>
    <h1>实时评论更新</h1>
    <div id="comments"></div>

    <script>
        const commentsContainer = document.getElementById('comments');
        const eventSource = new EventSource('http://localhost:9080/comments/stream/users/1');

        eventSource.addEventListener('new_comment', (event) => {
            const comment = JSON.parse(event.data);
            // alert('New comment:'+ comment.id);
            console.log('New comment:', comment);
            addCommentToPage(comment);
        });

        eventSource.addEventListener('error', (error) => {
            // alert('SSE error:'+ error);
            console.error('SSE error:', error);
            eventSource.close();
        });

        function addCommentToPage(comment) {
            const commentElement = document.createElement('div');
            commentElement.className = 'comment';
            commentElement.innerHTML = `
                <h3>comment_id: ${comment.id}</h3>
                <p>user_id: ${comment.user_id}</p>
                <p>report_id: ${comment.report_id}</p>
                <p>comment_detail: ${comment.comment_detail}</p>
                <small>创建时间: ${new Date(comment.created_at).toLocaleString()}</small>
            `;
            commentsContainer.insertBefore(commentElement, commentsContainer.firstChild);
        }

        // 初始化页面时加载现有评论
        fetch('http://localhost:9080/comments/users/1')
            .then(response => response.json())
            .then(comments => {
                comments.forEach(addCommentToPage);
            })
            .catch(error => console.error('Error loading initial comments:', error));
    </script>
</body>

</html>
  1. 运行服务端
go mod tidy
go run main.go
2024/07/18 16:45:41 Server is running on http://localhost:9080
  1. api测试
curl --location --request GET 'http://localhost:9080/comments/users/1' \
--header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \
--header 'Content-Type: application/json' \
--header 'Accept: */*' \
--header 'Host: localhost:9080' \
--header 'Connection: keep-alive' \
--data-raw ''
[{"id":9,"report_id":100,"user_id":1,"comment_detail":"test comment from user :no.9","created_at":"2024-07-18T16:47:23.914387+09:00","updated_at":"2024-07-18T16:47:23.91369+09:00"},{"id":8,"report_id":100,"user_id":1,"comment_detail":"test comment from user :no.8","created_at":"2024-07-18T16:45:58.197449+09:00","updated_at":"2024-07-18T16:45:58.196751+09:00"}]
  1. 浏览器测试 insert一条数据 ![[Pasted image 20240718171423.png]] js接收到事件,并更新html ![[Pasted image 20240718171521.png]]

附录

  1. comments表定义
CREATE TABLE public."comments" (
	id serial4 NOT NULL,
	report_id int4 NOT NULL,
	user_id int4 NOT NULL,
	comment_detail text NOT NULL,
	created_at timestamptz DEFAULT now() NOT NULL,
	updated_at timestamptz DEFAULT now() NOT NULL,
	CONSTRAINT comments_pkey PRIMARY KEY (id),
	CONSTRAINT comments_report_id_fkey FOREIGN KEY (report_id) REFERENCES public.reports(id),
	CONSTRAINT comments_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id)
);
  1. comment的model定义
package models

// Code generated by xo. DO NOT EDIT.

import (
	"context"
	"time"
)

// Comment represents a row from 'public.comments'.
type Comment struct {
	ID            int       `json:"id"`             // id
	ReportID      int       `json:"report_id"`      // 週報ID
	UserID        int       `json:"user_id"`        // ユーザーID(コメントを書く人)
	CommentDetail string    `json:"comment_detail"` // コメント内容
	CreatedAt     time.Time `json:"created_at"`     // created_at
	UpdatedAt     time.Time `json:"updated_at"`     // updated_at
	// xo fields
	_exists, _deleted bool
}
  1. 重构代码
package main

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"strconv"
	"time"

	"github.com/go-chi/chi/v5"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/go-chi/cors"
	"github.com/uptrace/bun"

	. "sse/internal/db"
	. "sse/internal/models"
)

const (
	pollInterval = 5 * time.Second
	commentLimit = 2
)

func streamComments(w http.ResponseWriter, r *http.Request, db *bun.DB) {
	userID, err := getUserID(r)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	setupSSEHeaders(w)

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
		return
	}

	ctx, cancel := context.WithCancel(r.Context())
	defer cancel()

	commentChan, errorChan := listenForNewComments(ctx, db, userID)

	for {
		select {
		case comment, ok := <-commentChan:
			if !ok {
				return
			}
			sendSSEMessage(w, flusher, "new_comment", comment)
		case err := <-errorChan:
			sendSSEMessage(w, flusher, "error", err.Error())
		case <-ctx.Done():
			return
		}
	}
}

func listenForNewComments(ctx context.Context, db *bun.DB, userID int) (<-chan Comment, <-chan error) {
	commentChan := make(chan Comment)
	errorChan := make(chan error)

	go func() {
		defer close(commentChan)
		defer close(errorChan)

		lastCreatedAt := time.Now()

		for {
			select {
			case <-ctx.Done():
				return
			default:
				comments, err := fetchNewComments(ctx, db, userID, lastCreatedAt)
				if err != nil {
					errorChan <- err
					time.Sleep(pollInterval)
					continue
				}

				for _, comment := range comments {
					commentChan <- comment
					if comment.CreatedAt.After(lastCreatedAt) {
						lastCreatedAt = comment.CreatedAt
					}
				}

				time.Sleep(pollInterval)
			}
		}
	}()

	return commentChan, errorChan
}

func getComment(w http.ResponseWriter, r *http.Request, db *bun.DB) {
	userID, err := getUserID(r)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	comments, err := fetchComments(r.Context(), db, userID, commentLimit)
	if err != nil {
		http.Error(w, "Comment not found", http.StatusNotFound)
		return
	}

	respondWithJSON(w, comments)
}

func setupRouter(db *bun.DB) *chi.Mux {
	r := chi.NewRouter()
	r.Use(middleware.Logger)
	r.Use(cors.Handler(cors.Options{
		AllowedOrigins:   []string{"*"},
		AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
		AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
		ExposedHeaders:   []string{"Link"},
		AllowCredentials: false,
		MaxAge:           300,
	}))

	r.Get("/comments/users/{id}", func(w http.ResponseWriter, r *http.Request) {
		getComment(w, r, db)
	})

	r.Get("/comments/stream/users/{id}", func(w http.ResponseWriter, r *http.Request) {
		streamComments(w, r, db)
	})

	return r
}

func main() {
	db := SetupDatabase()
	defer db.Close()

	r := setupRouter(db)
	log.Println("Server is running on http://localhost:9080")
	http.ListenAndServe(":9080", r)
}

// Helper functions

func getUserID(r *http.Request) (int, error) {
	id := chi.URLParam(r, "id")
	return strconv.Atoi(id)
}

func setupSSEHeaders(w http.ResponseWriter) {
	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")
	w.Header().Set("Access-Control-Allow-Origin", "*")
}

func sendSSEMessage(w http.ResponseWriter, flusher http.Flusher, eventType string, data interface{}) {
	dataJSON, err := json.Marshal(data)
	if err != nil {
		log.Printf("Error encoding data: %v", err)
		fmt.Fprintf(w, "event: error\ndata: Error encoding data\n\n")
	} else {
		fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, string(dataJSON))
	}
	flusher.Flush()
}

func fetchNewComments(ctx context.Context, db *bun.DB, userID int, lastCreatedAt time.Time) ([]Comment, error) {
	var comments []Comment
	err := db.NewSelect().
		Model(&comments).
		Where("created_at > ? AND user_id = ?", lastCreatedAt, userID).
		Order("created_at DESC").
		Scan(ctx)
	return comments, err
}

func fetchComments(ctx context.Context, db *bun.DB, userID, limit int) ([]Comment, error) {
	var comments []Comment
	err := db.NewSelect().
		Model(&comments).
		Where("user_id = ?", userID).
		OrderExpr("created_at DESC").
		Limit(limit).
		Scan(ctx)
	return comments, err
}

func respondWithJSON(w http.ResponseWriter, data interface{}) {
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(data)
}
重构版本做了以下改进:
将一些常量提取出来,如 pollInterval 和 commentLimit。
将一些重复的逻辑抽取成独立的函数,如 getUserID、setupSSEHeaders、sendSSEMessage 等。
将数据库查询操作封装成独立的函数 fetchNewComments 和 fetchComments。
简化了 streamComments 和 getComment 函数的结构,使其更加清晰。
将 listenForNewComments 函数修改为返回通道,而不是接受通道作为参数。
使用 context.Context 来管理 goroutine 的生命周期。
改进了错误处理和日志记录。