MLIR  18.0.0git
SPIRVWebGPUTransforms.cpp
Go to the documentation of this file.
1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements SPIR-V transforms used when targetting WebGPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #include <array>
27 #include <cstdint>
28 
29 namespace mlir {
30 namespace spirv {
31 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33 } // namespace spirv
34 } // namespace mlir
35 
36 namespace mlir {
37 namespace spirv {
38 namespace {
39 //===----------------------------------------------------------------------===//
40 // Helpers
41 //===----------------------------------------------------------------------===//
42 Attribute getScalarOrSplatAttr(Type type, int64_t value) {
43  APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
44  if (auto intTy = dyn_cast<IntegerType>(type))
45  return IntegerAttr::get(intTy, sizedValue);
46 
47  return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
48 }
49 
50 Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
51  Value lhs, Value rhs,
52  bool signExtendArguments) {
53  Location loc = mulOp->getLoc();
54  Type argTy = lhs.getType();
55  // Emulate 64-bit multiplication by splitting each input element of type i32
56  // into 2 16-bit digits of type i32. This is so that the intermediate
57  // multiplications and additions do not overflow. We extract these 16-bit
58  // digits from i32 vector elements by masking (low digit) and shifting right
59  // (high digit).
60  //
61  // The multiplication algorithm used is the standard (long) multiplication.
62  // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
63  // digits.
64  // - With zero-extended arguments, we end up emitting only 4 multiplications
65  // and 4 additions after constant folding.
66  // - With sign-extended arguments, we end up emitting 8 multiplications and
67  // and 12 additions after CSE.
68  Value cstLowMask = rewriter.create<ConstantOp>(
69  loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
70  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
71  return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
72  };
73 
74  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
75  getScalarOrSplatAttr(argTy, 16));
76  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
77  return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
78  };
79 
80  auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
81  // We only need to shift arithmetically by 15, but the extra
82  // sign-extension bit will be truncated by the logical shift, so this is
83  // fine. We do not have to introduce an extra constant since any
84  // value in [15, 32) would do.
85  return getHighDigit(
86  rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
87  };
88 
89  Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
90  getScalarOrSplatAttr(argTy, 0));
91 
92  Value lhsLow = getLowDigit(lhs);
93  Value lhsHigh = getHighDigit(lhs);
94  Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
95  Value rhsLow = getLowDigit(rhs);
96  Value rhsHigh = getHighDigit(rhs);
97  Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
98 
99  std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
100  std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
101  std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102 
103  for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
104  for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
105  if (i + j >= resultDigits.size())
106  continue;
107 
108  if (lhsDigit == cst0 || rhsDigit == cst0)
109  continue;
110 
111  Value &thisResDigit = resultDigits[i + j];
112  Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
113  Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
114  thisResDigit = getLowDigit(current);
115 
116  if (i + j + 1 != resultDigits.size()) {
117  Value &nextResDigit = resultDigits[i + j + 1];
118  Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
119  getHighDigit(current));
120  nextResDigit = carry;
121  }
122  }
123  }
124 
125  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
126  Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
127  return rewriter.create<BitwiseOrOp>(loc, low, highBits);
128  };
129  Value low = combineDigits(resultDigits[0], resultDigits[1]);
130  Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 
132  return rewriter.create<CompositeConstructOp>(
133  loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Rewrite Patterns
138 //===----------------------------------------------------------------------===//
139 
140 template <typename MulExtendedOp, bool SignExtendArguments>
141 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
143 
144  LogicalResult matchAndRewrite(MulExtendedOp op,
145  PatternRewriter &rewriter) const override {
146  Location loc = op->getLoc();
147  Value lhs = op.getOperand1();
148  Value rhs = op.getOperand2();
149 
150  // Currently, WGSL only supports 32-bit integer types. Any other integer
151  // types should already have been promoted/demoted to i32.
152  auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
153  if (elemTy.getIntOrFloatBitWidth() != 32)
154  return rewriter.notifyMatchFailure(
155  loc,
156  llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
157 
158  Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159  SignExtendArguments);
160  rewriter.replaceOp(op, mul);
161  return success();
162  }
163 };
164 
165 using ExpandSMulExtendedPattern =
166  ExpandMulExtendedPattern<SMulExtendedOp, true>;
167 using ExpandUMulExtendedPattern =
168  ExpandMulExtendedPattern<UMulExtendedOp, false>;
169 
170 //===----------------------------------------------------------------------===//
171 // Passes
172 //===----------------------------------------------------------------------===//
173 class WebGPUPreparePass
174  : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
175 public:
176  void runOnOperation() override {
177  RewritePatternSet patterns(&getContext());
179 
180  if (failed(
181  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
182  signalPassFailure();
183  }
184 };
185 } // namespace
186 
187 //===----------------------------------------------------------------------===//
188 // Public Interface
189 //===----------------------------------------------------------------------===//
191  RewritePatternSet &patterns) {
192  // WGSL currently does not support extended multiplication ops, see:
193  // https://github.com/gpuweb/gpuweb/issues/1565.
194  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
195  patterns.getContext());
196 }
197 } // namespace spirv
198 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends to a pattern list additional patterns to expand extended multiplication ops into regular arit...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.