21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
43 if (
auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value
lhs,
51 Value
rhs,
bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy =
lhs.getType();
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc,
lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
73 Value cst16 = ConstantOp::create(rewriter, loc,
lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
88 Value cst0 = ConstantOp::create(rewriter, loc,
lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
91 Value lhsLow = getLowDigit(
lhs);
92 Value lhsHigh = getHighDigit(
lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(
lhs) : cst0;
94 Value rhsLow = getLowDigit(
rhs);
95 Value rhsHigh = getHighDigit(
rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(
rhs) : cst0;
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102 for (
auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (
auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
107 if (lhsDigit == cst0 || rhsDigit == cst0)
110 Value &thisResDigit = resultDigits[i + j];
111 Value
mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit,
mul);
113 thisResDigit = getLowDigit(current);
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
133 llvm::ArrayRef({low, high}));
140template <
typename MulExtendedOp,
bool SignExtendArguments>
141struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter)
const override {
146 Location loc = op->getLoc();
147 Value
lhs = op.getOperand1();
148 Value
rhs = op.getOperand2();
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
156 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
158 Value
mul = lowerExtendedMultiplication(op, rewriter,
lhs,
rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op,
mul);
165using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
170struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
173 LogicalResult matchAndRewrite(IAddCarryOp op,
174 PatternRewriter &rewriter)
const override {
175 Location loc = op->getLoc();
176 Value
lhs = op.getOperand1();
177 Value
rhs = op.getOperand2();
181 Type argTy =
lhs.getType();
183 if (elemTy.getIntOrFloatBitWidth() != 32)
184 return rewriter.notifyMatchFailure(
186 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
188 Value one = ConstantOp::create(rewriter, loc, argTy,
189 getScalarOrSplatAttr(argTy, 1));
190 Value zero = ConstantOp::create(rewriter, loc, argTy,
191 getScalarOrSplatAttr(argTy, 0));
194 Value out = IAddOp::create(rewriter, loc,
lhs,
rhs);
195 Value cmp = ULessThanOp::create(rewriter, loc, out,
lhs);
196 Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
198 Value
add = CompositeConstructOp::create(rewriter, loc,
199 op->getResultTypes().front(),
200 llvm::ArrayRef({out, carry}));
202 rewriter.replaceOp(op,
add);
207struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
210 LogicalResult matchAndRewrite(IsInfOp op,
211 PatternRewriter &rewriter)
const override {
213 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
219struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
222 LogicalResult matchAndRewrite(IsNanOp op,
223 PatternRewriter &rewriter)
const override {
225 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
234struct WebGPUPreparePass final
236 void runOnOperation()
override {
254 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255 ExpandAddCarryPattern>(
patterns.getContext());
262 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(
patterns.getContext());
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns