21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/FormatVariadic.h"
30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
41 static Attribute getScalarOrSplatAttr(
Type type, int64_t value) {
43 if (
auto intTy = dyn_cast<IntegerType>(type))
49 static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter,
Value lhs,
51 Value rhs,
bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy = lhs.getType();
67 Value cstLowMask = rewriter.create<ConstantOp>(
68 loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](
Value val) {
70 return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
73 Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](
Value val) {
76 return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](
Value val) {
85 rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
88 Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
91 Value lhsLow = getLowDigit(lhs);
92 Value lhsHigh = getHighDigit(lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94 Value rhsLow = getLowDigit(rhs);
95 Value rhsHigh = getHighDigit(rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
104 if (i +
j >= resultDigits.size())
107 if (lhsDigit == cst0 || rhsDigit == cst0)
110 Value &thisResDigit = resultDigits[i +
j];
111 Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113 thisResDigit = getLowDigit(current);
115 if (i +
j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i +
j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
124 auto combineDigits = [loc, cst16, &rewriter](
Value low,
Value high) {
125 Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
126 return rewriter.create<BitwiseOrOp>(loc, low, highBits);
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 return rewriter.create<CompositeConstructOp>(
132 loc, mulOp->getResultTypes().front(),
llvm::ArrayRef({low, high}));
139 template <
typename MulExtendedOp,
bool SignExtendArguments>
140 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
143 LogicalResult matchAndRewrite(MulExtendedOp op,
144 PatternRewriter &rewriter)
const override {
145 Location loc = op->getLoc();
146 Value lhs = op.getOperand1();
147 Value rhs = op.getOperand2();
152 if (elemTy.getIntOrFloatBitWidth() != 32)
153 return rewriter.notifyMatchFailure(
155 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
157 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
158 SignExtendArguments);
159 rewriter.replaceOp(op, mul);
164 using ExpandSMulExtendedPattern =
165 ExpandMulExtendedPattern<SMulExtendedOp, true>;
166 using ExpandUMulExtendedPattern =
167 ExpandMulExtendedPattern<UMulExtendedOp, false>;
169 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
172 LogicalResult matchAndRewrite(IAddCarryOp op,
173 PatternRewriter &rewriter)
const override {
174 Location loc = op->getLoc();
175 Value lhs = op.getOperand1();
176 Value rhs = op.getOperand2();
180 Type argTy = lhs.getType();
182 if (elemTy.getIntOrFloatBitWidth() != 32)
183 return rewriter.notifyMatchFailure(
185 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
188 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
190 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
193 Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
194 Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
195 Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
197 Value add = rewriter.create<CompositeConstructOp>(
200 rewriter.replaceOp(op, add);
205 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
208 LogicalResult matchAndRewrite(IsInfOp op,
209 PatternRewriter &rewriter)
const override {
211 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
212 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
217 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
220 LogicalResult matchAndRewrite(IsNanOp op,
221 PatternRewriter &rewriter)
const override {
223 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
224 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
232 struct WebGPUPreparePass final
233 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
234 void runOnOperation()
override {
253 patterns.
add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
254 ExpandAddCarryPattern>(patterns.
getContext());
261 patterns.
add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.
getContext());
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.