|
14 | 14 | #define ORC_RT_WRAPPERFUNCTION_H
|
15 | 15 |
|
16 | 16 | #include "orc-rt-c/WrapperFunction.h"
|
| 17 | +#include "orc-rt/Error.h" |
| 18 | +#include "orc-rt/bind.h" |
17 | 19 |
|
18 | 20 | #include <utility>
|
19 | 21 |
|
@@ -98,6 +100,164 @@ class WrapperFunctionBuffer {
|
98 | 100 | orc_rt_WrapperFunctionBuffer B;
|
99 | 101 | };
|
100 | 102 |
|
| 103 | +namespace detail { |
| 104 | + |
| 105 | +template <typename C> |
| 106 | +struct WFCallableTraits |
| 107 | + : public WFCallableTraits< |
| 108 | + decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> { |
| 109 | +}; |
| 110 | + |
| 111 | +template <typename RetT> struct WFCallableTraits<RetT()> { |
| 112 | + typedef void HeadArgType; |
| 113 | +}; |
| 114 | + |
| 115 | +template <typename RetT, typename ArgT, typename... ArgTs> |
| 116 | +struct WFCallableTraits<RetT(ArgT, ArgTs...)> { |
| 117 | + typedef ArgT HeadArgType; |
| 118 | + typedef std::tuple<ArgTs...> TailArgTuple; |
| 119 | +}; |
| 120 | + |
| 121 | +template <typename ClassT, typename RetT, typename... ArgTs> |
| 122 | +struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)> |
| 123 | + : public WFCallableTraits<RetT(ArgTs...)> {}; |
| 124 | + |
| 125 | +template <typename ClassT, typename RetT, typename... ArgTs> |
| 126 | +struct WFCallableTraits<RetT (ClassT::*)(ArgTs...) const> |
| 127 | + : public WFCallableTraits<RetT(ArgTs...)> {}; |
| 128 | + |
| 129 | +template <typename Serializer> class StructuredYieldBase { |
| 130 | +public: |
| 131 | + StructuredYieldBase(orc_rt_SessionRef Session, void *CallCtx, |
| 132 | + orc_rt_WrapperFunctionReturn Return, Serializer &&S) |
| 133 | + : Session(Session), CallCtx(CallCtx), Return(Return), |
| 134 | + S(std::forward<Serializer>(S)) {} |
| 135 | + |
| 136 | +protected: |
| 137 | + orc_rt_SessionRef Session; |
| 138 | + void *CallCtx; |
| 139 | + orc_rt_WrapperFunctionReturn Return; |
| 140 | + std::decay_t<Serializer> S; |
| 141 | +}; |
| 142 | + |
| 143 | +template <typename RetT, typename Serializer> |
| 144 | +class StructuredYield : public StructuredYieldBase<Serializer> { |
| 145 | +public: |
| 146 | + using StructuredYieldBase<Serializer>::StructuredYieldBase; |
| 147 | + void operator()(RetT &&R) { |
| 148 | + if (auto ResultBytes = this->S.resultSerializer()(std::forward<RetT>(R))) |
| 149 | + this->Return(this->Session, this->CallCtx, ResultBytes->release()); |
| 150 | + else |
| 151 | + this->Return(this->Session, this->CallCtx, |
| 152 | + WrapperFunctionBuffer::createOutOfBandError( |
| 153 | + "Could not serialize wrapper function result data") |
| 154 | + .release()); |
| 155 | + } |
| 156 | +}; |
| 157 | + |
| 158 | +template <typename Serializer> |
| 159 | +class StructuredYield<void, Serializer> |
| 160 | + : public StructuredYieldBase<Serializer> { |
| 161 | +public: |
| 162 | + using StructuredYieldBase<Serializer>::StructuredYieldBase; |
| 163 | + void operator()() { |
| 164 | + this->Return(this->Session, this->CallCtx, |
| 165 | + WrapperFunctionBuffer().release()); |
| 166 | + } |
| 167 | +}; |
| 168 | + |
| 169 | +template <typename T, typename Serializer> struct ResultDeserializer; |
| 170 | + |
| 171 | +template <typename T, typename Serializer> |
| 172 | +struct ResultDeserializer<Expected<T>, Serializer> { |
| 173 | + static Expected<T> deserialize(WrapperFunctionBuffer ResultBytes, |
| 174 | + Serializer &S) { |
| 175 | + T Val; |
| 176 | + if (S.resultDeserializer()(ResultBytes, Val)) |
| 177 | + return std::move(Val); |
| 178 | + else |
| 179 | + return make_error<StringError>("Could not deserialize result"); |
| 180 | + } |
| 181 | +}; |
| 182 | + |
| 183 | +template <typename Serializer> struct ResultDeserializer<Error, Serializer> { |
| 184 | + static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) { |
| 185 | + assert(ResultBytes.empty()); |
| 186 | + return Error::success(); |
| 187 | + } |
| 188 | +}; |
| 189 | + |
| 190 | +} // namespace detail |
| 191 | + |
| 192 | +/// Provides call and handle utilities to simplify writing and invocation of |
| 193 | +/// wrapper functions in C++. |
| 194 | +struct WrapperFunction { |
| 195 | + |
| 196 | + /// Make a call to a wrapper function. |
| 197 | + /// |
| 198 | + /// This utility serializes and deserializes arguments and return values |
| 199 | + /// (using the given Serializer), and calls the wrapper function via the |
| 200 | + /// given Caller object. |
| 201 | + template <typename Caller, typename Serializer, typename ResultHandler, |
| 202 | + typename... ArgTs> |
| 203 | + static void call(Caller &&C, Serializer &&S, ResultHandler &&RH, |
| 204 | + ArgTs &&...Args) { |
| 205 | + typedef detail::WFCallableTraits<ResultHandler> ResultHandlerTraits; |
| 206 | + static_assert( |
| 207 | + std::tuple_size_v<typename ResultHandlerTraits::TailArgTuple> == 0, |
| 208 | + "Expected one argument to result-handler"); |
| 209 | + typedef typename ResultHandlerTraits::HeadArgType ResultType; |
| 210 | + |
| 211 | + if (auto ArgBytes = S.argumentSerializer()(std::forward<ArgTs>(Args)...)) { |
| 212 | + C( |
| 213 | + [RH = std::move(RH), |
| 214 | + S = std::move(S)](orc_rt_SessionRef Session, |
| 215 | + WrapperFunctionBuffer ResultBytes) mutable { |
| 216 | + if (const char *ErrMsg = ResultBytes.getOutOfBandError()) |
| 217 | + RH(make_error<StringError>(ErrMsg)); |
| 218 | + else |
| 219 | + RH(detail::ResultDeserializer< |
| 220 | + ResultType, Serializer>::deserialize(std::move(ResultBytes), |
| 221 | + S)); |
| 222 | + }, |
| 223 | + std::move(*ArgBytes)); |
| 224 | + } else |
| 225 | + RH(make_error<StringError>( |
| 226 | + "Could not serialize wrapper function call arguments")); |
| 227 | + } |
| 228 | + |
| 229 | + /// Simplifies implementation of wrapper functions in C++. |
| 230 | + /// |
| 231 | + /// This utility deserializes and serializes arguments and return values |
| 232 | + /// (using the given Serializer), and calls the given handler. |
| 233 | + template <typename Serializer, typename Handler> |
| 234 | + static void handle(orc_rt_SessionRef Session, void *CallCtx, |
| 235 | + orc_rt_WrapperFunctionReturn Return, |
| 236 | + WrapperFunctionBuffer ArgBytes, Serializer &&S, |
| 237 | + Handler &&H) { |
| 238 | + typedef detail::WFCallableTraits<Handler> HandlerTraits; |
| 239 | + typedef typename HandlerTraits::HeadArgType Yield; |
| 240 | + typedef typename HandlerTraits::TailArgTuple ArgTuple; |
| 241 | + typedef typename detail::WFCallableTraits<Yield>::HeadArgType RetType; |
| 242 | + |
| 243 | + if (ArgBytes.getOutOfBandError()) |
| 244 | + return Return(Session, CallCtx, ArgBytes.release()); |
| 245 | + |
| 246 | + ArgTuple Args; |
| 247 | + if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)), |
| 248 | + Args)) |
| 249 | + std::apply(bind_front(std::forward<Handler>(H), |
| 250 | + detail::StructuredYield<RetType, Serializer>( |
| 251 | + Session, CallCtx, Return, std::move(S))), |
| 252 | + std::move(Args)); |
| 253 | + else |
| 254 | + Return(Session, CallCtx, |
| 255 | + WrapperFunctionBuffer::createOutOfBandError( |
| 256 | + "Could not deserialize wrapper function arg data") |
| 257 | + .release()); |
| 258 | + } |
| 259 | +}; |
| 260 | + |
101 | 261 | } // namespace orc_rt
|
102 | 262 |
|
103 | 263 | #endif // ORC_RT_WRAPPERFUNCTION_H
|
0 commit comments