MLIR 22.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements SPIR-V transforms used when targetting WebGPU.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/Location.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#include <array>
26#include <cstdint>
27
28namespace mlir {
29namespace spirv {
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32} // namespace spirv
33} // namespace mlir
34
35namespace mlir {
36namespace spirv {
37namespace {
38//===----------------------------------------------------------------------===//
39// Helpers
40//===----------------------------------------------------------------------===//
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42 APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43 if (auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
45
46 return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47}
48
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value lhs,
51 Value rhs, bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy = lhs.getType();
54 // Emulate 64-bit multiplication by splitting each input element of type i32
55 // into 2 16-bit digits of type i32. This is so that the intermediate
56 // multiplications and additions do not overflow. We extract these 16-bit
57 // digits from i32 vector elements by masking (low digit) and shifting right
58 // (high digit).
59 //
60 // The multiplication algorithm used is the standard (long) multiplication.
61 // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62 // digits.
63 // - With zero-extended arguments, we end up emitting only 4 multiplications
64 // and 4 additions after constant folding.
65 // - With sign-extended arguments, we end up emitting 8 multiplications and
66 // and 12 additions after CSE.
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
71 };
72
73 Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
77 };
78
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80 // We only need to shift arithmetically by 15, but the extra
81 // sign-extension bit will be truncated by the logical shift, so this is
82 // fine. We do not have to introduce an extra constant since any
83 // value in [15, 32) would do.
84 return getHighDigit(
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
86 };
87
88 Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
90
91 Value lhsLow = getLowDigit(lhs);
92 Value lhsHigh = getHighDigit(lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94 Value rhsLow = getLowDigit(rhs);
95 Value rhsHigh = getHighDigit(rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101
102 for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
105 continue;
106
107 if (lhsDigit == cst0 || rhsDigit == cst0)
108 continue;
109
110 Value &thisResDigit = resultDigits[i + j];
111 Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113 thisResDigit = getLowDigit(current);
114
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
120 }
121 }
122 }
123
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
127 };
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
130
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
133 llvm::ArrayRef({low, high}));
134}
135
136//===----------------------------------------------------------------------===//
137// Rewrite Patterns
138//===----------------------------------------------------------------------===//
139
140template <typename MulExtendedOp, bool SignExtendArguments>
141struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter) const override {
146 Location loc = op->getLoc();
147 Value lhs = op.getOperand1();
148 Value rhs = op.getOperand2();
149
150 // Currently, WGSL only supports 32-bit integer types. Any other integer
151 // types should already have been promoted/demoted to i32.
152 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
155 loc,
156 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
157
158 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op, mul);
161 return success();
162 }
163};
164
165using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
169
170struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
171 using Base::Base;
172
173 LogicalResult matchAndRewrite(IAddCarryOp op,
174 PatternRewriter &rewriter) const override {
175 Location loc = op->getLoc();
176 Value lhs = op.getOperand1();
177 Value rhs = op.getOperand2();
178
179 // Currently, WGSL only supports 32-bit integer types. Any other integer
180 // types should already have been promoted/demoted to i32.
181 Type argTy = lhs.getType();
182 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
183 if (elemTy.getIntOrFloatBitWidth() != 32)
184 return rewriter.notifyMatchFailure(
185 loc,
186 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
187
188 Value one = ConstantOp::create(rewriter, loc, argTy,
189 getScalarOrSplatAttr(argTy, 1));
190 Value zero = ConstantOp::create(rewriter, loc, argTy,
191 getScalarOrSplatAttr(argTy, 0));
192
193 // Calculate the carry by checking if the addition resulted in an overflow.
194 Value out = IAddOp::create(rewriter, loc, lhs, rhs);
195 Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
196 Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
197
198 Value add = CompositeConstructOp::create(rewriter, loc,
199 op->getResultTypes().front(),
200 llvm::ArrayRef({out, carry}));
201
202 rewriter.replaceOp(op, add);
203 return success();
204 }
205};
206
207struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
208 using Base::Base;
209
210 LogicalResult matchAndRewrite(IsInfOp op,
211 PatternRewriter &rewriter) const override {
212 // We assume values to be finite and turn `IsInf` info `false`.
213 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
215 return success();
216 }
217};
218
219struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
220 using Base::Base;
221
222 LogicalResult matchAndRewrite(IsNanOp op,
223 PatternRewriter &rewriter) const override {
224 // We assume values to be finite and turn `IsNan` info `false`.
225 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
227 return success();
228 }
229};
230
231//===----------------------------------------------------------------------===//
232// Passes
233//===----------------------------------------------------------------------===//
234struct WebGPUPreparePass final
235 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
236 void runOnOperation() override {
237 RewritePatternSet patterns(&getContext());
240
241 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
242 signalPassFailure();
243 }
244};
245} // namespace
246
247//===----------------------------------------------------------------------===//
248// Public Interface
249//===----------------------------------------------------------------------===//
252 // WGSL currently does not support extended multiplication ops, see:
253 // https://github.com/gpuweb/gpuweb/issues/1565.
254 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255 ExpandAddCarryPattern>(patterns.getContext());
256}
257
260 // WGSL currently does not support `isInf` and `isNan`, see:
261 // https://github.com/gpuweb/gpuweb/pull/2311.
262 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
263}
264
265} // namespace spirv
266} // namespace mlir
return success()
lhs
b getContext())
#define mul(a, b)
#define add(a, b)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns