Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/clj/coffi/ffi.clj
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,20 @@
`(or ~sym mem/null)

(mem/primitive-type type)
`(mem/serialize* ~sym ~type-sym ~arena)
(if (get-method mem/serialize* type)
`(~(get-method mem/serialize* type) ~sym ~type-sym ~arena)
`(mem/serialize* ~sym ~type-sym ~arena))

;; type is a valid ::mem/type at macroexpansion time, which should make it safe to use it in size-of and align-of. however, the spec is not perfect, and c-layout might not be defined for type at macroexpansion time so we'll conditionally try to obtain size and align of the type to be safe
:else
`(let [alloc# (mem/alloc-instance ~type-sym)]
(mem/serialize-into ~sym ~type-sym alloc# ~arena)
alloc#))
(let [alloc-sym (with-meta (gensym "alloc") {:tag 'java.lang.foreign.MemorySegment})
size-of (if-some [v (try (mem/size-of type) (catch Exception _ nil))] v `(mem/size-of ~type))
align-of (if-some [v (try (mem/align-of type) (catch Exception _ nil))] v `(mem/align-of ~type))]
`(let [~alloc-sym (mem/alloc ~size-of ~align-of ~arena)]
~(if (get-method mem/generate-serialize type)
(mem/generate-serialize type sym 0 alloc-sym)
`(mem/serialize-into ~sym ~type ~alloc-sym ~arena))
~alloc-sym)))
(list sym)))

arg-serializers
Expand Down Expand Up @@ -359,9 +367,13 @@
wrap-serialize))

deserialize-prim (fn [expr]
`(mem/deserialize* ~expr ~ret-type-sym))
(if (get-method mem/deserialize* ret-type)
`(~(get-method mem/deserialize* ret-type) ~expr ~ret-type-sym)
`(mem/deserialize* ~expr ~ret-type-sym)))
deserialize-segment (fn [expr]
`(mem/deserialize-from ~expr ~ret-type-sym))
(if (get-method mem/deserialize-from ret-type)
`(~(get-method mem/deserialize-from ret-type) ~expr ~ret-type-sym)
`(mem/deserialize-from ~expr ~ret-type-sym)))
deserialize-ret (fn [expr]
(cond
(and (or (mem/primitive? ret-type)
Expand All @@ -372,11 +384,16 @@
(mem/primitive-type ret-type)
(deserialize-prim expr)

(get-method mem/generate-deserialize ret-type)
(let [return-segment (with-meta (gensym "return-segment") {:tag 'java.lang.foreign.MemorySegment})]
`(let [~return-segment ~expr]
~(mem/generate-deserialize ret-type 0 return-segment)))

:else
(deserialize-segment expr)))

wrap-arena (fn [expr]
`(with-open [~arena (mem/confined-arena)]
`(with-open [~arena (mem/thread-local-arena)]
~expr))
wrap-fn (fn [call needs-arena?]
`(fn [~@(if const-args? arg-syms ['& args-sym])]
Expand Down Expand Up @@ -412,12 +429,12 @@
[downcall arg-types ret-type]
(if (mem/primitive-type ret-type)
(fn native-fn [& args]
(with-open [arena (mem/confined-arena)]
(with-open [arena (mem/thread-local-arena)]
(mem/deserialize*
(apply downcall (map #(mem/serialize %1 %2 arena) args arg-types))
ret-type)))
(fn native-fn [& args]
(with-open [arena (mem/confined-arena)]
(with-open [arena (mem/thread-local-arena)]
(mem/deserialize-from
(apply downcall (mem/arena-allocator arena)
(map #(mem/serialize %1 %2 arena) args arg-types))
Expand Down
176 changes: 152 additions & 24 deletions src/clj/coffi/mem.clj
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,32 @@
ValueLayout$OfDouble)
(java.lang.ref Cleaner)
(java.util.function Consumer)
(java.nio ByteOrder))
(java.nio ByteOrder)
(clojure.lang
IDeref
Settable))
(:refer-clojure :exclude [defstruct]))

(set! *warn-on-reflection* true)

(def ^:private primitive-tag?
'#{byte bytes short shorts int ints long longs
float floats double doubles
bool bools char chars})

(defmacro once-only
{:style/indent [:defn]
:private true}
[[& names] & body]
(let [gensyms (repeatedly (count names) gensym)]
`(let [~@(interleave gensyms (repeat (count names) `(gensym)))]
`(let [~~@(mapcat #(-> (if (primitive-tag? (:tag (meta %2)))
[%1 ``(~'~(:tag (meta %2)) ~~%2)]
[`(with-meta ~%1 {:tag '~(:tag (meta %2))}) %2]))
gensyms names)]
~(let [~@(mapcat #(-> [(with-meta %1 {}) %2]) names gensyms)]
~@body)))))

(defn confined-arena
"Constructs a new arena for use only in this thread.

Expand Down Expand Up @@ -93,10 +114,135 @@
(^MemorySegment allocate [_this ^long byte-size ^long byte-alignment]
(.allocate arena ^long byte-size ^long byte-alignment))))

(defn thread-local* [init-fn]
(let [local (proxy [ThreadLocal] [] (initialValue [] (init-fn)))]
(reify
IDeref
(deref [this]
(.get local))
Settable
(doSet [this v]
(.set local v)))))

(def ^:private terminating-thread-local-registry (atom {}))

(defn terminating-thread-local* [init cleanup]
(let [current-thread (Thread/currentThread)
local (proxy [ThreadLocal] []
(initialValue []
(let [value (init)]
(swap! terminating-thread-local-registry assoc current-thread value)
value)))
cleanup-thread (proxy [Thread] []
(run []
(.join ^Thread current-thread)
(let [last-value (@terminating-thread-local-registry current-thread)]
(swap! terminating-thread-local-registry dissoc current-thread)
(cleanup last-value))))]
(.start ^Thread cleanup-thread)
(reify
IDeref
(deref [this]
(.get local))
Settable
(doSet [this v]
(swap! terminating-thread-local-registry assoc (Thread/currentThread) v)
(.set local v)))))

(defmacro thread-local
"Takes a body of expressions, and returns a java.lang.ThreadLocal object.
(see http://download.oracle.com/javase/6/docs/api/java/lang/ThreadLocal.html).

To get the current value of the thread-local binding, you must deref (@) the
thread-local object. The body of expressions will be executed once per thread
and future derefs will be cached."
[& body]
`(thread-local* (fn [] ~@body)))

(defmacro terminating-thread-local
"Takes a body of expressions, and a cleanup function as the last argument,
and returns a java.lang.ThreadLocal object, similar to `thread-local`.

Unlike `thread-local`, the cleanup function will be attached and run in a
different (unique) Thread once a Thread that initialized this thread-local
exits, mimicking the behavior of jdk.internal.misc.TerminatingThreadLocal

The main use case is to clean up thread local resources. example:

(def my-thread-local-arena (terminating-thread-local (mem/confined-arena) (fn [arena] (.close arena))))"
[& args]
`(terminating-thread-local* (fn [] ~@(drop-last args)) ~(last args)))

; thread local allocation, inspired by https://bugs.openjdk.org/browse/JDK-8348189

(def ^:private thread-local-buffer-initial-size 512)
(def ^:private thread-local-backing-arena (terminating-thread-local (confined-arena) (fn [arena] (.close ^Arena arena))))
(defprotocol IArenaGet (^Arena arena-get [this]))
(deftype ThreadLocalConfinedArena
[^:unsynchronized-mutable ^SegmentAllocator allocator
^:unsynchronized-mutable ^Arena arena
^:unsynchronized-mutable ^MemorySegment buffer
^:unsynchronized-mutable ^long allocCount
^:unsynchronized-mutable ^int nestingLevels
^:unsynchronized-mutable ^boolean isBufferTooSmall]
Arena
(^MemorySegment allocate [this ^long byteSize ^long byteAlignment]
(try
(set! allocCount (unchecked-add allocCount byteSize))
(.allocate ^SegmentAllocator allocator ^long byteSize ^long byteAlignment)
(catch IndexOutOfBoundsException _
(let [new-size (* thread-local-buffer-initial-size (unchecked-inc-int (unchecked-divide-int allocCount thread-local-buffer-initial-size)))
new-buffer (.allocate ^Arena arena ^long new-size)
new-allocator (SegmentAllocator/slicingAllocator new-buffer)]
(set! isBufferTooSmall (boolean true))
(set! buffer new-buffer)
(set! allocator new-allocator)
(.allocate ^SegmentAllocator new-allocator ^long byteSize ^long byteAlignment)))))
(scope [this]
(.scope arena))
(close [this]
(set! nestingLevels (unchecked-add-int nestingLevels -1))
(if (= 0 nestingLevels)
(do
(if isBufferTooSmall
(let [new-size (* 2 thread-local-buffer-initial-size (unchecked-inc-int (unchecked-divide-int allocCount thread-local-buffer-initial-size)))
new-arena (confined-arena)
new-buffer (.allocate ^Arena new-arena ^long new-size)]
(.doSet ^Settable thread-local-backing-arena new-arena)
(.close ^Arena arena)
(set! arena new-arena)
(set! buffer new-buffer)
(set! isBufferTooSmall (boolean false))))
(set! allocCount 0)
(set! allocator (SegmentAllocator/slicingAllocator buffer)))))
IArenaGet
(arena-get [this]
(set! nestingLevels (unchecked-inc-int nestingLevels))
this))

(def ^:private thread-local-confined-arena
(thread-local (let [arena @thread-local-backing-arena
buffer (.allocate ^Arena arena ^long thread-local-buffer-initial-size)]
(ThreadLocalConfinedArena. (SegmentAllocator/slicingAllocator buffer) arena buffer 0 0 false))))

(defn ^Arena thread-local-arena []
(arena-get @thread-local-confined-arena))

(defn alloc
"Allocates `size` bytes.

If an `arena` is provided, the allocation will be reclaimed when it is closed."
{:inline
(fn alloc-inline
([size]
(once-only [^long size]
`(.allocate ^Arena (Arena/ofAuto) ~size)))
([size arena]
(once-only [^long size ^Arena arena]
`(.allocate ~arena ~size)))
([size alignment arena]
(once-only [^long size ^long alignment ^Arena arena]
`(.allocate ~arena ~size ~alignment))))}
(^MemorySegment [size] (alloc size (auto-arena)))
(^MemorySegment [size arena] (.allocate ^Arena arena (long size)))
(^MemorySegment [size alignment arena] (.allocate ^Arena arena (long size) (long alignment))))
Expand Down Expand Up @@ -287,24 +433,6 @@
"The alignment in bytes of a c-sized pointer."
(.byteAlignment pointer-layout))

(def ^:private primitive-tag?
'#{byte bytes short shorts int ints long longs
float floats double doubles
bool bools char chars})

(defmacro once-only
{:style/indent [:defn]
:private true}
[[& names] & body]
(let [gensyms (repeatedly (count names) gensym)]
`(let [~@(interleave gensyms (repeat (count names) `(gensym)))]
`(let [~~@(mapcat #(-> (if (primitive-tag? (:tag (meta %2)))
[%1 ``(~'~(:tag (meta %2)) ~~%2)]
[`(with-meta ~%1 {:tag '~(:tag (meta %2))}) %2]))
gensyms names)]
~(let [~@(mapcat #(-> [(with-meta %1 {}) %2]) names gensyms)]
~@body)))))

(defn read-byte
"Reads a [[byte]] from the `segment`, at an optional `offset`."
{:inline
Expand Down Expand Up @@ -1176,8 +1304,7 @@
^Class [type]
(java-prim-layout (or (primitive-type type) type) MemorySegment))

(defn size-of
"The size in bytes of the given `type`."
(defn size-of "The size in bytes of the given `type`."
^long [type]
(let [t (cond-> type
(not (instance? MemoryLayout type)) c-layout)]
Expand Down Expand Up @@ -1818,7 +1945,7 @@
(defmethod generate-deserialize :coffi.mem/double [_type offset segment-source-form] `(read-double ~segment-source-form ~offset))
(defmethod generate-deserialize :coffi.mem/pointer [_type offset segment-source-form] `(read-address ~segment-source-form ~offset))
(defmethod generate-deserialize :coffi.mem/c-string [_type offset segment-source-form]
`(.getString (.reinterpret (.get ~(with-meta segment-source-form {:tag 'java.lang.foreign.MemorySegment}) pointer-layout ~offset) Integer/MAX_VALUE) 0))
`(.getString ~(with-meta `(.reinterpret ~(with-meta segment-source-form {:tag 'java.lang.foreign.MemorySegment}) Integer/MAX_VALUE) {:tag 'java.lang.foreign.MemorySegment}) ~offset))

(defn- generate-deserialize-array-as-array-bulk [array-type n offset segment-source-form]
(list (coffitype->array-read-fn array-type) segment-source-form n offset))
Expand Down Expand Up @@ -1954,7 +2081,7 @@
(->> (typelist fields)
(map-indexed
(fn [index [offset [_ field-type]]]
(generate-serialize field-type (list (symbol (str "." (name (nth fieldnames index)))) 'source-obj) (if (number? global-offset) (+ global-offset offset) `(+ ~global-offset ~offset)) segment-source-form)))
(generate-serialize field-type (list (symbol (str "." (name (nth fieldnames index)))) (with-meta 'source-obj {:tag (symbol (str (namespace typename) "." (name typename)))})) (if (number? global-offset) (+ global-offset offset) `(+ ~global-offset ~offset)) segment-source-form)))
(concat [`let ['source-obj source-form]])))))

(gen-interface
Expand Down Expand Up @@ -2240,7 +2367,8 @@
(register-new-struct-serialization coffi-typename struct-layout)
`(do
~(generate-struct-type typename typed-symbols)
(defmethod c-layout ~coffi-typename [~'_] (c-layout ((requiring-resolve 'coffi.layout/with-c-layout) ~struct-layout-raw)))
(let [memory-layout# (c-layout ((requiring-resolve 'coffi.layout/with-c-layout) ~struct-layout-raw))]
(defmethod c-layout ~coffi-typename [~'_] memory-layout#))
(register-new-struct-deserialization ~coffi-typename ((requiring-resolve 'coffi.layout/with-c-layout) ~struct-layout-raw))
(register-new-struct-serialization ~coffi-typename ((requiring-resolve 'coffi.layout/with-c-layout) ~struct-layout-raw))
(defmethod deserialize-from ~coffi-typename ~[segment-form '_type]
Expand Down