Unit Testing AWS S3 Downloads in Go
An example of github.com/aws/aws-sdk-go/awstesting/unit
’s use.
Problem
How do you unit test a Go function that wraps aws-sdk-go
’s
s3manager#Downloader.Download
without issuing real HTTP requests to the AWS API (and without using an additional tool like localstack)?
Solution
Make the implementation’s *s3manager.Downloader
configurable; in testing, use the
github.com/aws/aws-sdk-go/awstesting/unit
package to create a custom
*s3manager.Downloader
that uses a local testdata
directory as a mock AWS S3 bucket.
Example
Consider a contrived s3object
package. The package provides an S3Object
type, which features a Download
method that wraps s3manager#Downloader.Download
and adds some extra logic.
By default, its New
constructor configures a *s3manager.Downloader
on the user’s behalf. However, the constructor also
accepts a WithDownloader
functional option
for optionally configuring the use of a non-default *s3manager.Downloader
.
package s3object
import (
"net/url"
"os"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
// S3Object is an S3 object.
type S3Object struct {
downloader *s3manager.Downloader
}
// S3ObjectOption is a functional option used to configure a new *S3Object.
type S3ObjectOption = func(s3o *S3Object)
// WithDownloader is a DownloaderOption for configuring the use of a specific
// *s3manager.Downloader.
func WithDownloader(d *s3manager.Downloader) S3ObjectOption {
return func(s3o *S3Object) {
s3o.downloader = d
}
}
// New returns a new *S3Object using the *url.URL it's passed.
func New(u *url.URL, opts ...S3ObjectOption) *S3Object {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
}))
downloader := s3manager.NewDownloader(sess)
s3o := &S3Object{
downloader: downloader,
}
for _, opt := range opts {
opt(s3o)
}
return s3o
}
// Download downloads the S3 object from the URL it's passed.
// It saves the object to a local file named after the last part of the URL's path.
func (s3o *S3Object) Download(u *url.URL) error {
fileName := u.Path[strings.LastIndex(u.Path, "/")+1:]
file, err := os.Create(fileName)
defer file.Close()
if err != nil {
return err
}
_, err = s3o.downloader.Download(file, &s3.GetObjectInput{
Bucket: aws.String(u.Host),
Key: aws.String(u.Path),
})
return err
}
Using the github.com/aws/aws-sdk-go/awstesting/unit
package, a test
*s3manager.Downloader
can be created, provided below via the testDownloader()
function. In this case, the *s3manager.Downloader
provided by testDownloader()
treats a local testdata
directory as a mock S3 bucket.
The s3object.New
constructor’s support for a WithDownloader
functional
option enables the s3object.Object
under test to use the
*s3manager.Downloader
provided by testDownloader()
.
package s3object_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/mdb/s3object"
)
func testDownloader() *s3manager.Downloader {
var locker sync.Mutex
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
locker.Lock()
defer locker.Unlock()
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
f, err := os.ReadFile(fmt.Sprintf("testdata%s", r.HTTPRequest.URL.Path))
switch err {
case nil:
// If there's no error reading the file, return a 200 HTTP response with
// the file contents as the response body.
r.HTTPResponse.StatusCode = http.StatusOK
r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader(f))
default:
// Otherwise, return a 500 HTTP response with the error as the body.
r.HTTPResponse.StatusCode = http.StatusInternalServerError
r.HTTPResponse.Body = ioutil.NopCloser(strings.NewReader(err.Error()))
// But, if the error occurs because the file doesn't exist, return a 404
// HTTP response.
if os.IsNotExist(err) {
r.HTTPResponse.StatusCode = http.StatusNotFound
}
}
r.HTTPResponse.Header.Set("Content-Length", "1")
})
return s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
}
func TestDownload(t *testing.T) {
tests := []struct {
desc string
path string
expectedErr bool
}{{
desc: "does not exist",
path: "testdata/does-not-exist/bim.txt",
expectedErr: true,
}, {
desc: "exists",
path: "testdata/foo/bar.txt",
expectedErr: false,
}}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
u, err := url.Parse(fmt.Sprintf("s3://%s", test.path))
if err != nil {
t.Error(err)
}
newFileName := u.Path[strings.LastIndex(u.Path, "/")+1:]
t.Cleanup(func() { os.Remove(newFileName) })
s3o := s3object.New(u, s3object.WithDownloader(testDownloader()))
err = s3o.Download(u)
if err != nil && !test.expectedErr {
t.Error(err)
}
if err == nil && test.expectedErr {
t.Error("expected error")
}
if test.expectedErr {
return
}
originalFileContent, err := ioutil.ReadFile(test.path)
if err != nil {
t.Fatalf("unable to read file: %v", err)
}
newFileContent, err := ioutil.ReadFile(newFileName)
if err != nil {
t.Fatalf("unable to read file: %v", err)
}
if string(originalFileContent) != string(newFileContent) {
t.Errorf("expected %s contents to equal %s", newFileName, test.path)
}
})
}
}
The project directory looks like the following; note the testdata
directory,
which serves as the fake S3 bucket used in tests:
├── go.mod
├── go.sum
├── s3object.go
├── s3object_test.go
└── testdata
└── foo
└── bar.txt
3 directories, 5 files