From 8adb8490fd4f1d1fe65aad01b0a7dda0e52ac596 Mon Sep 17 00:00:00 2001 From: Dengke Tang Date: Wed, 27 Apr 2022 13:07:36 -0700 Subject: [PATCH] adapt new input stream api (#341) --- include/aws/crt/RefCounted.h | 66 ++++++++++++++++++++++++++++++++++++ include/aws/crt/io/Stream.h | 7 ++-- source/io/Stream.cpp | 16 ++++++--- tests/CMakeLists.txt | 1 + tests/StreamTest.cpp | 42 +++++++++++++++++++++++ 5 files changed, 125 insertions(+), 7 deletions(-) create mode 100644 include/aws/crt/RefCounted.h diff --git a/include/aws/crt/RefCounted.h b/include/aws/crt/RefCounted.h new file mode 100644 index 0000000..811ccc0 --- /dev/null +++ b/include/aws/crt/RefCounted.h @@ -0,0 +1,66 @@ +#pragma once +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include + +namespace Aws +{ + namespace Crt + { + /** + * Inherit from RefCounted to allow reference-counting from C code, + * which will keep your C++ object alive as long as the count is non-zero. + * + * A class must inherit from RefCounted and std::enable_shared_from_this. + * Your class must always be placed inside a shared_ptr (do not create on + * the stack, or keep on the heap as a raw pointer). + * + * Whenever the reference count goes from 0 to 1 a shared_ptr is created + * internally to keep this object alive. Whenever the reference count + * goes from 1 to 0 the internal shared_ptr is reset, allowing this object + * to be destroyed. + */ + template class RefCounted + { + protected: + RefCounted() {} + ~RefCounted() {} + + void AcquireRef() + { + m_mutex.lock(); + if (m_count++ == 0) + { + m_strongPtr = static_cast(this)->shared_from_this(); + } + m_mutex.unlock(); + } + + void ReleaseRef() + { + // Move contents of m_strongPtr to a temp so that this + // object can't be destroyed until the function exits. + std::shared_ptr tmpStrongPtr; + + m_mutex.lock(); + if (m_count-- == 1) + { + std::swap(m_strongPtr, tmpStrongPtr); + } + m_mutex.unlock(); + } + + private: + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + size_t m_count = 0; + std::shared_ptr m_strongPtr; + std::mutex m_mutex; + }; + } // namespace Crt +} // namespace Aws diff --git a/include/aws/crt/io/Stream.h b/include/aws/crt/io/Stream.h index 323747b..697461e 100644 --- a/include/aws/crt/io/Stream.h +++ b/include/aws/crt/io/Stream.h @@ -5,6 +5,7 @@ */ #include +#include #include #include @@ -35,7 +36,8 @@ namespace Aws * aws_input_stream interface. To use, create a subclass of InputStream and define the abstract * functions. */ - class AWS_CRT_CPP_API InputStream + class AWS_CRT_CPP_API InputStream : public std::enable_shared_from_this, + public RefCounted { public: virtual ~InputStream(); @@ -136,7 +138,8 @@ namespace Aws static int s_Read(aws_input_stream *stream, aws_byte_buf *dest); static int s_GetStatus(aws_input_stream *stream, aws_stream_status *status); static int s_GetLength(struct aws_input_stream *stream, int64_t *out_length); - static void s_Destroy(struct aws_input_stream *stream); + static void s_Acquire(aws_input_stream *stream); + static void s_Release(aws_input_stream *stream); static aws_input_stream_vtable s_vtable; }; diff --git a/source/io/Stream.cpp b/source/io/Stream.cpp index 1d6a3a4..a3beb28 100644 --- a/source/io/Stream.cpp +++ b/source/io/Stream.cpp @@ -68,10 +68,16 @@ namespace Aws return AWS_OP_ERR; } - void InputStream::s_Destroy(struct aws_input_stream *stream) + void InputStream::s_Acquire(aws_input_stream *stream) { - (void)stream; - // DO NOTHING, let the C++ destructor handle it. + auto impl = static_cast(stream->impl); + impl->AcquireRef(); + } + + void InputStream::s_Release(aws_input_stream *stream) + { + auto impl = static_cast(stream->impl); + impl->ReleaseRef(); } aws_input_stream_vtable InputStream::s_vtable = { @@ -79,7 +85,8 @@ namespace Aws InputStream::s_Read, InputStream::s_GetStatus, InputStream::s_GetLength, - InputStream::s_Destroy, + InputStream::s_Acquire, + InputStream::s_Release, }; InputStream::InputStream(Aws::Crt::Allocator *allocator) @@ -88,7 +95,6 @@ namespace Aws AWS_ZERO_STRUCT(m_underlying_stream); m_underlying_stream.impl = this; - m_underlying_stream.allocator = m_allocator; m_underlying_stream.vtable = &s_vtable; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 613f487..057eaeb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -50,6 +50,7 @@ add_test_case(StreamTestRead) add_test_case(StreamTestReadEmpty) add_test_case(StreamTestSeekBegin) add_test_case(StreamTestSeekEnd) +add_test_case(StreamTestRefcount) add_test_case(TestCredentialsConstruction) add_test_case(TestProviderStaticGet) add_test_case(TestProviderEnvironmentGet) diff --git a/tests/StreamTest.cpp b/tests/StreamTest.cpp index a95b090..4befc8f 100644 --- a/tests/StreamTest.cpp +++ b/tests/StreamTest.cpp @@ -168,3 +168,45 @@ static int s_StreamTestSeekEnd(struct aws_allocator *allocator, void *ctx) } AWS_TEST_CASE(StreamTestSeekEnd, s_StreamTestSeekEnd) + +/* Test that C/C++ has the refcount on the stream will keep the object alive */ +static int s_StreamTestRefcount(struct aws_allocator *allocator, void *ctx) +{ + (void)ctx; + { + Aws::Crt::ApiHandle apiHandle(allocator); + aws_input_stream *c_stream = NULL; + { + auto stringStream = Aws::Crt::MakeShared(allocator, STREAM_CONTENTS); + /* Make a shared pointer for stream as the C side will ONLY interact with the shared pointer initialed + * stream */ + std::shared_ptr wrappedStream = + Aws::Crt::MakeShared(allocator, stringStream, allocator); + + /* C side keep a reference on it. */ + aws_input_stream_acquire(wrappedStream->GetUnderlyingStream()); + /* C side release a reference on it. So that it drops to zero from the C point of view, but as C++ still + * holding it, it's still valid to be used */ + aws_input_stream_release(wrappedStream->GetUnderlyingStream()); + /* Test that you can still use it correctly */ + int64_t length = 0; + ASSERT_SUCCESS(aws_input_stream_get_length(wrappedStream->GetUnderlyingStream(), &length)); + ASSERT_TRUE(length == strlen(STREAM_CONTENTS)); + + /* C side keep a reference on it. */ + aws_input_stream_acquire(wrappedStream->GetUnderlyingStream()); + c_stream = wrappedStream->GetUnderlyingStream(); + } + /* C++ object is now out of scope, but as C side still holding the reference to it, it still avaliable to be + * invoked from C */ + int64_t length = 0; + ASSERT_SUCCESS(aws_input_stream_get_length(c_stream, &length)); + ASSERT_TRUE(length == strlen(STREAM_CONTENTS)); + /* Release the refcount from C to clean up resource without leak */ + aws_input_stream_release(c_stream); + } + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(StreamTestRefcount, s_StreamTestRefcount) -- 2.36.0