diff --git a/encryption.go b/encryption.go index b6fa9db..42aa9a8 100644 --- a/encryption.go +++ b/encryption.go @@ -42,6 +42,8 @@ import ( // the encrypted layer type EncryptLayerFinalizer func() (map[string]string, error) +const keyProviderSchemePrefix = "provider." + func init() { keyWrappers = make(map[string]keywrap.KeyWrapper) keyWrapperAnnotations = make(map[string]string) @@ -54,7 +56,7 @@ func init() { log.Error(err) } else if ic != nil { for provider, attrs := range ic.KeyProviderConfig { - RegisterKeyWrapper("provider."+provider, keyprovider.NewKeyWrapper(provider, attrs)) + RegisterKeyWrapper(keyProviderSchemePrefix+provider, keyprovider.NewKeyWrapper(provider, attrs)) } } } @@ -213,6 +215,7 @@ func DecryptLayer(dc *config.DecryptConfig, encLayerReader io.Reader, desc ocisp func decryptLayerKeyOptsData(dc *config.DecryptConfig, desc ocispec.Descriptor) ([]byte, error) { privKeyGiven := false + keyproviderTried := false errs := "" if len(keyWrapperAnnotations) == 0 { return nil, errors.New("missing Annotations needed for decryption") @@ -226,6 +229,11 @@ func decryptLayerKeyOptsData(dc *config.DecryptConfig, desc ocispec.Descriptor) continue } + isKeyprovider := strings.HasPrefix(scheme, keyProviderSchemePrefix) + if isKeyprovider { + keyproviderTried = true + } + if len(keywrapper.GetPrivateKeys(dc.Parameters)) > 0 { privKeyGiven = true } @@ -242,7 +250,7 @@ func decryptLayerKeyOptsData(dc *config.DecryptConfig, desc ocispec.Descriptor) return optsData, nil } } - if !privKeyGiven { + if !privKeyGiven && !keyproviderTried { return nil, fmt.Errorf("missing private key needed for decryption:\n%s", errs) } return nil, fmt.Errorf("no suitable key unwrapper found or none of the private keys could be used for decryption:\n%s", errs) diff --git a/encryption_test.go b/encryption_test.go index f76de5d..d26f644 100644 --- a/encryption_test.go +++ b/encryption_test.go @@ -141,3 +141,55 @@ func TestEncryptLayer(t *testing.T) { t.Fatalf("Expected %v, got %v", data, decLayer) } } + +func TestWasmMediaTypeEncryption(t *testing.T) { + data := []byte("This is WASM module data!") + desc := ocispec.Descriptor{ + Digest: digest.FromBytes(data), + Size: int64(len(data)), + MediaType: "application/vnd.wasm.content.layer.v1+wasm", + } + + dataReader := bytes.NewReader(data) + + encLayerReader, encLayerFinalizer, err := EncryptLayer(ec, dataReader, desc) + if err != nil { + t.Fatal(err) + } + + encLayer := make([]byte, 1024) + encsize, err := encLayerReader.Read(encLayer) + if err != io.EOF { + t.Fatal("Expected EOF") + } + encLayerReaderAt := bytes.NewReader(encLayer[:encsize]) + + annotations, err := encLayerFinalizer() + if err != nil { + t.Fatal(err) + } + + if len(annotations) == 0 { + t.Fatal("No keys created for annotations") + } + + newDesc := ocispec.Descriptor{ + Annotations: annotations, + MediaType: "application/vnd.wasm.content.layer.v1+wasm+encrypted", + } + + decLayerReader, _, err := DecryptLayer(dc, encLayerReaderAt, newDesc, false) + if err != nil { + t.Fatal(err) + } + + decLayer := make([]byte, 1024) + decsize, err := decLayerReader.Read(decLayer) + if err != nil && err != io.EOF { + t.Fatal(err) + } + + if !reflect.DeepEqual(decLayer[:decsize], data) { + t.Fatalf("Expected %v, got %v", data, decLayer) + } +} diff --git a/spec/spec.go b/spec/spec.go index c0c1718..e8b49e2 100644 --- a/spec/spec.go +++ b/spec/spec.go @@ -17,4 +17,8 @@ const ( // // Deprecated: Use [MediaTypeLayerNonDistributableZstdEnc]. MediaTypeLayerNonDistributableZsdtEnc = MediaTypeLayerNonDistributableZstdEnc + // MediaTypeWasmLayer is MIME type used for WASM layers. + MediaTypeWasmLayer = "application/vnd.wasm.content.layer.v1+wasm" + // MediaTypeWasmEnc is MIME type used for encrypted WASM layers. + MediaTypeWasmEnc = MediaTypeWasmLayer + "+encrypted" )