MLIR  19.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements SPIR-V transforms used when targetting WebGPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #include <array>
27 #include <cstdint>
28 
29 namespace mlir {
30 namespace spirv {
31 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33 } // namespace spirv
34 } // namespace mlir
35 
36 namespace mlir {
37 namespace spirv {
38 namespace {
39 //===----------------------------------------------------------------------===//
40 // Helpers
41 //===----------------------------------------------------------------------===//
42 static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
43  APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
44  if (auto intTy = dyn_cast<IntegerType>(type))
45  return IntegerAttr::get(intTy, sizedValue);
46 
47  return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
48 }
49 
50 static Value lowerExtendedMultiplication(Operation *mulOp,
51  PatternRewriter &rewriter, Value lhs,
52  Value rhs, bool signExtendArguments) {
53  Location loc = mulOp->getLoc();
54  Type argTy = lhs.getType();
55  // Emulate 64-bit multiplication by splitting each input element of type i32
56  // into 2 16-bit digits of type i32. This is so that the intermediate
57  // multiplications and additions do not overflow. We extract these 16-bit
58  // digits from i32 vector elements by masking (low digit) and shifting right
59  // (high digit).
60  //
61  // The multiplication algorithm used is the standard (long) multiplication.
62  // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
63  // digits.
64  // - With zero-extended arguments, we end up emitting only 4 multiplications
65  // and 4 additions after constant folding.
66  // - With sign-extended arguments, we end up emitting 8 multiplications and
67  // and 12 additions after CSE.
68  Value cstLowMask = rewriter.create<ConstantOp>(
69  loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
70  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
71  return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
72  };
73 
74  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
75  getScalarOrSplatAttr(argTy, 16));
76  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
77  return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
78  };
79 
80  auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
81  // We only need to shift arithmetically by 15, but the extra
82  // sign-extension bit will be truncated by the logical shift, so this is
83  // fine. We do not have to introduce an extra constant since any
84  // value in [15, 32) would do.
85  return getHighDigit(
86  rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
87  };
88 
89  Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
90  getScalarOrSplatAttr(argTy, 0));
91 
92  Value lhsLow = getLowDigit(lhs);
93  Value lhsHigh = getHighDigit(lhs);
94  Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
95  Value rhsLow = getLowDigit(rhs);
96  Value rhsHigh = getHighDigit(rhs);
97  Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
98 
99  std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
100  std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
101  std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102 
103  for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
104  for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
105  if (i + j >= resultDigits.size())
106  continue;
107 
108  if (lhsDigit == cst0 || rhsDigit == cst0)
109  continue;
110 
111  Value &thisResDigit = resultDigits[i + j];
112  Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
113  Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
114  thisResDigit = getLowDigit(current);
115 
116  if (i + j + 1 != resultDigits.size()) {
117  Value &nextResDigit = resultDigits[i + j + 1];
118  Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
119  getHighDigit(current));
120  nextResDigit = carry;
121  }
122  }
123  }
124 
125  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
126  Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
127  return rewriter.create<BitwiseOrOp>(loc, low, highBits);
128  };
129  Value low = combineDigits(resultDigits[0], resultDigits[1]);
130  Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 
132  return rewriter.create<CompositeConstructOp>(
133  loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Rewrite Patterns
138 //===----------------------------------------------------------------------===//
139 
140 template <typename MulExtendedOp, bool SignExtendArguments>
141 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
143 
144  LogicalResult matchAndRewrite(MulExtendedOp op,
145  PatternRewriter &rewriter) const override {
146  Location loc = op->getLoc();
147  Value lhs = op.getOperand1();
148  Value rhs = op.getOperand2();
149 
150  // Currently, WGSL only supports 32-bit integer types. Any other integer
151  // types should already have been promoted/demoted to i32.
152  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
153  if (elemTy.getIntOrFloatBitWidth() != 32)
154  return rewriter.notifyMatchFailure(
155  loc,
156  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
157 
158  Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159  SignExtendArguments);
160  rewriter.replaceOp(op, mul);
161  return success();
162  }
163 };
164 
165 using ExpandSMulExtendedPattern =
166  ExpandMulExtendedPattern<SMulExtendedOp, true>;
167 using ExpandUMulExtendedPattern =
168  ExpandMulExtendedPattern<UMulExtendedOp, false>;
169 
170 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
172 
173  LogicalResult matchAndRewrite(IAddCarryOp op,
174  PatternRewriter &rewriter) const override {
175  Location loc = op->getLoc();
176  Value lhs = op.getOperand1();
177  Value rhs = op.getOperand2();
178 
179  // Currently, WGSL only supports 32-bit integer types. Any other integer
180  // types should already have been promoted/demoted to i32.
181  Type argTy = lhs.getType();
182  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
183  if (elemTy.getIntOrFloatBitWidth() != 32)
184  return rewriter.notifyMatchFailure(
185  loc,
186  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
187 
188  Value one =
189  rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
190  Value zero =
191  rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
192 
193  // Calculate the carry by checking if the addition resulted in an overflow.
194  Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
195  Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
196  Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
197 
198  Value add = rewriter.create<CompositeConstructOp>(
199  loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
200 
201  rewriter.replaceOp(op, add);
202  return success();
203  }
204 };
205 
206 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
208 
209  LogicalResult matchAndRewrite(IsInfOp op,
210  PatternRewriter &rewriter) const override {
211  // We assume values to be finite and turn `IsInf` info `false`.
212  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
214  return success();
215  }
216 };
217 
218 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
220 
221  LogicalResult matchAndRewrite(IsNanOp op,
222  PatternRewriter &rewriter) const override {
223  // We assume values to be finite and turn `IsNan` info `false`.
224  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
225  op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
226  return success();
227  }
228 };
229 
230 //===----------------------------------------------------------------------===//
231 // Passes
232 //===----------------------------------------------------------------------===//
233 struct WebGPUPreparePass final
234  : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
235  void runOnOperation() override {
236  RewritePatternSet patterns(&getContext());
239 
240  if (failed(
241  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
242  signalPassFailure();
243  }
244 };
245 } // namespace
246 
247 //===----------------------------------------------------------------------===//
248 // Public Interface
249 //===----------------------------------------------------------------------===//
251  RewritePatternSet &patterns) {
252  // WGSL currently does not support extended multiplication ops, see:
253  // https://github.com/gpuweb/gpuweb/issues/1565.
254  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255  ExpandAddCarryPattern>(patterns.getContext());
256 }
257 
259  RewritePatternSet &patterns) {
260  // WGSL currently does not support `isInf` and `isNan`, see:
261  // https://github.com/gpuweb/gpuweb/pull/2311.
262  patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
263 }
264 
265 } // namespace spirv
266 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.