21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/FormatVariadic.h"
30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
41 static Attribute getScalarOrSplatAttr(
Type type, int64_t value) {
43 if (
auto intTy = dyn_cast<IntegerType>(type))
49 static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter,
Value lhs,
51 Value rhs,
bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy = lhs.getType();
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](
Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
73 Value cst16 = ConstantOp::create(rewriter, loc, lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](
Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](
Value val) {
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
88 Value cst0 = ConstantOp::create(rewriter, loc, lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
91 Value lhsLow = getLowDigit(lhs);
92 Value lhsHigh = getHighDigit(lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94 Value rhsLow = getLowDigit(rhs);
95 Value rhsHigh = getHighDigit(rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
104 if (i +
j >= resultDigits.size())
107 if (lhsDigit == cst0 || rhsDigit == cst0)
110 Value &thisResDigit = resultDigits[i +
j];
111 Value mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113 thisResDigit = getLowDigit(current);
115 if (i +
j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i +
j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
124 auto combineDigits = [loc, cst16, &rewriter](
Value low,
Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
140 template <
typename MulExtendedOp,
bool SignExtendArguments>
141 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter)
const override {
146 Location loc = op->getLoc();
147 Value lhs = op.getOperand1();
148 Value rhs = op.getOperand2();
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
156 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
158 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op, mul);
165 using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167 using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
170 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
173 LogicalResult matchAndRewrite(IAddCarryOp op,
174 PatternRewriter &rewriter)
const override {
175 Location loc = op->getLoc();
176 Value lhs = op.getOperand1();
177 Value rhs = op.getOperand2();
181 Type argTy = lhs.getType();
183 if (elemTy.getIntOrFloatBitWidth() != 32)
184 return rewriter.notifyMatchFailure(
186 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
188 Value one = ConstantOp::create(rewriter, loc, argTy,
189 getScalarOrSplatAttr(argTy, 1));
190 Value zero = ConstantOp::create(rewriter, loc, argTy,
191 getScalarOrSplatAttr(argTy, 0));
194 Value out = IAddOp::create(rewriter, loc, lhs, rhs);
195 Value cmp = ULessThanOp::create(rewriter, loc, out, lhs);
196 Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
198 Value add = CompositeConstructOp::create(rewriter, loc,
199 op->getResultTypes().front(),
202 rewriter.replaceOp(op,
add);
207 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
210 LogicalResult matchAndRewrite(IsInfOp op,
211 PatternRewriter &rewriter)
const override {
213 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
219 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
222 LogicalResult matchAndRewrite(IsNanOp op,
223 PatternRewriter &rewriter)
const override {
225 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
234 struct WebGPUPreparePass final
235 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
236 void runOnOperation()
override {
254 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255 ExpandAddCarryPattern>(
patterns.getContext());
262 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(
patterns.getContext());
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.