MLIR  20.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements SPIR-V transforms used when targetting WebGPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #include <array>
26 #include <cstdint>
27 
28 namespace mlir {
29 namespace spirv {
30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32 } // namespace spirv
33 } // namespace mlir
34 
35 namespace mlir {
36 namespace spirv {
37 namespace {
38 //===----------------------------------------------------------------------===//
39 // Helpers
40 //===----------------------------------------------------------------------===//
41 static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42  APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43  if (auto intTy = dyn_cast<IntegerType>(type))
44  return IntegerAttr::get(intTy, sizedValue);
45 
46  return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47 }
48 
49 static Value lowerExtendedMultiplication(Operation *mulOp,
50  PatternRewriter &rewriter, Value lhs,
51  Value rhs, bool signExtendArguments) {
52  Location loc = mulOp->getLoc();
53  Type argTy = lhs.getType();
54  // Emulate 64-bit multiplication by splitting each input element of type i32
55  // into 2 16-bit digits of type i32. This is so that the intermediate
56  // multiplications and additions do not overflow. We extract these 16-bit
57  // digits from i32 vector elements by masking (low digit) and shifting right
58  // (high digit).
59  //
60  // The multiplication algorithm used is the standard (long) multiplication.
61  // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62  // digits.
63  // - With zero-extended arguments, we end up emitting only 4 multiplications
64  // and 4 additions after constant folding.
65  // - With sign-extended arguments, we end up emitting 8 multiplications and
66  // and 12 additions after CSE.
67  Value cstLowMask = rewriter.create<ConstantOp>(
68  loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70  return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
71  };
72 
73  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
74  getScalarOrSplatAttr(argTy, 16));
75  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76  return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
77  };
78 
79  auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80  // We only need to shift arithmetically by 15, but the extra
81  // sign-extension bit will be truncated by the logical shift, so this is
82  // fine. We do not have to introduce an extra constant since any
83  // value in [15, 32) would do.
84  return getHighDigit(
85  rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
86  };
87 
88  Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
89  getScalarOrSplatAttr(argTy, 0));
90 
91  Value lhsLow = getLowDigit(lhs);
92  Value lhsHigh = getHighDigit(lhs);
93  Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94  Value rhsLow = getLowDigit(rhs);
95  Value rhsHigh = getHighDigit(rhs);
96  Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97 
98  std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99  std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100  std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101 
102  for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103  for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104  if (i + j >= resultDigits.size())
105  continue;
106 
107  if (lhsDigit == cst0 || rhsDigit == cst0)
108  continue;
109 
110  Value &thisResDigit = resultDigits[i + j];
111  Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
112  Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113  thisResDigit = getLowDigit(current);
114 
115  if (i + j + 1 != resultDigits.size()) {
116  Value &nextResDigit = resultDigits[i + j + 1];
117  Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118  getHighDigit(current));
119  nextResDigit = carry;
120  }
121  }
122  }
123 
124  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125  Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
126  return rewriter.create<BitwiseOrOp>(loc, low, highBits);
127  };
128  Value low = combineDigits(resultDigits[0], resultDigits[1]);
129  Value high = combineDigits(resultDigits[2], resultDigits[3]);
130 
131  return rewriter.create<CompositeConstructOp>(
132  loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Rewrite Patterns
137 //===----------------------------------------------------------------------===//
138 
139 template <typename MulExtendedOp, bool SignExtendArguments>
140 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 
143  LogicalResult matchAndRewrite(MulExtendedOp op,
144  PatternRewriter &rewriter) const override {
145  Location loc = op->getLoc();
146  Value lhs = op.getOperand1();
147  Value rhs = op.getOperand2();
148 
149  // Currently, WGSL only supports 32-bit integer types. Any other integer
150  // types should already have been promoted/demoted to i32.
151  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
152  if (elemTy.getIntOrFloatBitWidth() != 32)
153  return rewriter.notifyMatchFailure(
154  loc,
155  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
156 
157  Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
158  SignExtendArguments);
159  rewriter.replaceOp(op, mul);
160  return success();
161  }
162 };
163 
164 using ExpandSMulExtendedPattern =
165  ExpandMulExtendedPattern<SMulExtendedOp, true>;
166 using ExpandUMulExtendedPattern =
167  ExpandMulExtendedPattern<UMulExtendedOp, false>;
168 
169 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
171 
172  LogicalResult matchAndRewrite(IAddCarryOp op,
173  PatternRewriter &rewriter) const override {
174  Location loc = op->getLoc();
175  Value lhs = op.getOperand1();
176  Value rhs = op.getOperand2();
177 
178  // Currently, WGSL only supports 32-bit integer types. Any other integer
179  // types should already have been promoted/demoted to i32.
180  Type argTy = lhs.getType();
181  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
182  if (elemTy.getIntOrFloatBitWidth() != 32)
183  return rewriter.notifyMatchFailure(
184  loc,
185  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
186 
187  Value one =
188  rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
189  Value zero =
190  rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
191 
192  // Calculate the carry by checking if the addition resulted in an overflow.
193  Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
194  Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
195  Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
196 
197  Value add = rewriter.create<CompositeConstructOp>(
198  loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
199 
200  rewriter.replaceOp(op, add);
201  return success();
202  }
203 };
204 
205 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
207 
208  LogicalResult matchAndRewrite(IsInfOp op,
209  PatternRewriter &rewriter) const override {
210  // We assume values to be finite and turn `IsInf` info `false`.
211  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
212  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
213  return success();
214  }
215 };
216 
217 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
219 
220  LogicalResult matchAndRewrite(IsNanOp op,
221  PatternRewriter &rewriter) const override {
222  // We assume values to be finite and turn `IsNan` info `false`.
223  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
224  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
225  return success();
226  }
227 };
228 
229 //===----------------------------------------------------------------------===//
230 // Passes
231 //===----------------------------------------------------------------------===//
232 struct WebGPUPreparePass final
233  : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
234  void runOnOperation() override {
235  RewritePatternSet patterns(&getContext());
238 
239  if (failed(
240  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
241  signalPassFailure();
242  }
243 };
244 } // namespace
245 
246 //===----------------------------------------------------------------------===//
247 // Public Interface
248 //===----------------------------------------------------------------------===//
250  RewritePatternSet &patterns) {
251  // WGSL currently does not support extended multiplication ops, see:
252  // https://github.com/gpuweb/gpuweb/issues/1565.
253  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
254  ExpandAddCarryPattern>(patterns.getContext());
255 }
256 
258  RewritePatternSet &patterns) {
259  // WGSL currently does not support `isInf` and `isNan`, see:
260  // https://github.com/gpuweb/gpuweb/pull/2311.
261  patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
262 }
263 
264 } // namespace spirv
265 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.