Skip to content

Commit 211bb50

Browse files
authored
[fusilli] Add pointwise attribute and node (#2416)
This change implements attribute and node for pointwise op types. This does not implement the ASM emitter for `PointwiseNode`. That will be done in a followup PR. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent dbf3409 commit 211bb50

File tree

9 files changed

+772
-0
lines changed

9 files changed

+772
-0
lines changed

sharkfuser/include/fusilli.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
// Attributes / Types:
2424
#include "fusilli/attributes/attributes.h"
2525
#include "fusilli/attributes/conv_attributes.h"
26+
#include "fusilli/attributes/pointwise_attributes.h"
2627
#include "fusilli/attributes/tensor_attributes.h"
2728
#include "fusilli/attributes/types.h"
2829

2930
// Nodes:
3031
#include "fusilli/node/conv_node.h"
3132
#include "fusilli/node/node.h"
33+
#include "fusilli/node/pointwise_node.h"
3234

3335
// Backend:
3436
#include "fusilli/backend/backend.h"
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2025 Advanced Micro Devices, Inc.
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains attributes (compile-time constant metadata) for
10+
// pointwise nodes.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef FUSILLI_ATTRIBUTES_POINTWISE_ATTRIBUTES_H
15+
#define FUSILLI_ATTRIBUTES_POINTWISE_ATTRIBUTES_H
16+
17+
#include "fusilli/attributes/attributes.h"
18+
#include "fusilli/attributes/tensor_attributes.h"
19+
20+
#include <memory>
21+
#include <string>
22+
#include <unordered_map>
23+
24+
namespace fusilli {
25+
26+
class PointwiseAttr : public AttributesCRTP<PointwiseAttr> {
27+
public:
28+
// Names for Tensor Inputs and Outputs. Pointwise can have a maximum of three
29+
// inputs.
30+
enum class InputNames { IN_0, IN_1, IN_2 };
31+
enum class OutputNames { OUT_0 };
32+
33+
enum class Mode {
34+
NOT_SET,
35+
ADD,
36+
RELU_FWD,
37+
};
38+
39+
std::unordered_map<InputNames, std::shared_ptr<TensorAttr>> inputs;
40+
std::unordered_map<OutputNames, std::shared_ptr<TensorAttr>> outputs;
41+
42+
// Setters:
43+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(PointwiseAttr, InputNames, IN_0)
44+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(PointwiseAttr, InputNames, IN_1)
45+
FUSILLI_GENERIC_INPUT_TENSOR_SETTER(PointwiseAttr, InputNames, IN_2)
46+
FUSILLI_GENERIC_OUTPUT_TENSOR_SETTER(PointwiseAttr, OutputNames, OUT_0)
47+
48+
PointwiseAttr &setMode(Mode mode) {
49+
mode_ = mode;
50+
return *this;
51+
}
52+
53+
// Getters:
54+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_0)
55+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_1)
56+
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_2)
57+
FUSILLI_GENERIC_OUTPUT_TENSOR_GETTER(OutputNames, OUT_0)
58+
59+
Mode getMode() const { return mode_; }
60+
61+
// Utilities for pointwise modes.
62+
static const std::unordered_map<Mode, std::string> modeToStr;
63+
static const std::unordered_map<PointwiseAttr::Mode, int>
64+
modeToRequiredInputCount;
65+
66+
private:
67+
Mode mode_ = Mode::NOT_SET;
68+
};
69+
70+
inline const std::unordered_map<PointwiseAttr::Mode, std::string>
71+
PointwiseAttr::modeToStr = {
72+
{PointwiseAttr::Mode::NOT_SET, "NOT_SET"},
73+
{PointwiseAttr::Mode::RELU_FWD, "RELU_FWD"},
74+
{PointwiseAttr::Mode::ADD, "ADD"},
75+
};
76+
inline const std::unordered_map<PointwiseAttr::Mode, int>
77+
PointwiseAttr::modeToRequiredInputCount = {
78+
{PointwiseAttr::Mode::RELU_FWD, 1}, {PointwiseAttr::Mode::ADD, 2}};
79+
80+
} // namespace fusilli
81+
82+
#endif // FUSILLI_ATTRIBUTES_POINTWISE_ATTRIBUTES_H

sharkfuser/include/fusilli/attributes/tensor_attributes.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
#include "fusilli/graph/context.h"
100100
#include "fusilli/support/logging.h"
101101

102+
#include <algorithm>
102103
#include <cassert>
103104
#include <cstddef>
104105
#include <cstdint>
@@ -153,6 +154,38 @@ inline std::vector<size_t> getChannelsLastStrideOrder(size_t numDims) {
153154
return strideOrder;
154155
}
155156

157+
// Generates a stride order preserving the format of `inputStride`. When the
158+
// desired format has a larger size, the result is padded to be of size
159+
// `outputDimSize`.
160+
//
161+
// For example: an input of {10, 30, 20} would return a stride order of
162+
// {0, 2, 1}.
163+
inline std::vector<size_t>
164+
generateStrideOrderPreservingFormat(const std::vector<int64_t> inputStride,
165+
size_t outputDimSize) {
166+
std::vector<size_t> indices(inputStride.size());
167+
std::iota(indices.begin(), indices.end(), 0);
168+
169+
// Sort indices based on stride values in descending order
170+
std::sort(indices.begin(), indices.end(), [&inputStride](size_t i, size_t j) {
171+
return inputStride[i] < inputStride[j];
172+
});
173+
174+
// Create the stride order
175+
std::vector<size_t> strideOrder(inputStride.size());
176+
for (size_t i = 0; i < indices.size(); ++i) {
177+
strideOrder[indices[i]] = i;
178+
}
179+
180+
// If output_dim_size is larger, pad with remaining dimensions
181+
if (outputDimSize > inputStride.size()) {
182+
size_t start = strideOrder.size();
183+
strideOrder.resize(outputDimSize);
184+
std::iota(strideOrder.begin() + start, strideOrder.end(), start);
185+
}
186+
return strideOrder;
187+
}
188+
156189
inline std::vector<int64_t>
157190
generateStrideFromDim(const std::vector<int64_t> &dim,
158191
const std::vector<size_t> &strideOrder) {
@@ -222,6 +255,56 @@ getContiguousToChannelsLastPermuteOrder(size_t numDims) {
222255
return permuteOrder;
223256
}
224257

258+
// Takes a set of input shapes and computes a common shape that all inputs
259+
// shapes can be broadcast to. This implements Pytorch style broadcasting where
260+
// shapes are right-aligned. For example:
261+
//
262+
// Input shapes:
263+
// {64, 16, 1, 1}
264+
// { 1, 32, 1}
265+
//
266+
// Result:
267+
// {64, 16, 32, 1}
268+
inline ErrorOr<std::vector<int64_t>>
269+
computeBroadcastShape(const std::vector<std::vector<int64_t>> &shapes) {
270+
// Remove empty shapes.
271+
auto filteredShapes =
272+
shapes | std::views::filter([](const std::vector<int64_t> &shape) {
273+
return !shape.empty();
274+
});
275+
FUSILLI_RETURN_ERROR_IF(filteredShapes.empty(), ErrorCode::InvalidAttribute,
276+
"All input shapes are empty");
277+
278+
// Find the maximum rank in `shapes`.
279+
size_t maxSize =
280+
std::max_element(
281+
filteredShapes.begin(), filteredShapes.end(),
282+
[](const std::vector<int64_t> &lhs, const std::vector<int64_t> &rhs) {
283+
return lhs.size() < rhs.size();
284+
})
285+
->size();
286+
287+
std::vector<int64_t> commonShape(maxSize, 1);
288+
for (const std::vector<int64_t> &shape : filteredShapes) {
289+
// When broadcasting shapes of differing ranks, the dimensions are
290+
// right-aligned. Process from rightmost dimension to leftmost.
291+
for (size_t offset = 0; offset < shape.size(); ++offset) {
292+
size_t commonIdx = commonShape.size() - 1 - offset;
293+
size_t shapeIdx = shape.size() - 1 - offset;
294+
295+
if (commonShape[commonIdx] == 1) {
296+
commonShape[commonIdx] = shape[shapeIdx];
297+
}
298+
299+
FUSILLI_RETURN_ERROR_IF((shape[shapeIdx] != 1) &&
300+
(commonShape[commonIdx] != shape[shapeIdx]),
301+
ErrorCode::InvalidAttribute,
302+
"Cannot broadcast two non unit dimensions");
303+
}
304+
}
305+
return ok(std::move(commonShape));
306+
}
307+
225308
class TensorAttr {
226309
public:
227310
using scalar_t = std::variant<int64_t, int32_t, float, double>;

sharkfuser/include/fusilli/node/node.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class INode {
3030
enum class Type {
3131
Composite,
3232
Convolution,
33+
Pointwise,
3334
};
3435

3536
explicit INode(const Context &ctx) : context(ctx) {}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2025 Advanced Micro Devices, Inc.
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains definitions for the pointwise nodes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef FUSILLI_NODE_POINTWISE_NODE_H
14+
#define FUSILLI_NODE_POINTWISE_NODE_H
15+
16+
#include "fusilli/attributes/pointwise_attributes.h"
17+
#include "fusilli/attributes/tensor_attributes.h"
18+
#include "fusilli/graph/context.h"
19+
#include "fusilli/node/node.h"
20+
#include "fusilli/support/logging.h"
21+
22+
#include <string>
23+
#include <unordered_map>
24+
25+
namespace fusilli {
26+
27+
class PointwiseNode : public NodeCRTP<PointwiseNode> {
28+
public:
29+
PointwiseAttr pointwiseAttr;
30+
31+
PointwiseNode(PointwiseAttr &&attr, const Context &ctx)
32+
: NodeCRTP(ctx), pointwiseAttr(std::move(attr)) {}
33+
34+
const std::string &getName() const override final {
35+
return pointwiseAttr.getName();
36+
}
37+
Type getType() const override final { return Type::Pointwise; }
38+
39+
ErrorObject preValidateNode() const override final {
40+
FUSILLI_LOG_LABEL_ENDL("INFO: Pre-Validating PointwiseNode '"
41+
<< pointwiseAttr.getName() << "'");
42+
FUSILLI_RETURN_ERROR_IF(
43+
pointwiseAttr.getMode() == PointwiseAttr::Mode::NOT_SET,
44+
ErrorCode::AttributeNotSet, "Pointwise mode not set");
45+
46+
// Validate inputs based on mode
47+
PointwiseAttr::Mode mode = pointwiseAttr.getMode();
48+
int requiredCount = PointwiseAttr::modeToRequiredInputCount.at(mode);
49+
50+
// Validate input requirements (required inputs must exist, unnecessary ones
51+
// must not)
52+
constexpr int maxInputs = 3;
53+
for (int i = 0; i < maxInputs; ++i) {
54+
auto inputName = static_cast<PointwiseAttr::InputNames>(i);
55+
bool hasInput = pointwiseAttr.inputs.contains(inputName) &&
56+
pointwiseAttr.inputs.at(inputName) != nullptr;
57+
58+
if (i < requiredCount) {
59+
FUSILLI_RETURN_ERROR_IF(!hasInput, ErrorCode::AttributeNotSet,
60+
PointwiseAttr::modeToStr.at(mode) +
61+
" mode requires IN_" + std::to_string(i) +
62+
" input");
63+
} else {
64+
FUSILLI_RETURN_ERROR_IF(hasInput, ErrorCode::InvalidAttribute,
65+
PointwiseAttr::modeToStr.at(mode) +
66+
" mode should not have IN_" +
67+
std::to_string(i) + " input set");
68+
}
69+
}
70+
71+
// Validate output
72+
FUSILLI_RETURN_ERROR_IF(!pointwiseAttr.getOUT_0(),
73+
ErrorCode::AttributeNotSet,
74+
"Pointwise operation requires output");
75+
76+
return ok();
77+
}
78+
79+
ErrorObject inferPropertiesNode() override final {
80+
FUSILLI_LOG_LABEL_ENDL("INFO: Inferring properties for PointwiseNode '"
81+
<< pointwiseAttr.getName() << "'");
82+
83+
// Fill missing properties from context (including data types)
84+
pointwiseAttr.fillFromContext(context);
85+
86+
const auto &outTensor = pointwiseAttr.getOUT_0();
87+
if (outTensor->getDim().empty()) {
88+
// Collect all input shapes
89+
std::vector<std::vector<int64_t>> inputShapes;
90+
for (const auto &[_, inTensor] : pointwiseAttr.inputs)
91+
if (inTensor)
92+
inputShapes.push_back(inTensor->getDim());
93+
94+
outTensor->setDim(FUSILLI_TRY(computeBroadcastShape(inputShapes)));
95+
}
96+
97+
if (outTensor->getStride().empty()) {
98+
// Try to set the stride from an input shape that matches the output
99+
// shape.
100+
for (const auto &[_, inTensor] : pointwiseAttr.inputs) {
101+
if (!inTensor)
102+
continue;
103+
if (inTensor->getDim() != outTensor->getDim())
104+
continue;
105+
outTensor->setStride(inTensor->getStride());
106+
}
107+
108+
if (outTensor->getStride().empty() && outTensor->isVirtual()) {
109+
// If we haven't found the stride already and the output is virtual,
110+
// compute an output stride that has the same format as IN_0. This can
111+
// occur when all inputs are broadcasted.
112+
auto inputStride = pointwiseAttr.getIN_0()->getStride();
113+
std::vector<size_t> strideOrder = generateStrideOrderPreservingFormat(
114+
inputStride, outTensor->getDim().size());
115+
outTensor->setStride(
116+
generateStrideFromDim(outTensor->getDim(), strideOrder));
117+
}
118+
FUSILLI_RETURN_ERROR_IF(outTensor->getStride().empty(),
119+
ErrorCode::InvalidAttribute,
120+
"Pointwise output strides could not be computed");
121+
}
122+
123+
return ok();
124+
}
125+
};
126+
} // namespace fusilli
127+
128+
#endif // FUSILLI_NODE_POINTWISE_NODE_H

sharkfuser/tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_fusilli_test(
3131
test_attributes.cpp
3232
test_tensor_attributes.cpp
3333
test_conv_attributes.cpp
34+
test_pointwise_attributes.cpp
3435
DEPS
3536
libfusilli
3637
Catch2::Catch2WithMain
@@ -40,6 +41,7 @@ add_fusilli_test(
4041
NAME fusilli_node_tests
4142
SRCS
4243
test_conv_node.cpp
44+
test_pointwise_node.cpp
4345
DEPS
4446
libfusilli
4547
Catch2::Catch2WithMain

0 commit comments

Comments
 (0)