@@ -2497,22 +2497,9 @@ struct memory : public handle<dnnl_memory_t> {
24972497 /// A memory descriptor.
24982498 struct desc {
24992499 struct sparse_desc {
2500- sparse_desc(dnnl_sparse_encoding_t encoding, const dims &dims_order,
2501- dim nnze, const std::vector<data_type> &metadata_types,
2502- const dims &entry_dims, const dims &structure_dims,
2503- const dims &structure_nnz, bool allow_empty = false) {
2504- std::vector<dnnl_data_type_t> c_metadata_types(
2505- metadata_types.size());
2506- for (size_t i = 0; i < c_metadata_types.size(); i++) {
2507- c_metadata_types[i] = convert_to_c(metadata_types[i]);
2508- }
2500+ sparse_desc(dnnl_sparse_encoding_t encoding, bool allow_empty = false) {
25092501 // TODO: check structure_dims.size() == structure_nnz.size();
2510- dnnl_status_t status = dnnl_sparse_desc_init(&data, encoding,
2511- (int)dims_order.size(), dims_order.data(), nnze,
2512- (int)c_metadata_types.size(), c_metadata_types.data(),
2513- (int)entry_dims.size(), entry_dims.data(),
2514- (int)structure_dims.size(), structure_dims.data(),
2515- structure_nnz.data());
2502+ dnnl_status_t status = dnnl_sparse_desc_init(&data, encoding);
25162503 if (!allow_empty)
25172504 error::wrap_c_api(
25182505 status, "could not construct a sparse descriptor");
@@ -2602,41 +2589,8 @@ struct memory : public handle<dnnl_memory_t> {
26022589 "sparse descriptor");
26032590 }
26042591
2605- /// Function for creating CSR sparse desc with unstructured sparsity.
2606- static sparse_desc csr(dim nnze, data_type index_type,
2607- data_type pointer_type, bool allow_empty = false) {
2608- return sparse_desc(dnnl_sparse_encoding_csr, {0, 1}, nnze,
2609- {index_type, pointer_type}, {}, {}, {}, allow_empty);
2610- }
2611-
2612- /// Function for creating CSC sparse desc with unstructured sparsity.
2613- static sparse_desc csc(dim nnze, data_type index_type,
2614- data_type pointer_type, bool allow_empty = false) {
2615- return sparse_desc(dnnl_sparse_encoding_csc, {1, 0}, nnze,
2616- {index_type, pointer_type}, {}, {}, {}, allow_empty);
2617- }
2618-
2619- /// Function for creating BCSR sparse desc with unstructured sparsity.
2620- static sparse_desc bcsr(dim nnze, data_type index_type,
2621- data_type pointer_type, const dims &block_dims,
2622- bool allow_empty = false) {
2623- return sparse_desc(dnnl_sparse_encoding_bcsr, {0, 1}, nnze,
2624- {index_type, pointer_type}, block_dims, {}, {},
2625- allow_empty);
2626- }
2627-
2628- /// Function for creating BCSC sparse desc unstructured sparsity.
2629- static sparse_desc bcsc(dim nnze, data_type index_type,
2630- data_type pointer_type, const dims &block_dims,
2631- bool allow_empty = false) {
2632- return sparse_desc(dnnl_sparse_encoding_bcsc, {1, 0}, nnze,
2633- {index_type, pointer_type}, block_dims, {}, {},
2634- allow_empty);
2635- }
2636-
2637- static sparse_desc packed(dim nnze, bool allow_empty = false) {
2638- return sparse_desc(dnnl_sparse_encoding_packed, {}, nnze, {}, {},
2639- {}, {}, allow_empty);
2592+ static sparse_desc packed(bool allow_empty = false) {
2593+ return sparse_desc(dnnl_sparse_encoding_packed, allow_empty);
26402594 }
26412595
26422596 /// Constructs a memory descriptor for a region inside an area
@@ -2786,18 +2740,6 @@ struct memory : public handle<dnnl_memory_t> {
27862740 /// including the padding area.
27872741 size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
27882742
2789- /// Returns the size of a values and metadata for a particular sparse
2790- /// encoding.
2791- ///
2792- /// @param index Index that correspondes to values or metadata.
2793- /// Each sparse encoding defines index interpretation.
2794- ///
2795- /// @returns The number of bytes required for values or metadata for a
2796- /// particular sparse encoding described by a memory descriptor.
2797- size_t get_size(int index) const {
2798- return dnnl_memory_desc_get_size_sparse(&data, index);
2799- }
2800-
28012743 /// Checks whether the memory descriptor is zero (empty).
28022744 /// @returns @c true if the memory descriptor describes an empty
28032745 /// memory and @c false otherwise.
@@ -2858,44 +2800,12 @@ struct memory : public handle<dnnl_memory_t> {
28582800
28592801 /// Constructs a memory object.
28602802 ///
2861- /// The underlying buffer(s) for the memory will be allocated by the
2862- /// library.
2803+ /// The underlying buffer for the memory will be allocated by the library.
28632804 ///
28642805 /// @param md Memory descriptor.
28652806 /// @param aengine Engine to store the data on.
2866- memory(const desc &md, const engine &aengine) {
2867- dnnl_status_t status;
2868- dnnl_memory_t result;
2869- const bool is_sparse_md = md.data.format_kind == dnnl_format_sparse;
2870- if (is_sparse_md) {
2871- // Deduce number of handles.
2872- dim nhandles = 0;
2873- switch (md.data.format_desc.sparse_desc.encoding) {
2874- case dnnl_sparse_encoding_csr:
2875- case dnnl_sparse_encoding_csc:
2876- case dnnl_sparse_encoding_bcsr:
2877- case dnnl_sparse_encoding_bcsc: nhandles = 3; break;
2878- case dnnl_sparse_encoding_packed: nhandles = 1; break;
2879- default: nhandles = 0;
2880- }
2881- std::vector<void *> handles(nhandles, DNNL_MEMORY_ALLOCATE);
2882- status = dnnl_memory_create_sparse(&result, &md.data, aengine.get(),
2883- (dim)handles.size(), handles.data());
2884- } else {
2885- status = dnnl_memory_create(
2886- &result, &md.data, aengine.get(), DNNL_MEMORY_ALLOCATE);
2887- }
2888- error::wrap_c_api(status, "could not create a memory object");
2889- reset(result);
2890- }
2891-
2892- memory(const desc &md, const engine &aengine, std::vector<void *> handles) {
2893- dnnl_memory_t result;
2894- dnnl_status_t status = dnnl_memory_create_sparse(&result, &md.data,
2895- aengine.get(), (dim)handles.size(), handles.data());
2896- error::wrap_c_api(status, "could not create a memory object");
2897- reset(result);
2898- }
2807+ memory(const desc &md, const engine &aengine)
2808+ : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
28992809
29002810 /// Returns the associated memory descriptor.
29012811 desc get_desc() const {
@@ -2924,28 +2834,6 @@ struct memory : public handle<dnnl_memory_t> {
29242834 return handle;
29252835 }
29262836
2927- // TODO: add documentation.
2928- std::vector<void *> get_data_handles() const {
2929- dim nhandles;
2930- error::wrap_c_api(
2931- dnnl_memory_get_data_handles(get(), &nhandles, nullptr),
2932- "could not get a number of native handles from a memory "
2933- "object");
2934- std::vector<void *> handles(nhandles);
2935- error::wrap_c_api(
2936- dnnl_memory_get_data_handles(get(), &nhandles, handles.data()),
2937- "could not get native handles from a memory object");
2938- return handles;
2939- }
2940-
2941- // TODO: add documentation.
2942- void set_data_handles(std::vector<void *> handles) {
2943- dim nhandles = handles.size();
2944- error::wrap_c_api(
2945- dnnl_memory_set_data_handles(get(), nhandles, handles.data()),
2946- "could not set native handles of a memory object");
2947- }
2948-
29492837 /// Sets the underlying memory buffer.
29502838 ///
29512839 /// This function may write zero values to the memory specified by the @p
@@ -3031,23 +2919,6 @@ struct memory : public handle<dnnl_memory_t> {
30312919 return static_cast<T *>(mapped_ptr);
30322920 }
30332921
3034- // TODO: add documentation.
3035- template <typename T = void>
3036- T *map_data(int index) const {
3037- void *mapped_ptr;
3038- error::wrap_c_api(
3039- dnnl_memory_map_data_sparse(get(), index, &mapped_ptr),
3040- "could not map memory object data");
3041- return static_cast<T *>(mapped_ptr);
3042- }
3043-
3044- // TODO: add documentation.
3045- void unmap_data(int index, void *mapped_ptr) const {
3046- error::wrap_c_api(
3047- dnnl_memory_unmap_data_sparse(get(), index, mapped_ptr),
3048- "could not unmap memory object data");
3049- }
3050-
30512922 /// Unmaps a memory object and writes back any changes made to the
30522923 /// previously mapped memory buffer.
30532924 ///
0 commit comments