forked from pool/aws-crt-cpp
226 lines
8.4 KiB
Diff
226 lines
8.4 KiB
Diff
|
From 8adb8490fd4f1d1fe65aad01b0a7dda0e52ac596 Mon Sep 17 00:00:00 2001
|
||
|
From: Dengke Tang <dengket@amazon.com>
|
||
|
Date: Wed, 27 Apr 2022 13:07:36 -0700
|
||
|
Subject: [PATCH] adapt new input stream api (#341)
|
||
|
|
||
|
---
|
||
|
include/aws/crt/RefCounted.h | 66 ++++++++++++++++++++++++++++++++++++
|
||
|
include/aws/crt/io/Stream.h | 7 ++--
|
||
|
source/io/Stream.cpp | 16 ++++++---
|
||
|
tests/CMakeLists.txt | 1 +
|
||
|
tests/StreamTest.cpp | 42 +++++++++++++++++++++++
|
||
|
5 files changed, 125 insertions(+), 7 deletions(-)
|
||
|
create mode 100644 include/aws/crt/RefCounted.h
|
||
|
|
||
|
diff --git a/include/aws/crt/RefCounted.h b/include/aws/crt/RefCounted.h
|
||
|
new file mode 100644
|
||
|
index 0000000..811ccc0
|
||
|
--- /dev/null
|
||
|
+++ b/include/aws/crt/RefCounted.h
|
||
|
@@ -0,0 +1,66 @@
|
||
|
+#pragma once
|
||
|
+/**
|
||
|
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||
|
+ * SPDX-License-Identifier: Apache-2.0.
|
||
|
+ */
|
||
|
+
|
||
|
+#include <memory>
|
||
|
+#include <mutex>
|
||
|
+
|
||
|
+namespace Aws
|
||
|
+{
|
||
|
+ namespace Crt
|
||
|
+ {
|
||
|
+ /**
|
||
|
+ * Inherit from RefCounted to allow reference-counting from C code,
|
||
|
+ * which will keep your C++ object alive as long as the count is non-zero.
|
||
|
+ *
|
||
|
+ * A class must inherit from RefCounted and std::enable_shared_from_this.
|
||
|
+ * Your class must always be placed inside a shared_ptr (do not create on
|
||
|
+ * the stack, or keep on the heap as a raw pointer).
|
||
|
+ *
|
||
|
+ * Whenever the reference count goes from 0 to 1 a shared_ptr is created
|
||
|
+ * internally to keep this object alive. Whenever the reference count
|
||
|
+ * goes from 1 to 0 the internal shared_ptr is reset, allowing this object
|
||
|
+ * to be destroyed.
|
||
|
+ */
|
||
|
+ template <class T> class RefCounted
|
||
|
+ {
|
||
|
+ protected:
|
||
|
+ RefCounted() {}
|
||
|
+ ~RefCounted() {}
|
||
|
+
|
||
|
+ void AcquireRef()
|
||
|
+ {
|
||
|
+ m_mutex.lock();
|
||
|
+ if (m_count++ == 0)
|
||
|
+ {
|
||
|
+ m_strongPtr = static_cast<T *>(this)->shared_from_this();
|
||
|
+ }
|
||
|
+ m_mutex.unlock();
|
||
|
+ }
|
||
|
+
|
||
|
+ void ReleaseRef()
|
||
|
+ {
|
||
|
+ // Move contents of m_strongPtr to a temp so that this
|
||
|
+ // object can't be destroyed until the function exits.
|
||
|
+ std::shared_ptr<T> tmpStrongPtr;
|
||
|
+
|
||
|
+ m_mutex.lock();
|
||
|
+ if (m_count-- == 1)
|
||
|
+ {
|
||
|
+ std::swap(m_strongPtr, tmpStrongPtr);
|
||
|
+ }
|
||
|
+ m_mutex.unlock();
|
||
|
+ }
|
||
|
+
|
||
|
+ private:
|
||
|
+ RefCounted(const RefCounted &) = delete;
|
||
|
+ RefCounted &operator=(const RefCounted &) = delete;
|
||
|
+
|
||
|
+ size_t m_count = 0;
|
||
|
+ std::shared_ptr<T> m_strongPtr;
|
||
|
+ std::mutex m_mutex;
|
||
|
+ };
|
||
|
+ } // namespace Crt
|
||
|
+} // namespace Aws
|
||
|
diff --git a/include/aws/crt/io/Stream.h b/include/aws/crt/io/Stream.h
|
||
|
index 323747b..697461e 100644
|
||
|
--- a/include/aws/crt/io/Stream.h
|
||
|
+++ b/include/aws/crt/io/Stream.h
|
||
|
@@ -5,6 +5,7 @@
|
||
|
*/
|
||
|
|
||
|
#include <aws/crt/Exports.h>
|
||
|
+#include <aws/crt/RefCounted.h>
|
||
|
#include <aws/crt/Types.h>
|
||
|
#include <aws/io/stream.h>
|
||
|
|
||
|
@@ -35,7 +36,8 @@ namespace Aws
|
||
|
* aws_input_stream interface. To use, create a subclass of InputStream and define the abstract
|
||
|
* functions.
|
||
|
*/
|
||
|
- class AWS_CRT_CPP_API InputStream
|
||
|
+ class AWS_CRT_CPP_API InputStream : public std::enable_shared_from_this<InputStream>,
|
||
|
+ public RefCounted<InputStream>
|
||
|
{
|
||
|
public:
|
||
|
virtual ~InputStream();
|
||
|
@@ -136,7 +138,8 @@ namespace Aws
|
||
|
static int s_Read(aws_input_stream *stream, aws_byte_buf *dest);
|
||
|
static int s_GetStatus(aws_input_stream *stream, aws_stream_status *status);
|
||
|
static int s_GetLength(struct aws_input_stream *stream, int64_t *out_length);
|
||
|
- static void s_Destroy(struct aws_input_stream *stream);
|
||
|
+ static void s_Acquire(aws_input_stream *stream);
|
||
|
+ static void s_Release(aws_input_stream *stream);
|
||
|
|
||
|
static aws_input_stream_vtable s_vtable;
|
||
|
};
|
||
|
diff --git a/source/io/Stream.cpp b/source/io/Stream.cpp
|
||
|
index 1d6a3a4..a3beb28 100644
|
||
|
--- a/source/io/Stream.cpp
|
||
|
+++ b/source/io/Stream.cpp
|
||
|
@@ -68,10 +68,16 @@ namespace Aws
|
||
|
return AWS_OP_ERR;
|
||
|
}
|
||
|
|
||
|
- void InputStream::s_Destroy(struct aws_input_stream *stream)
|
||
|
+ void InputStream::s_Acquire(aws_input_stream *stream)
|
||
|
{
|
||
|
- (void)stream;
|
||
|
- // DO NOTHING, let the C++ destructor handle it.
|
||
|
+ auto impl = static_cast<InputStream *>(stream->impl);
|
||
|
+ impl->AcquireRef();
|
||
|
+ }
|
||
|
+
|
||
|
+ void InputStream::s_Release(aws_input_stream *stream)
|
||
|
+ {
|
||
|
+ auto impl = static_cast<InputStream *>(stream->impl);
|
||
|
+ impl->ReleaseRef();
|
||
|
}
|
||
|
|
||
|
aws_input_stream_vtable InputStream::s_vtable = {
|
||
|
@@ -79,7 +85,8 @@ namespace Aws
|
||
|
InputStream::s_Read,
|
||
|
InputStream::s_GetStatus,
|
||
|
InputStream::s_GetLength,
|
||
|
- InputStream::s_Destroy,
|
||
|
+ InputStream::s_Acquire,
|
||
|
+ InputStream::s_Release,
|
||
|
};
|
||
|
|
||
|
InputStream::InputStream(Aws::Crt::Allocator *allocator)
|
||
|
@@ -88,7 +95,6 @@ namespace Aws
|
||
|
AWS_ZERO_STRUCT(m_underlying_stream);
|
||
|
|
||
|
m_underlying_stream.impl = this;
|
||
|
- m_underlying_stream.allocator = m_allocator;
|
||
|
m_underlying_stream.vtable = &s_vtable;
|
||
|
}
|
||
|
|
||
|
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
|
||
|
index 613f487..057eaeb 100644
|
||
|
--- a/tests/CMakeLists.txt
|
||
|
+++ b/tests/CMakeLists.txt
|
||
|
@@ -50,6 +50,7 @@ add_test_case(StreamTestRead)
|
||
|
add_test_case(StreamTestReadEmpty)
|
||
|
add_test_case(StreamTestSeekBegin)
|
||
|
add_test_case(StreamTestSeekEnd)
|
||
|
+add_test_case(StreamTestRefcount)
|
||
|
add_test_case(TestCredentialsConstruction)
|
||
|
add_test_case(TestProviderStaticGet)
|
||
|
add_test_case(TestProviderEnvironmentGet)
|
||
|
diff --git a/tests/StreamTest.cpp b/tests/StreamTest.cpp
|
||
|
index a95b090..4befc8f 100644
|
||
|
--- a/tests/StreamTest.cpp
|
||
|
+++ b/tests/StreamTest.cpp
|
||
|
@@ -168,3 +168,45 @@ static int s_StreamTestSeekEnd(struct aws_allocator *allocator, void *ctx)
|
||
|
}
|
||
|
|
||
|
AWS_TEST_CASE(StreamTestSeekEnd, s_StreamTestSeekEnd)
|
||
|
+
|
||
|
+/* Test that C/C++ has the refcount on the stream will keep the object alive */
|
||
|
+static int s_StreamTestRefcount(struct aws_allocator *allocator, void *ctx)
|
||
|
+{
|
||
|
+ (void)ctx;
|
||
|
+ {
|
||
|
+ Aws::Crt::ApiHandle apiHandle(allocator);
|
||
|
+ aws_input_stream *c_stream = NULL;
|
||
|
+ {
|
||
|
+ auto stringStream = Aws::Crt::MakeShared<Aws::Crt::StringStream>(allocator, STREAM_CONTENTS);
|
||
|
+ /* Make a shared pointer for stream as the C side will ONLY interact with the shared pointer initialed
|
||
|
+ * stream */
|
||
|
+ std::shared_ptr<Aws::Crt::Io::StdIOStreamInputStream> wrappedStream =
|
||
|
+ Aws::Crt::MakeShared<Aws::Crt::Io::StdIOStreamInputStream>(allocator, stringStream, allocator);
|
||
|
+
|
||
|
+ /* C side keep a reference on it. */
|
||
|
+ aws_input_stream_acquire(wrappedStream->GetUnderlyingStream());
|
||
|
+ /* C side release a reference on it. So that it drops to zero from the C point of view, but as C++ still
|
||
|
+ * holding it, it's still valid to be used */
|
||
|
+ aws_input_stream_release(wrappedStream->GetUnderlyingStream());
|
||
|
+ /* Test that you can still use it correctly */
|
||
|
+ int64_t length = 0;
|
||
|
+ ASSERT_SUCCESS(aws_input_stream_get_length(wrappedStream->GetUnderlyingStream(), &length));
|
||
|
+ ASSERT_TRUE(length == strlen(STREAM_CONTENTS));
|
||
|
+
|
||
|
+ /* C side keep a reference on it. */
|
||
|
+ aws_input_stream_acquire(wrappedStream->GetUnderlyingStream());
|
||
|
+ c_stream = wrappedStream->GetUnderlyingStream();
|
||
|
+ }
|
||
|
+ /* C++ object is now out of scope, but as C side still holding the reference to it, it still avaliable to be
|
||
|
+ * invoked from C */
|
||
|
+ int64_t length = 0;
|
||
|
+ ASSERT_SUCCESS(aws_input_stream_get_length(c_stream, &length));
|
||
|
+ ASSERT_TRUE(length == strlen(STREAM_CONTENTS));
|
||
|
+ /* Release the refcount from C to clean up resource without leak */
|
||
|
+ aws_input_stream_release(c_stream);
|
||
|
+ }
|
||
|
+
|
||
|
+ return AWS_OP_SUCCESS;
|
||
|
+}
|
||
|
+
|
||
|
+AWS_TEST_CASE(StreamTestRefcount, s_StreamTestRefcount)
|
||
|
--
|
||
|
2.36.0
|
||
|
|