MLIR  22.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements SPIR-V transforms used when targetting WebGPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #include <array>
26 #include <cstdint>
27 
28 namespace mlir {
29 namespace spirv {
30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32 } // namespace spirv
33 } // namespace mlir
34 
35 namespace mlir {
36 namespace spirv {
37 namespace {
38 //===----------------------------------------------------------------------===//
39 // Helpers
40 //===----------------------------------------------------------------------===//
41 static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42  APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43  if (auto intTy = dyn_cast<IntegerType>(type))
44  return IntegerAttr::get(intTy, sizedValue);
45 
46  return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47 }
48 
49 static Value lowerExtendedMultiplication(Operation *mulOp,
50  PatternRewriter &rewriter, Value lhs,
51  Value rhs, bool signExtendArguments) {
52  Location loc = mulOp->getLoc();
53  Type argTy = lhs.getType();
54  // Emulate 64-bit multiplication by splitting each input element of type i32
55  // into 2 16-bit digits of type i32. This is so that the intermediate
56  // multiplications and additions do not overflow. We extract these 16-bit
57  // digits from i32 vector elements by masking (low digit) and shifting right
58  // (high digit).
59  //
60  // The multiplication algorithm used is the standard (long) multiplication.
61  // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62  // digits.
63  // - With zero-extended arguments, we end up emitting only 4 multiplications
64  // and 4 additions after constant folding.
65  // - With sign-extended arguments, we end up emitting 8 multiplications and
66  // and 12 additions after CSE.
67  Value cstLowMask = ConstantOp::create(
68  rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70  return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
71  };
72 
73  Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
74  getScalarOrSplatAttr(argTy, 16));
75  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76  return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
77  };
78 
79  auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80  // We only need to shift arithmetically by 15, but the extra
81  // sign-extension bit will be truncated by the logical shift, so this is
82  // fine. We do not have to introduce an extra constant since any
83  // value in [15, 32) would do.
84  return getHighDigit(
85  ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
86  };
87 
88  Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
89  getScalarOrSplatAttr(argTy, 0));
90 
91  Value lhsLow = getLowDigit(lhs);
92  Value lhsHigh = getHighDigit(lhs);
93  Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94  Value rhsLow = getLowDigit(rhs);
95  Value rhsHigh = getHighDigit(rhs);
96  Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97 
98  std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99  std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100  std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101 
102  for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103  for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104  if (i + j >= resultDigits.size())
105  continue;
106 
107  if (lhsDigit == cst0 || rhsDigit == cst0)
108  continue;
109 
110  Value &thisResDigit = resultDigits[i + j];
111  Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112  Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113  thisResDigit = getLowDigit(current);
114 
115  if (i + j + 1 != resultDigits.size()) {
116  Value &nextResDigit = resultDigits[i + j + 1];
117  Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118  getHighDigit(current));
119  nextResDigit = carry;
120  }
121  }
122  }
123 
124  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125  Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126  return BitwiseOrOp::create(rewriter, loc, low, highBits);
127  };
128  Value low = combineDigits(resultDigits[0], resultDigits[1]);
129  Value high = combineDigits(resultDigits[2], resultDigits[3]);
130 
131  return CompositeConstructOp::create(rewriter, loc,
132  mulOp->getResultTypes().front(),
133  llvm::ArrayRef({low, high}));
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Rewrite Patterns
138 //===----------------------------------------------------------------------===//
139 
140 template <typename MulExtendedOp, bool SignExtendArguments>
141 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
143 
144  LogicalResult matchAndRewrite(MulExtendedOp op,
145  PatternRewriter &rewriter) const override {
146  Location loc = op->getLoc();
147  Value lhs = op.getOperand1();
148  Value rhs = op.getOperand2();
149 
150  // Currently, WGSL only supports 32-bit integer types. Any other integer
151  // types should already have been promoted/demoted to i32.
152  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
153  if (elemTy.getIntOrFloatBitWidth() != 32)
154  return rewriter.notifyMatchFailure(
155  loc,
156  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
157 
158  Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159  SignExtendArguments);
160  rewriter.replaceOp(op, mul);
161  return success();
162  }
163 };
164 
165 using ExpandSMulExtendedPattern =
166  ExpandMulExtendedPattern<SMulExtendedOp, true>;
167 using ExpandUMulExtendedPattern =
168  ExpandMulExtendedPattern<UMulExtendedOp, false>;
169 
170 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
172 
173  LogicalResult matchAndRewrite(IAddCarryOp op,
174  PatternRewriter &rewriter) const override {
175  Location loc = op->getLoc();
176  Value lhs = op.getOperand1();
177  Value rhs = op.getOperand2();
178 
179  // Currently, WGSL only supports 32-bit integer types. Any other integer
180  // types should already have been promoted/demoted to i32.
181  Type argTy = lhs.getType();
182  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
183  if (elemTy.getIntOrFloatBitWidth() != 32)
184  return rewriter.notifyMatchFailure(
185  loc,
186  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
187 
188  Value one = ConstantOp::create(rewriter, loc, argTy,
189  getScalarOrSplatAttr(argTy, 1));
190  Value zero = ConstantOp::create(rewriter, loc, argTy,
191  getScalarOrSplatAttr(argTy, 0));
192 
193  // Calculate the carry by checking if the addition resulted in an overflow.
194  Value out = IAddOp::create(rewriter, loc, lhs, rhs);
195  Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
196  Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
197 
198  Value add = CompositeConstructOp::create(rewriter, loc,
199  op->getResultTypes().front(),
200  llvm::ArrayRef({out, carry}));
201 
202  rewriter.replaceOp(op, add);
203  return success();
204  }
205 };
206 
207 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
209 
210  LogicalResult matchAndRewrite(IsInfOp op,
211  PatternRewriter &rewriter) const override {
212  // We assume values to be finite and turn `IsInf` info `false`.
213  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
215  return success();
216  }
217 };
218 
219 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
221 
222  LogicalResult matchAndRewrite(IsNanOp op,
223  PatternRewriter &rewriter) const override {
224  // We assume values to be finite and turn `IsNan` info `false`.
225  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
227  return success();
228  }
229 };
230 
231 //===----------------------------------------------------------------------===//
232 // Passes
233 //===----------------------------------------------------------------------===//
234 struct WebGPUPreparePass final
235  : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
236  void runOnOperation() override {
237  RewritePatternSet patterns(&getContext());
240 
241  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
242  signalPassFailure();
243  }
244 };
245 } // namespace
246 
247 //===----------------------------------------------------------------------===//
248 // Public Interface
249 //===----------------------------------------------------------------------===//
252  // WGSL currently does not support extended multiplication ops, see:
253  // https://github.com/gpuweb/gpuweb/issues/1565.
254  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255  ExpandAddCarryPattern>(patterns.getContext());
256 }
257 
260  // WGSL currently does not support `isInf` and `isNan`, see:
261  // https://github.com/gpuweb/gpuweb/pull/2311.
262  patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
263 }
264 
265 } // namespace spirv
266 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Type
An inlay hint that for a type annotation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.