MLIR 23.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements SPIR-V transforms used when targetting WebGPU.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/Location.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#include <array>
26#include <cstdint>
27
28namespace mlir {
29namespace spirv {
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32} // namespace spirv
33} // namespace mlir
34
35namespace mlir {
36namespace spirv {
37namespace {
38//===----------------------------------------------------------------------===//
39// Helpers
40//===----------------------------------------------------------------------===//
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42 APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43 if (auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
45
46 return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47}
48
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value lhs,
51 Value rhs, bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy = lhs.getType();
54 // Emulate 64-bit multiplication by splitting each input element of type i32
55 // into 2 16-bit digits of type i32. This is so that the intermediate
56 // multiplications and additions do not overflow. We extract these 16-bit
57 // digits from i32 vector elements by masking (low digit) and shifting right
58 // (high digit).
59 //
60 // The multiplication algorithm used is the standard (long) multiplication.
61 // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62 // digits.
63 // - With zero-extended arguments, we end up emitting only 4 multiplications
64 // and 4 additions after constant folding.
65 // - With sign-extended arguments, we end up emitting 8 multiplications and
66 // and 12 additions after CSE.
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
71 };
72
73 Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
77 };
78
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80 // We only need to shift arithmetically by 15, but the extra
81 // sign-extension bit will be truncated by the logical shift, so this is
82 // fine. We do not have to introduce an extra constant since any
83 // value in [15, 32) would do.
84 return getHighDigit(
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
86 };
87
88 Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
90
91 Value lhsLow = getLowDigit(lhs);
92 Value lhsHigh = getHighDigit(lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94 Value rhsLow = getLowDigit(rhs);
95 Value rhsHigh = getHighDigit(rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101
102 for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
105 continue;
106
107 if (lhsDigit == cst0 || rhsDigit == cst0)
108 continue;
109
110 Value &thisResDigit = resultDigits[i + j];
111 Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113 thisResDigit = getLowDigit(current);
114
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
120 }
121 }
122 }
123
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
127 };
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
130
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
133 llvm::ArrayRef({low, high}));
134}
135
136//===----------------------------------------------------------------------===//
137// Rewrite Patterns
138//===----------------------------------------------------------------------===//
139
140template <typename MulExtendedOp, bool SignExtendArguments>
141struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter) const override {
146 Location loc = op->getLoc();
147 Value lhs = op.getOperand1();
148 Value rhs = op.getOperand2();
149
150 // Currently, WGSL only supports 32-bit integer types. Any other integer
151 // types should already have been promoted/demoted to i32.
152 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
155 loc,
156 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
157
158 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op, mul);
161 return success();
162 }
163};
164
165using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
169
170struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
171 using Base::Base;
172
173 LogicalResult matchAndRewrite(IAddCarryOp op,
174 PatternRewriter &rewriter) const override {
175 Location loc = op->getLoc();
176 Value lhs = op.getOperand1();
177 Value rhs = op.getOperand2();
178
179 // Currently, WGSL only supports 32-bit integer types. Any other integer
180 // types should already have been promoted/demoted to i32.
181 Type argTy = lhs.getType();
182 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
183 if (elemTy.getIntOrFloatBitWidth() != 32)
184 return rewriter.notifyMatchFailure(
185 loc,
186 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
187
188 Value one = ConstantOp::create(rewriter, loc, argTy,
189 getScalarOrSplatAttr(argTy, 1));
190 Value zero = ConstantOp::create(rewriter, loc, argTy,
191 getScalarOrSplatAttr(argTy, 0));
192
193 // Calculate the carry by checking if the addition resulted in an overflow.
194 Value out = IAddOp::create(rewriter, loc, lhs, rhs);
195 Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
196 Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
197
198 Value add = CompositeConstructOp::create(rewriter, loc,
199 op->getResultTypes().front(),
200 llvm::ArrayRef({out, carry}));
201
202 rewriter.replaceOp(op, add);
203 return success();
204 }
205};
206
207struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
208 using Base::Base;
209
210 LogicalResult matchAndRewrite(IsInfOp op,
211 PatternRewriter &rewriter) const override {
212 // We assume values to be finite and turn `IsInf` info `false`.
213 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
215 return success();
216 }
217};
218
219struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
220 using Base::Base;
221
222 LogicalResult matchAndRewrite(IsNanOp op,
223 PatternRewriter &rewriter) const override {
224 // We assume values to be finite and turn `IsNan` info `false`.
225 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
227 return success();
228 }
229};
230
231//===----------------------------------------------------------------------===//
232// Passes
233//===----------------------------------------------------------------------===//
234struct WebGPUPreparePass final
235 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
236 void runOnOperation() override {
237 RewritePatternSet patterns(&getContext());
240
241 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
242 signalPassFailure();
243 }
244};
245} // namespace
246
247//===----------------------------------------------------------------------===//
248// Public Interface
249//===----------------------------------------------------------------------===//
251 RewritePatternSet &patterns) {
252 // WGSL currently does not support extended multiplication ops, see:
253 // https://github.com/gpuweb/gpuweb/issues/1565.
254 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255 ExpandAddCarryPattern>(patterns.getContext());
256}
257
259 RewritePatternSet &patterns) {
260 // WGSL currently does not support `isInf` and `isNan`, see:
261 // https://github.com/gpuweb/gpuweb/pull/2311.
262 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
263}
264
265} // namespace spirv
266} // namespace mlir
return success()
lhs
b getContext())
#define mul(a, b)
#define add(a, b)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.