| // Copyright 2018 The Bazel Authors. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package shard |
| |
| import ( |
| "archive/zip" |
| "bytes" |
| "errors" |
| "fmt" |
| "reflect" |
| "sort" |
| "strings" |
| "testing" |
| ) |
| |
| func TestSepSharder(t *testing.T) { |
| tcs := []struct { |
| name string |
| sep string |
| wantName string |
| }{ |
| { |
| name: "Hello", |
| sep: "/", |
| wantName: "Hello", |
| }, |
| { |
| name: "foo/bar/baz", |
| sep: "/", |
| wantName: "foo/bar", |
| }, |
| { |
| name: "com@google@Foo.dex", |
| sep: "@", |
| wantName: "com@google", |
| }, |
| } |
| |
| for _, tc := range tcs { |
| checkShard := func(name string, sc int) int { |
| if name != tc.wantName { |
| t.Errorf("makeSepSharder(%s).Shard(%s, 1): got name: %s wanted: %s", tc.sep, tc.name, name, tc.wantName) |
| } |
| return 0 |
| } |
| |
| s := MakeSepFunc(tc.sep, Func(checkShard)) |
| s(tc.name, 1) |
| } |
| |
| } |
| |
| func TestBadSharder(t *testing.T) { |
| srcZip, err := makeZip(map[string]string{"hello.dex": "world"}) |
| if err != nil { |
| t.Fatalf("Could not make initial zip: %v", err) |
| } |
| |
| for _, shardVal := range []int{-1, -244, 123} { |
| zr, err := zip.NewReader(bytes.NewReader(srcZip), int64(len(srcZip))) |
| if err != nil { |
| t.Fatalf("could not read initial zip: %v", err) |
| } |
| zws := []*zip.Writer{zip.NewWriter(&bytes.Buffer{})} |
| |
| s := Func(func(name string, sc int) int { |
| return shardVal |
| }) |
| err = ZipShard(zr, zws, s) |
| if err == nil || !strings.Contains(err.Error(), "invalid shard index") { |
| t.Errorf("Returning shard value: %d gave: %v wanted an error with invalid shard index", shardVal, err) |
| } |
| } |
| } |
| |
| func TestZipShard(t *testing.T) { |
| tcs := []struct { |
| name string |
| contents map[string]string |
| shardCount int |
| want map[int][]string |
| zipShardErr error |
| }{ |
| { |
| name: "Vanilla", |
| contents: map[string]string{ |
| "foo/hello.dex": "world", |
| "bar/something.dex": "stuff", |
| "blah/nothing.dex": "here", |
| "blah/everything.dex": "nowhere", |
| "hello/everything.dex": "nowhere", |
| }, |
| shardCount: 5, |
| want: map[int][]string{ |
| 0: {"hello/everything.dex"}, |
| 3: {"foo/hello.dex", "bar/something.dex"}, |
| 4: {"blah/nothing.dex", "blah/everything.dex"}, |
| }, |
| }, |
| { |
| name: "no output shards", |
| contents: map[string]string{"something.dex": "something"}, |
| shardCount: 0, |
| zipShardErr: errors.New("no output writers"), |
| }, |
| { |
| name: "empty input zip", |
| contents: map[string]string{}, |
| shardCount: 5, |
| want: map[int][]string{}, |
| }, |
| } |
| |
| for _, tc := range tcs { |
| srcZip, err := makeZip(tc.contents) |
| if err != nil { |
| t.Errorf("%s: could not create initial zip: %v", tc.name, err) |
| } |
| zr, err := zip.NewReader(bytes.NewReader(srcZip), int64(len(srcZip))) |
| |
| if err != nil { |
| t.Errorf("%s: could not read initial zip: %v", tc.name, err) |
| continue |
| } |
| bufs := make([]*bytes.Buffer, tc.shardCount) |
| zws := make([]*zip.Writer, tc.shardCount) |
| for i := range zws { |
| bufs[i] = new(bytes.Buffer) |
| zws[i] = zip.NewWriter(bufs[i]) |
| } |
| s := MakeSepFunc("/", Func(func(name string, sc int) int { |
| return len(name) % sc |
| })) |
| err = ZipShard(zr, zws, s) |
| if !reflect.DeepEqual(err, tc.zipShardErr) { |
| t.Errorf("%s: got zipshard error: %v wanted: %v", tc.name, err, tc.zipShardErr) |
| continue |
| } |
| for i, s := range bufs { |
| if err := zws[i].Close(); err != nil { |
| t.Errorf("%s: shard: %d cannot close zip writer: %v", tc.name, tc.shardCount, err) |
| continue |
| } |
| z, err := zip.NewReader(bytes.NewReader(s.Bytes()), int64(s.Len())) |
| if err != nil { |
| t.Errorf("%s: shard: %d cannot create zip reader: %v", tc.name, tc.shardCount, err) |
| continue |
| } |
| var fileNames []string |
| for _, f := range z.File { |
| fileNames = append(fileNames, f.Name) |
| } |
| sort.Strings(fileNames) |
| want, _ := tc.want[i] |
| sort.Strings(want) |
| if !reflect.DeepEqual(want, fileNames) { |
| t.Errorf("%s: shard: %d got: %s wanted: %s", tc.name, i, fileNames, want) |
| } |
| } |
| } |
| } |
| |
| func makeZip(contents map[string]string) ([]byte, error) { |
| var zin bytes.Buffer |
| zw := zip.NewWriter(&zin) |
| for name, body := range contents { |
| f, err := zw.Create(name) |
| if err != nil { |
| return nil, fmt.Errorf("%s: could not create: %v", name, err) |
| } |
| _, err = f.Write([]byte(body)) |
| if err != nil { |
| return nil, fmt.Errorf("%s: could not write: %s due to %v: ", name, body, err) |
| } |
| |
| } |
| if err := zw.Close(); err != nil { |
| return nil, err |
| } |
| return zin.Bytes(), nil |
| } |