diff --git a/binding.gyp b/binding.gyp index e5116af..563dce5 100644 --- a/binding.gyp +++ b/binding.gyp @@ -41,7 +41,8 @@ "src/ConsumerConfig.cc", "src/Reader.cc", "src/ReaderConfig.cc", - "src/ThreadSafeDeferred.cc" + "src/ThreadSafeDeferred.cc", + "src/CryptoKeyReader.cc" ], 'conditions': [ ['OS=="mac"', { diff --git a/index.d.ts b/index.d.ts index 46ac69a..b85b623 100644 --- a/index.d.ts +++ b/index.d.ts @@ -65,6 +65,8 @@ export interface ProducerConfig { properties?: { [key: string]: string }; publicKeyPath?: string; encryptionKey?: string; + encryptionKeys?: string[]; + cryptoKeyReader?: CryptoKeyReader; cryptoFailureAction?: ProducerCryptoFailureAction; chunkingEnabled?: boolean; schema?: SchemaInfo; @@ -99,6 +101,7 @@ export interface ConsumerConfig { listener?: (message: Message, consumer: Consumer) => void; readCompacted?: boolean; privateKeyPath?: string; + cryptoKeyReader?: CryptoKeyReader; cryptoFailureAction?: ConsumerCryptoFailureAction; maxPendingChunkedMessage?: number; autoAckOldestChunkedMessageOnQueueFull?: number; @@ -171,6 +174,7 @@ export class Message { getPartitionKey(): string; getOrderingKey(): string; getProducerName(): string; + getEncryptionContext(): EncryptionContext | null; } export class MessageId { @@ -198,6 +202,22 @@ export interface TopicMetadata { */ export type MessageRouter = (message: Message, topicMetadata: TopicMetadata) => number; +export interface EncryptionKey { + key: string; + value: Buffer; + metadata: { [key: string]: string }; +} + +export interface EncryptionContext { + keys: EncryptionKey[]; + param: Buffer; + algorithm: string; + compressionType: CompressionType; + uncompressedMessageSize: number; + batchSize: number; + isDecryptionFailed: boolean; +} + export interface SchemaInfo { schemaType: SchemaType; name?: string; @@ -285,6 +305,16 @@ export class AuthenticationBasic { }); } +export interface EncryptionKeyInfo { + key: Buffer; + metadata: { [key: string]: string }; +} + +export class CryptoKeyReader { + getPublicKey(keyName: string, metadata: { [key: string]: string }): EncryptionKeyInfo; + getPrivateKey(keyName: string, metadata: { [key: string]: string }): EncryptionKeyInfo; +} + export enum LogLevel { DEBUG = 0, INFO = 1, @@ -303,6 +333,7 @@ export type HashingScheme = 'JavaStringHash'; export type CompressionType = + 'None' | 'Zlib' | 'LZ4' | 'ZSTD' | diff --git a/index.js b/index.js index ddbb997..d909251 100644 --- a/index.js +++ b/index.js @@ -37,6 +37,7 @@ const Pulsar = { Client, Message: PulsarBinding.Message, MessageId: PulsarBinding.MessageId, + CryptoKeyReader: PulsarBinding.CryptoKeyReader, AuthenticationTls, AuthenticationAthenz, AuthenticationToken, diff --git a/src/ConsumerConfig.cc b/src/ConsumerConfig.cc index 7b2b61c..e7419c6 100644 --- a/src/ConsumerConfig.cc +++ b/src/ConsumerConfig.cc @@ -20,6 +20,7 @@ #include "ConsumerConfig.h" #include "Consumer.h" #include "SchemaInfo.h" +#include "CryptoKeyReader.h" #include "Message.h" #include "pulsar/ConsumerConfiguration.h" #include @@ -60,6 +61,7 @@ static const std::string CFG_KEY_SHARED_POLICY = "keySharedPolicy"; static const std::string CFG_KEY_SHARED_POLICY_MODE = "keyShareMode"; static const std::string CFG_KEY_SHARED_POLICY_ALLOW_OUT_OF_ORDER = "allowOutOfOrderDelivery"; static const std::string CFG_KEY_SHARED_POLICY_STICKY_RANGES = "stickyRanges"; +static const std::string CFG_CRYPTO_KEY_READER = "cryptoKeyReader"; static const std::map SUBSCRIPTION_TYPE = { {"Exclusive", pulsar_ConsumerExclusive}, @@ -249,13 +251,21 @@ void ConsumerConfig::InitConfig(std::shared_ptr deferred, std::string privateKeyPath = consumerConfig.Get(CFG_PRIVATE_KEY_PATH).ToString().Utf8Value(); pulsar_consumer_configuration_set_default_crypto_key_reader( this->cConsumerConfig.get(), publicKeyPath.c_str(), privateKeyPath.c_str()); - if (consumerConfig.Has(CFG_CRYPTO_FAILURE_ACTION) && - consumerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).IsString()) { - std::string cryptoFailureAction = consumerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).ToString().Utf8Value(); - if (CONSUMER_CRYPTO_FAILURE_ACTION.count(cryptoFailureAction)) { - pulsar_consumer_configuration_set_crypto_failure_action( - this->cConsumerConfig.get(), CONSUMER_CRYPTO_FAILURE_ACTION.at(cryptoFailureAction)); - } + } + + if (consumerConfig.Has(CFG_CRYPTO_KEY_READER) && consumerConfig.Get(CFG_CRYPTO_KEY_READER).IsObject()) { + Napi::Object cryptoKeyReaderObj = consumerConfig.Get(CFG_CRYPTO_KEY_READER).As(); + CryptoKeyReader *cryptoKeyReader = Napi::ObjectWrap::Unwrap(cryptoKeyReaderObj); + this->cConsumerConfig.get()->consumerConfiguration.setCryptoKeyReader( + cryptoKeyReader->GetCCryptoKeyReader()); + } + + if (consumerConfig.Has(CFG_CRYPTO_FAILURE_ACTION) && + consumerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).IsString()) { + std::string cryptoFailureAction = consumerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).ToString().Utf8Value(); + if (CONSUMER_CRYPTO_FAILURE_ACTION.count(cryptoFailureAction)) { + pulsar_consumer_configuration_set_crypto_failure_action( + this->cConsumerConfig.get(), CONSUMER_CRYPTO_FAILURE_ACTION.at(cryptoFailureAction)); } } diff --git a/src/CryptoKeyReader.cc b/src/CryptoKeyReader.cc new file mode 100644 index 0000000..4cf09ad --- /dev/null +++ b/src/CryptoKeyReader.cc @@ -0,0 +1,151 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "CryptoKeyReader.h" +#include +#include +#include +#include + +class CryptoKeyReaderWrapper : public pulsar::CryptoKeyReader { + public: + CryptoKeyReaderWrapper(const Napi::Object& jsObject) : mainThreadId_(std::this_thread::get_id()) { + jsObject_.Reset(jsObject, 1); + tsfn_ = Napi::ThreadSafeFunction::New( + jsObject.Env(), Napi::Function::New(jsObject.Env(), [](const Napi::CallbackInfo& info) {}), jsObject, + "CryptoKeyReader", 0, 1); + } + + ~CryptoKeyReaderWrapper() { tsfn_.Release(); } + + pulsar::Result getPublicKey(const std::string& keyName, std::map& metadata, + pulsar::EncryptionKeyInfo& encKeyInfo) const override { + return executeCallback("getPublicKey", keyName, metadata, encKeyInfo); + } + + pulsar::Result getPrivateKey(const std::string& keyName, std::map& metadata, + pulsar::EncryptionKeyInfo& encKeyInfo) const override { + return executeCallback("getPrivateKey", keyName, metadata, encKeyInfo); + } + + private: + Napi::ObjectReference jsObject_; + Napi::ThreadSafeFunction tsfn_; + std::thread::id mainThreadId_; + + static void parseEncryptionKeyInfo(const Napi::Object& obj, pulsar::EncryptionKeyInfo& info) { + if (obj.Has("key") && obj.Get("key").IsBuffer()) { + Napi::Buffer keyBuf = obj.Get("key").As>(); + info.setKey(std::string(keyBuf.Data(), keyBuf.Length())); + } + if (obj.Has("metadata") && obj.Get("metadata").IsObject()) { + std::map metadata; + Napi::Object metaObj = obj.Get("metadata").As(); + Napi::Array keys = metaObj.GetPropertyNames(); + for (uint32_t i = 0; i < keys.Length(); i++) { + std::string k = keys.Get(i).ToString().Utf8Value(); + std::string v = metaObj.Get(k).ToString().Utf8Value(); + metadata[k] = v; + } + info.setMetadata(metadata); + } + } + + pulsar::Result callJsMethod(Napi::Env env, const std::string& method, const std::string& keyName, + const std::map& metadata, + pulsar::EncryptionKeyInfo& encKeyInfo) const { + Napi::HandleScope scope(env); + + if (jsObject_.IsEmpty()) { + return pulsar::Result::ResultCryptoError; + } + Napi::Object obj = jsObject_.Value(); + + if (!obj.Has(method)) { + return pulsar::Result::ResultCryptoError; + } + Napi::Value funcVal = obj.Get(method); + if (!funcVal.IsFunction()) { + return pulsar::Result::ResultCryptoError; + } + Napi::Function func = funcVal.As(); + + Napi::Object metadataObj = Napi::Object::New(env); + for (const auto& kv : metadata) { + metadataObj.Set(kv.first, kv.second); + } + + try { + Napi::Value result = func.Call(obj, {Napi::String::New(env, keyName), metadataObj}); + if (result.IsObject()) { + parseEncryptionKeyInfo(result.As(), encKeyInfo); + return pulsar::Result::ResultOk; + } + } catch (const Napi::Error& e) { + return pulsar::Result::ResultCryptoError; + } catch (...) { + return pulsar::Result::ResultCryptoError; + } + return pulsar::Result::ResultCryptoError; + } + + pulsar::Result executeCallback(const std::string& method, const std::string& keyName, + std::map& metadata, + pulsar::EncryptionKeyInfo& encKeyInfo) const { + if (std::this_thread::get_id() == mainThreadId_) { + return callJsMethod(jsObject_.Env(), method, keyName, metadata, encKeyInfo); + } else { + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + napi_status status = tsfn_.BlockingCall([this, promise, &method, &keyName, &metadata, &encKeyInfo]( + Napi::Env env, Napi::Function jsCallback) { + promise->set_value(callJsMethod(env, method, keyName, metadata, encKeyInfo)); + }); + + if (status != napi_ok) { + return pulsar::Result::ResultCryptoError; + } + + future.wait(); + return future.get(); + } + } +}; + +Napi::FunctionReference CryptoKeyReader::constructor; + +void CryptoKeyReader::Init(Napi::Env env, Napi::Object exports) { + Napi::HandleScope scope(env); + + Napi::Function func = DefineClass(env, "CryptoKeyReader", {}); + + constructor = Napi::Persistent(func); + constructor.SuppressDestruct(); + + exports.Set("CryptoKeyReader", func); +} + +CryptoKeyReader::CryptoKeyReader(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) {} + +CryptoKeyReader::~CryptoKeyReader() {} + +std::shared_ptr CryptoKeyReader::GetCCryptoKeyReader() { + return std::make_shared(Value()); +} diff --git a/src/CryptoKeyReader.h b/src/CryptoKeyReader.h new file mode 100644 index 0000000..19d9155 --- /dev/null +++ b/src/CryptoKeyReader.h @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef CRYPTO_KEY_READER_H +#define CRYPTO_KEY_READER_H + +#include +#include + +class CryptoKeyReader : public Napi::ObjectWrap { + public: + static void Init(Napi::Env env, Napi::Object exports); + static Napi::Object NewInstance(const Napi::CallbackInfo &info); + CryptoKeyReader(const Napi::CallbackInfo &info); + ~CryptoKeyReader(); + std::shared_ptr GetCCryptoKeyReader(); + + private: + static Napi::FunctionReference constructor; +}; + +#endif diff --git a/src/Message.cc b/src/Message.cc index 8b1d081..58909bd 100644 --- a/src/Message.cc +++ b/src/Message.cc @@ -20,6 +20,15 @@ #include "Message.h" #include "MessageId.h" #include +#include +#include +#include +#include + +struct _pulsar_message { + pulsar::MessageBuilder builder; + pulsar::Message message; +}; static const std::string CFG_DATA = "data"; static const std::string CFG_PROPS = "properties"; @@ -32,6 +41,12 @@ static const std::string CFG_DELIVER_AT = "deliverAt"; static const std::string CFG_DISABLE_REPLICATION = "disableReplication"; static const std::string CFG_ORDERING_KEY = "orderingKey"; +static const std::map COMPRESSION_TYPE_MAP = { + {pulsar::CompressionNone, "None"}, {pulsar::CompressionLZ4, "LZ4"}, + {pulsar::CompressionZLib, "Zlib"}, {pulsar::CompressionZSTD, "ZSTD"}, + {pulsar::CompressionSNAPPY, "SNAPPY"}, +}; + Napi::FunctionReference Message::constructor; Napi::Object Message::Init(Napi::Env env, Napi::Object exports) { @@ -47,7 +62,8 @@ Napi::Object Message::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("getRedeliveryCount", &Message::GetRedeliveryCount), InstanceMethod("getPartitionKey", &Message::GetPartitionKey), InstanceMethod("getOrderingKey", &Message::GetOrderingKey), - InstanceMethod("getProducerName", &Message::GetProducerName)}); + InstanceMethod("getProducerName", &Message::GetProducerName), + InstanceMethod("getEncryptionContext", &Message::GetEncryptionContext)}); constructor = Napi::Persistent(func); constructor.SuppressDestruct(); @@ -156,6 +172,54 @@ Napi::Value Message::GetProducerName(const Napi::CallbackInfo &info) { return Napi::String::New(env, pulsar_message_get_producer_name(this->cMessage.get())); } +Napi::Value Message::GetEncryptionContext(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + if (!ValidateCMessage(env)) { + return env.Null(); + } + + auto encCtxOpt = this->cMessage.get()->message.getEncryptionContext(); + if (!encCtxOpt) { + return env.Null(); + } + + // getEncryptionContext returns std::optional + const pulsar::EncryptionContext *encCtxPtr = *encCtxOpt; + if (!encCtxPtr) { + return env.Null(); + } + const pulsar::EncryptionContext &encCtx = *encCtxPtr; + + Napi::Object obj = Napi::Object::New(env); + Napi::Array keys = Napi::Array::New(env); + int i = 0; + for (const auto &keyInfo : encCtx.keys()) { + Napi::Object keyObj = Napi::Object::New(env); + keyObj.Set("key", Napi::String::New(env, keyInfo.key)); + keyObj.Set("value", Napi::Buffer::Copy(env, keyInfo.value.c_str(), keyInfo.value.length())); + + Napi::Object metadataObj = Napi::Object::New(env); + for (const auto &meta : keyInfo.metadata) { + metadataObj.Set(meta.first, Napi::String::New(env, meta.second)); + } + keyObj.Set("metadata", metadataObj); + + keys.Set(i++, keyObj); + } + obj.Set("keys", keys); + + obj.Set("param", Napi::Buffer::Copy(env, encCtx.param().c_str(), encCtx.param().length())); + obj.Set("algorithm", Napi::String::New(env, encCtx.algorithm())); + const auto it = COMPRESSION_TYPE_MAP.find(encCtx.compressionType()); + std::string compressionTypeStr = (it != COMPRESSION_TYPE_MAP.end()) ? it->second : "None"; + obj.Set("compressionType", Napi::String::New(env, compressionTypeStr)); + obj.Set("uncompressedMessageSize", Napi::Number::New(env, encCtx.uncompressedMessageSize())); + obj.Set("batchSize", Napi::Number::New(env, encCtx.batchSize())); + obj.Set("isDecryptionFailed", Napi::Boolean::New(env, encCtx.isDecryptionFailed())); + + return obj; +} + bool Message::ValidateCMessage(Napi::Env env) { if (this->cMessage.get()) { return true; diff --git a/src/Message.h b/src/Message.h index 417de92..4d8c4aa 100644 --- a/src/Message.h +++ b/src/Message.h @@ -47,6 +47,7 @@ class Message : public Napi::ObjectWrap { Napi::Value GetOrderingKey(const Napi::CallbackInfo &info); Napi::Value GetProducerName(const Napi::CallbackInfo &info); Napi::Value GetRedeliveryCount(const Napi::CallbackInfo &info); + Napi::Value GetEncryptionContext(const Napi::CallbackInfo &info); bool ValidateCMessage(Napi::Env env); static char **NewStringArray(int size) { return (char **)calloc(sizeof(char *), size); } diff --git a/src/ProducerConfig.cc b/src/ProducerConfig.cc index 83afb9c..eca62a7 100644 --- a/src/ProducerConfig.cc +++ b/src/ProducerConfig.cc @@ -18,6 +18,7 @@ */ #include "SchemaInfo.h" #include "ProducerConfig.h" +#include "CryptoKeyReader.h" #include "Message.h" #include #include @@ -45,6 +46,8 @@ static const std::string CFG_SCHEMA = "schema"; static const std::string CFG_PROPS = "properties"; static const std::string CFG_PUBLIC_KEY_PATH = "publicKeyPath"; static const std::string CFG_ENCRYPTION_KEY = "encryptionKey"; +static const std::string CFG_ENCRYPTION_KEYS = "encryptionKeys"; +static const std::string CFG_CRYPTO_KEY_READER = "cryptoKeyReader"; static const std::string CFG_CRYPTO_FAILURE_ACTION = "cryptoFailureAction"; static const std::string CFG_CHUNK_ENABLED = "chunkingEnabled"; static const std::string CFG_ACCESS_MODE = "accessMode"; @@ -67,10 +70,8 @@ static const std::map HASHING_SCHEME = { }; static std::map COMPRESSION_TYPE = { - {"Zlib", pulsar_CompressionZLib}, - {"LZ4", pulsar_CompressionLZ4}, - {"ZSTD", pulsar_CompressionZSTD}, - {"SNAPPY", pulsar_CompressionSNAPPY}, + {"None", pulsar_CompressionNone}, {"Zlib", pulsar_CompressionZLib}, {"LZ4", pulsar_CompressionLZ4}, + {"ZSTD", pulsar_CompressionZSTD}, {"SNAPPY", pulsar_CompressionSNAPPY}, }; static std::map PRODUCER_CRYPTO_FAILURE_ACTION = { @@ -239,15 +240,32 @@ ProducerConfig::ProducerConfig(const Napi::Object& producerConfig) : topic("") { std::string encryptionKey = producerConfig.Get(CFG_ENCRYPTION_KEY).ToString().Utf8Value(); pulsar_producer_configuration_set_encryption_key(this->cProducerConfig.get(), encryptionKey.c_str()); } - if (producerConfig.Has(CFG_CRYPTO_FAILURE_ACTION) && - producerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).IsString()) { - std::string cryptoFailureAction = producerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).ToString().Utf8Value(); - if (PRODUCER_CRYPTO_FAILURE_ACTION.count(cryptoFailureAction)) - pulsar_producer_configuration_set_crypto_failure_action( - this->cProducerConfig.get(), PRODUCER_CRYPTO_FAILURE_ACTION.at(cryptoFailureAction)); + } + + if (producerConfig.Has(CFG_ENCRYPTION_KEYS) && producerConfig.Get(CFG_ENCRYPTION_KEYS).IsArray()) { + Napi::Array keys = producerConfig.Get(CFG_ENCRYPTION_KEYS).As(); + for (uint32_t i = 0; i < keys.Length(); i++) { + if (keys.Get(i).IsString()) { + std::string key = keys.Get(i).ToString().Utf8Value(); + this->cProducerConfig.get()->conf.addEncryptionKey(key); + } } } + if (producerConfig.Has(CFG_CRYPTO_KEY_READER) && producerConfig.Get(CFG_CRYPTO_KEY_READER).IsObject()) { + Napi::Object cryptoKeyReaderObj = producerConfig.Get(CFG_CRYPTO_KEY_READER).As(); + CryptoKeyReader* cryptoKeyReader = Napi::ObjectWrap::Unwrap(cryptoKeyReaderObj); + this->cProducerConfig.get()->conf.setCryptoKeyReader(cryptoKeyReader->GetCCryptoKeyReader()); + } + + if (producerConfig.Has(CFG_CRYPTO_FAILURE_ACTION) && + producerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).IsString()) { + std::string cryptoFailureAction = producerConfig.Get(CFG_CRYPTO_FAILURE_ACTION).ToString().Utf8Value(); + if (PRODUCER_CRYPTO_FAILURE_ACTION.count(cryptoFailureAction)) + pulsar_producer_configuration_set_crypto_failure_action( + this->cProducerConfig.get(), PRODUCER_CRYPTO_FAILURE_ACTION.at(cryptoFailureAction)); + } + if (producerConfig.Has(CFG_CHUNK_ENABLED) && producerConfig.Get(CFG_CHUNK_ENABLED).IsBoolean()) { bool chunkingEnabled = producerConfig.Get(CFG_CHUNK_ENABLED).ToBoolean().Value(); pulsar_producer_configuration_set_chunking_enabled(this->cProducerConfig.get(), chunkingEnabled); diff --git a/src/addon.cc b/src/addon.cc index fa26ae0..3025bce 100644 --- a/src/addon.cc +++ b/src/addon.cc @@ -24,6 +24,7 @@ #include "Consumer.h" #include "Client.h" #include "Reader.h" +#include "CryptoKeyReader.h" #include Napi::Object InitAll(Napi::Env env, Napi::Object exports) { @@ -33,6 +34,7 @@ Napi::Object InitAll(Napi::Env env, Napi::Object exports) { Producer::Init(env, exports); Consumer::Init(env, exports); Reader::Init(env, exports); + CryptoKeyReader::Init(env, exports); return Client::Init(env, exports); } diff --git a/tests/encryption.test.js b/tests/encryption.test.js new file mode 100644 index 0000000..c6aedf8 --- /dev/null +++ b/tests/encryption.test.js @@ -0,0 +1,223 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +const path = require('path'); +const fs = require('fs'); +const Pulsar = require('../index'); + +class MyCryptoKeyReader extends Pulsar.CryptoKeyReader { + constructor(publicKeys, privateKeys) { + super(); + this.publicKeys = publicKeys; + this.privateKeys = privateKeys; + } + + getPublicKey(keyName, _metadata) { + const keyPath = this.publicKeys[keyName]; + if (keyPath) { + try { + const key = fs.readFileSync(keyPath); + return { key, _metadata }; + } catch (e) { + return null; + } + } + return null; + } + + getPrivateKey(keyName, _metadata) { + const keyPath = this.privateKeys[keyName]; + if (keyPath) { + try { + const key = fs.readFileSync(keyPath); + return { key, _metadata }; + } catch (e) { + return null; + } + } + return null; + } +} + +(() => { + describe('Encryption', () => { + let client; + const publicKeyPath = path.join(__dirname, 'certificate/public-key.client-rsa.pem'); + const privateKeyPath = path.join(__dirname, 'certificate/private-key.client-rsa.pem'); + + beforeAll(() => { + client = new Pulsar.Client({ + serviceUrl: 'pulsar://localhost:6650', + operationTimeoutSeconds: 30, + }); + }); + + afterAll(async () => { + await client.close(); + }); + + test('End-to-End Encryption', async () => { + const topic = `persistent://public/default/test-encryption-${Date.now()}`; + + const cryptoKeyReader = new MyCryptoKeyReader( + { 'my-key': publicKeyPath }, + { 'my-key': privateKeyPath }, + ); + + const producer = await client.createProducer({ + topic, + encryptionKeys: ['my-key'], + cryptoKeyReader, + cryptoFailureAction: 'FAIL', + }); + + const consumer = await client.subscribe({ + topic, + subscription: 'sub-encryption', + cryptoKeyReader, + cryptoFailureAction: 'CONSUME', + subscriptionInitialPosition: 'Earliest', + }); + + const msgContent = 'my-secret-message'; + await producer.send({ + data: Buffer.from(msgContent), + }); + + const msg = await consumer.receive(); + expect(msg.getData().toString()).toBe(msgContent); + const encCtx = msg.getEncryptionContext(); + expect(encCtx).not.toBeNull(); + expect(encCtx.isDecryptionFailed).toBe(false); + expect(encCtx.keys).toBeDefined(); + expect(encCtx.keys.length).toBeGreaterThan(0); + expect(encCtx.keys[0].value).toBeInstanceOf(Buffer); + expect(encCtx.param).toBeInstanceOf(Buffer); + expect(encCtx.algorithm).toBe(''); + expect(encCtx.compressionType).toBe('None'); + expect(encCtx.uncompressedMessageSize).toBe(0); + expect(encCtx.batchSize).toBe(1); + + await consumer.acknowledge(msg); + await producer.close(); + await consumer.close(); + }); + + test('End-to-End Encryption with Batching and Compression', async () => { + const topic = `persistent://public/default/test-encryption-batch-compress-${Date.now()}`; + + const cryptoKeyReader = new MyCryptoKeyReader( + { 'my-key': publicKeyPath }, + { 'my-key': privateKeyPath }, + ); + + const producer = await client.createProducer({ + topic, + encryptionKeys: ['my-key'], + cryptoKeyReader, + cryptoFailureAction: 'FAIL', + batchingEnabled: true, + batchingMaxMessages: 10, + batchingMaxPublishDelayMs: 100, + compressionType: 'Zlib', + }); + + const consumer = await client.subscribe({ + topic, + subscription: 'sub-encryption-batch-compress', + cryptoKeyReader, + cryptoFailureAction: 'CONSUME', + subscriptionInitialPosition: 'Earliest', + }); + + const numMessages = 10; + const sendPromises = []; + for (let i = 0; i < numMessages; i += 1) { + sendPromises.push(producer.send({ + data: Buffer.from(`message-${i}`), + })); + } + await Promise.all(sendPromises); + + for (let i = 0; i < numMessages; i += 1) { + const msg = await consumer.receive(); + expect(msg.getData().toString()).toBe(`message-${i}`); + const encCtx = msg.getEncryptionContext(); + expect(encCtx).not.toBeNull(); + expect(encCtx.isDecryptionFailed).toBe(false); + expect(encCtx.keys).toBeDefined(); + expect(encCtx.keys.length).toBeGreaterThan(0); + expect(encCtx.keys[0].value).toBeInstanceOf(Buffer); + expect(encCtx.param).toBeInstanceOf(Buffer); + expect(encCtx.algorithm).toBe(''); + expect(encCtx.compressionType).toBe('Zlib'); + expect(encCtx.uncompressedMessageSize).toBeGreaterThan(0); + expect(encCtx.batchSize).toBe(numMessages); + + await consumer.acknowledge(msg); + } + + await producer.close(); + await consumer.close(); + }); + + test('Decryption Failure', async () => { + const topic = `persistent://public/default/test-decryption-failure-${Date.now()}`; + + const cryptoKeyReader = new MyCryptoKeyReader( + { 'my-key': publicKeyPath }, + { 'my-key': privateKeyPath }, + ); + + const producer = await client.createProducer({ + topic, + encryptionKeys: ['my-key'], + cryptoKeyReader, + cryptoFailureAction: 'FAIL', + }); + + const consumer = await client.subscribe({ + topic, + subscription: 'sub-decryption-failure', + cryptoFailureAction: 'CONSUME', + subscriptionInitialPosition: 'Earliest', + }); + + const msgContent = 'my-secret-message'; + await producer.send({ + data: Buffer.from(msgContent), + }); + + const msg = await consumer.receive(); + expect(msg.getData().toString()).not.toBe(msgContent); + + const encCtx = msg.getEncryptionContext(); + expect(encCtx).not.toBeNull(); + expect(encCtx.isDecryptionFailed).toBe(true); + expect(encCtx.keys).toBeDefined(); + expect(encCtx.keys.length).toBeGreaterThan(0); + expect(encCtx.keys[0].value).toBeInstanceOf(Buffer); + expect(encCtx.param).toBeInstanceOf(Buffer); + + await consumer.acknowledge(msg); + await producer.close(); + await consumer.close(); + }); + }); +})();