SHA256
1
0
forked from pool/aws-crt-cpp
aws-crt-cpp/acc_adapt-new-input-stream-api.patch

226 lines
8.4 KiB
Diff

From 8adb8490fd4f1d1fe65aad01b0a7dda0e52ac596 Mon Sep 17 00:00:00 2001
From: Dengke Tang <dengket@amazon.com>
Date: Wed, 27 Apr 2022 13:07:36 -0700
Subject: [PATCH] adapt new input stream api (#341)
---
include/aws/crt/RefCounted.h | 66 ++++++++++++++++++++++++++++++++++++
include/aws/crt/io/Stream.h | 7 ++--
source/io/Stream.cpp | 16 ++++++---
tests/CMakeLists.txt | 1 +
tests/StreamTest.cpp | 42 +++++++++++++++++++++++
5 files changed, 125 insertions(+), 7 deletions(-)
create mode 100644 include/aws/crt/RefCounted.h
diff --git a/include/aws/crt/RefCounted.h b/include/aws/crt/RefCounted.h
new file mode 100644
index 0000000..811ccc0
--- /dev/null
+++ b/include/aws/crt/RefCounted.h
@@ -0,0 +1,66 @@
+#pragma once
+/**
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ * SPDX-License-Identifier: Apache-2.0.
+ */
+
+#include <memory>
+#include <mutex>
+
+namespace Aws
+{
+ namespace Crt
+ {
+ /**
+ * Inherit from RefCounted to allow reference-counting from C code,
+ * which will keep your C++ object alive as long as the count is non-zero.
+ *
+ * A class must inherit from RefCounted and std::enable_shared_from_this.
+ * Your class must always be placed inside a shared_ptr (do not create on
+ * the stack, or keep on the heap as a raw pointer).
+ *
+ * Whenever the reference count goes from 0 to 1 a shared_ptr is created
+ * internally to keep this object alive. Whenever the reference count
+ * goes from 1 to 0 the internal shared_ptr is reset, allowing this object
+ * to be destroyed.
+ */
+ template <class T> class RefCounted
+ {
+ protected:
+ RefCounted() {}
+ ~RefCounted() {}
+
+ void AcquireRef()
+ {
+ m_mutex.lock();
+ if (m_count++ == 0)
+ {
+ m_strongPtr = static_cast<T *>(this)->shared_from_this();
+ }
+ m_mutex.unlock();
+ }
+
+ void ReleaseRef()
+ {
+ // Move contents of m_strongPtr to a temp so that this
+ // object can't be destroyed until the function exits.
+ std::shared_ptr<T> tmpStrongPtr;
+
+ m_mutex.lock();
+ if (m_count-- == 1)
+ {
+ std::swap(m_strongPtr, tmpStrongPtr);
+ }
+ m_mutex.unlock();
+ }
+
+ private:
+ RefCounted(const RefCounted &) = delete;
+ RefCounted &operator=(const RefCounted &) = delete;
+
+ size_t m_count = 0;
+ std::shared_ptr<T> m_strongPtr;
+ std::mutex m_mutex;
+ };
+ } // namespace Crt
+} // namespace Aws
diff --git a/include/aws/crt/io/Stream.h b/include/aws/crt/io/Stream.h
index 323747b..697461e 100644
--- a/include/aws/crt/io/Stream.h
+++ b/include/aws/crt/io/Stream.h
@@ -5,6 +5,7 @@
*/
#include <aws/crt/Exports.h>
+#include <aws/crt/RefCounted.h>
#include <aws/crt/Types.h>
#include <aws/io/stream.h>
@@ -35,7 +36,8 @@ namespace Aws
* aws_input_stream interface. To use, create a subclass of InputStream and define the abstract
* functions.
*/
- class AWS_CRT_CPP_API InputStream
+ class AWS_CRT_CPP_API InputStream : public std::enable_shared_from_this<InputStream>,
+ public RefCounted<InputStream>
{
public:
virtual ~InputStream();
@@ -136,7 +138,8 @@ namespace Aws
static int s_Read(aws_input_stream *stream, aws_byte_buf *dest);
static int s_GetStatus(aws_input_stream *stream, aws_stream_status *status);
static int s_GetLength(struct aws_input_stream *stream, int64_t *out_length);
- static void s_Destroy(struct aws_input_stream *stream);
+ static void s_Acquire(aws_input_stream *stream);
+ static void s_Release(aws_input_stream *stream);
static aws_input_stream_vtable s_vtable;
};
diff --git a/source/io/Stream.cpp b/source/io/Stream.cpp
index 1d6a3a4..a3beb28 100644
--- a/source/io/Stream.cpp
+++ b/source/io/Stream.cpp
@@ -68,10 +68,16 @@ namespace Aws
return AWS_OP_ERR;
}
- void InputStream::s_Destroy(struct aws_input_stream *stream)
+ void InputStream::s_Acquire(aws_input_stream *stream)
{
- (void)stream;
- // DO NOTHING, let the C++ destructor handle it.
+ auto impl = static_cast<InputStream *>(stream->impl);
+ impl->AcquireRef();
+ }
+
+ void InputStream::s_Release(aws_input_stream *stream)
+ {
+ auto impl = static_cast<InputStream *>(stream->impl);
+ impl->ReleaseRef();
}
aws_input_stream_vtable InputStream::s_vtable = {
@@ -79,7 +85,8 @@ namespace Aws
InputStream::s_Read,
InputStream::s_GetStatus,
InputStream::s_GetLength,
- InputStream::s_Destroy,
+ InputStream::s_Acquire,
+ InputStream::s_Release,
};
InputStream::InputStream(Aws::Crt::Allocator *allocator)
@@ -88,7 +95,6 @@ namespace Aws
AWS_ZERO_STRUCT(m_underlying_stream);
m_underlying_stream.impl = this;
- m_underlying_stream.allocator = m_allocator;
m_underlying_stream.vtable = &s_vtable;
}
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 613f487..057eaeb 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -50,6 +50,7 @@ add_test_case(StreamTestRead)
add_test_case(StreamTestReadEmpty)
add_test_case(StreamTestSeekBegin)
add_test_case(StreamTestSeekEnd)
+add_test_case(StreamTestRefcount)
add_test_case(TestCredentialsConstruction)
add_test_case(TestProviderStaticGet)
add_test_case(TestProviderEnvironmentGet)
diff --git a/tests/StreamTest.cpp b/tests/StreamTest.cpp
index a95b090..4befc8f 100644
--- a/tests/StreamTest.cpp
+++ b/tests/StreamTest.cpp
@@ -168,3 +168,45 @@ static int s_StreamTestSeekEnd(struct aws_allocator *allocator, void *ctx)
}
AWS_TEST_CASE(StreamTestSeekEnd, s_StreamTestSeekEnd)
+
+/* Test that C/C++ has the refcount on the stream will keep the object alive */
+static int s_StreamTestRefcount(struct aws_allocator *allocator, void *ctx)
+{
+ (void)ctx;
+ {
+ Aws::Crt::ApiHandle apiHandle(allocator);
+ aws_input_stream *c_stream = NULL;
+ {
+ auto stringStream = Aws::Crt::MakeShared<Aws::Crt::StringStream>(allocator, STREAM_CONTENTS);
+ /* Make a shared pointer for stream as the C side will ONLY interact with the shared pointer initialed
+ * stream */
+ std::shared_ptr<Aws::Crt::Io::StdIOStreamInputStream> wrappedStream =
+ Aws::Crt::MakeShared<Aws::Crt::Io::StdIOStreamInputStream>(allocator, stringStream, allocator);
+
+ /* C side keep a reference on it. */
+ aws_input_stream_acquire(wrappedStream->GetUnderlyingStream());
+ /* C side release a reference on it. So that it drops to zero from the C point of view, but as C++ still
+ * holding it, it's still valid to be used */
+ aws_input_stream_release(wrappedStream->GetUnderlyingStream());
+ /* Test that you can still use it correctly */
+ int64_t length = 0;
+ ASSERT_SUCCESS(aws_input_stream_get_length(wrappedStream->GetUnderlyingStream(), &length));
+ ASSERT_TRUE(length == strlen(STREAM_CONTENTS));
+
+ /* C side keep a reference on it. */
+ aws_input_stream_acquire(wrappedStream->GetUnderlyingStream());
+ c_stream = wrappedStream->GetUnderlyingStream();
+ }
+ /* C++ object is now out of scope, but as C side still holding the reference to it, it still avaliable to be
+ * invoked from C */
+ int64_t length = 0;
+ ASSERT_SUCCESS(aws_input_stream_get_length(c_stream, &length));
+ ASSERT_TRUE(length == strlen(STREAM_CONTENTS));
+ /* Release the refcount from C to clean up resource without leak */
+ aws_input_stream_release(c_stream);
+ }
+
+ return AWS_OP_SUCCESS;
+}
+
+AWS_TEST_CASE(StreamTestRefcount, s_StreamTestRefcount)
--
2.36.0