mirror of
https://github.com/a-h/templ.git
synced 2025-02-06 09:45:21 +00:00
208 lines
6.9 KiB
Go
208 lines
6.9 KiB
Go
package templ_test
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
|
||
"github.com/a-h/templ"
|
||
"github.com/google/go-cmp/cmp"
|
||
)
|
||
|
||
func TestHandler(t *testing.T) {
|
||
hello := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
|
||
if _, err := io.WriteString(w, "Hello"); err != nil {
|
||
t.Fatalf("failed to write string: %v", err)
|
||
}
|
||
return nil
|
||
})
|
||
errorComponent := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
|
||
if _, err := io.WriteString(w, "Hello"); err != nil {
|
||
t.Fatalf("failed to write string: %v", err)
|
||
}
|
||
return errors.New("handler error")
|
||
})
|
||
|
||
tests := []struct {
|
||
name string
|
||
input *templ.ComponentHandler
|
||
expectedStatus int
|
||
expectedMIMEType string
|
||
expectedBody string
|
||
}{
|
||
{
|
||
name: "handlers return OK by default",
|
||
input: templ.Handler(hello),
|
||
expectedStatus: http.StatusOK,
|
||
expectedMIMEType: "text/html; charset=utf-8",
|
||
expectedBody: "Hello",
|
||
},
|
||
{
|
||
name: "handlers return OK by default",
|
||
input: templ.Handler(templ.Raw(`♠ ‘ ♠ ‘`)),
|
||
expectedStatus: http.StatusOK,
|
||
expectedMIMEType: "text/html; charset=utf-8",
|
||
expectedBody: "♠ ‘ ♠ ‘",
|
||
},
|
||
{
|
||
name: "handlers can be configured to return an alternative status code",
|
||
input: templ.Handler(hello, templ.WithStatus(http.StatusNotFound)),
|
||
expectedStatus: http.StatusNotFound,
|
||
expectedMIMEType: "text/html; charset=utf-8",
|
||
expectedBody: "Hello",
|
||
},
|
||
{
|
||
name: "handlers can be configured to return an alternative status code and content type",
|
||
input: templ.Handler(hello, templ.WithStatus(http.StatusOK), templ.WithContentType("text/csv")),
|
||
expectedStatus: http.StatusOK,
|
||
expectedMIMEType: "text/csv",
|
||
expectedBody: "Hello",
|
||
},
|
||
{
|
||
name: "handlers that fail return a 500 error",
|
||
input: templ.Handler(errorComponent),
|
||
expectedStatus: http.StatusInternalServerError,
|
||
expectedMIMEType: "text/plain; charset=utf-8",
|
||
expectedBody: "templ: failed to render template\n",
|
||
},
|
||
{
|
||
name: "error handling can be customised",
|
||
input: templ.Handler(errorComponent, templ.WithErrorHandler(func(r *http.Request, err error) http.Handler {
|
||
// Because the error is received, it's possible to log the detail of the request.
|
||
// log.Printf("template render error for %v %v: %v", r.Method, r.URL.String(), err)
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
if _, err := io.WriteString(w, "custom body"); err != nil {
|
||
t.Fatalf("failed to write string: %v", err)
|
||
}
|
||
})
|
||
})),
|
||
expectedStatus: http.StatusBadRequest,
|
||
expectedMIMEType: "text/html; charset=utf-8",
|
||
expectedBody: "custom body",
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
tt := tt
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
w := httptest.NewRecorder()
|
||
r := httptest.NewRequest("GET", "/test", nil)
|
||
tt.input.ServeHTTP(w, r)
|
||
if got := w.Result().StatusCode; tt.expectedStatus != got {
|
||
t.Errorf("expected status %d, got %d", tt.expectedStatus, got)
|
||
}
|
||
if mimeType := w.Result().Header.Get("Content-Type"); tt.expectedMIMEType != mimeType {
|
||
t.Errorf("expected content-type %s, got %s", tt.expectedMIMEType, mimeType)
|
||
}
|
||
body, err := io.ReadAll(w.Result().Body)
|
||
if err != nil {
|
||
t.Errorf("failed to read body: %v", err)
|
||
}
|
||
if diff := cmp.Diff(tt.expectedBody, string(body)); diff != "" {
|
||
t.Error(diff)
|
||
}
|
||
})
|
||
}
|
||
|
||
t.Run("streaming mode allows responses to be flushed", func(t *testing.T) {
|
||
w := httptest.NewRecorder()
|
||
r := httptest.NewRequest("GET", "/test", nil)
|
||
|
||
component := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
|
||
// Write part 1.
|
||
if _, err := io.WriteString(w, "Part 1"); err != nil {
|
||
return err
|
||
}
|
||
// Flush.
|
||
if f, ok := w.(http.Flusher); ok {
|
||
f.Flush()
|
||
}
|
||
// Check partial response.
|
||
wr := w.(*httptest.ResponseRecorder)
|
||
actualBody := wr.Body.String()
|
||
if diff := cmp.Diff("Part 1", actualBody); diff != "" {
|
||
t.Error(diff)
|
||
}
|
||
// Write part 2.
|
||
if _, err := io.WriteString(w, "\nPart 2"); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
})
|
||
|
||
templ.Handler(component, templ.WithStatus(http.StatusCreated), templ.WithStreaming()).ServeHTTP(w, r)
|
||
if got := w.Result().StatusCode; http.StatusCreated != got {
|
||
t.Errorf("expected status %d, got %d", http.StatusCreated, got)
|
||
}
|
||
if mimeType := w.Result().Header.Get("Content-Type"); "text/html; charset=utf-8" != mimeType {
|
||
t.Errorf("expected content-type %s, got %s", "text/html; charset=utf-8", mimeType)
|
||
}
|
||
body, err := io.ReadAll(w.Result().Body)
|
||
if err != nil {
|
||
t.Errorf("failed to read body: %v", err)
|
||
}
|
||
if diff := cmp.Diff("Part 1\nPart 2", string(body)); diff != "" {
|
||
t.Error(diff)
|
||
}
|
||
})
|
||
t.Run("streaming mode handles errors", func(t *testing.T) {
|
||
w := httptest.NewRecorder()
|
||
r := httptest.NewRequest("GET", "/test", nil)
|
||
|
||
expectedErr := errors.New("streaming error")
|
||
|
||
component := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
|
||
if _, err := io.WriteString(w, "Body"); err != nil {
|
||
return err
|
||
}
|
||
return expectedErr
|
||
})
|
||
|
||
var errorHandlerCalled bool
|
||
errorHandler := func(r *http.Request, err error) http.Handler {
|
||
if expectedErr != err {
|
||
t.Errorf("expected error %v, got %v", expectedErr, err)
|
||
}
|
||
errorHandlerCalled = true
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// This will be ignored, because the header has already been written.
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
// This will be written, but will be appended to the written body.
|
||
if _, err := io.WriteString(w, "Error message"); err != nil {
|
||
t.Errorf("failed to write error message: %v", err)
|
||
}
|
||
})
|
||
}
|
||
|
||
h := templ.Handler(component,
|
||
templ.WithStatus(http.StatusCreated),
|
||
templ.WithStreaming(),
|
||
templ.WithErrorHandler(errorHandler),
|
||
)
|
||
h.ServeHTTP(w, r)
|
||
|
||
if !errorHandlerCalled {
|
||
t.Error("expected error handler to be called")
|
||
}
|
||
// Expect the status code to be 201, not 400, because in streaming mode,
|
||
// we have to write the header before we can call the error handler.
|
||
if actualResponseCode := w.Result().StatusCode; http.StatusCreated != actualResponseCode {
|
||
t.Errorf("expected status %d, got %d", http.StatusCreated, actualResponseCode)
|
||
}
|
||
// Expect the body to be "BodyError message", not just "Error message" because
|
||
// in streaming mode, we've already written part of the body to the response, unlike in
|
||
// standard mode where the body is written to a buffer before the response is written,
|
||
// ensuring that partial responses are not sent.
|
||
actualBody, err := io.ReadAll(w.Result().Body)
|
||
if err != nil {
|
||
t.Errorf("failed to read body: %v", err)
|
||
}
|
||
if diff := cmp.Diff("BodyError message", string(actualBody)); diff != "" {
|
||
t.Error(diff)
|
||
}
|
||
})
|
||
}
|