21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
43 if (
auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value
lhs,
51 Value
rhs,
bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy =
lhs.getType();
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc,
lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
73 Value cst16 = ConstantOp::create(rewriter, loc,
lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
88 Value cst0 = ConstantOp::create(rewriter, loc,
lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
91 Value lhsLow = getLowDigit(
lhs);
92 Value lhsHigh = getHighDigit(
lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(
lhs) : cst0;
94 Value rhsLow = getLowDigit(
rhs);
95 Value rhsHigh = getHighDigit(
rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(
rhs) : cst0;
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102 for (
auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (
auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
107 if (lhsDigit == cst0 || rhsDigit == cst0)
110 Value &thisResDigit = resultDigits[i + j];
111 Value
mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit,
mul);
113 thisResDigit = getLowDigit(current);
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
133 llvm::ArrayRef({low, high}));
140template <
typename MulExtendedOp,
bool SignExtendArguments>
141struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter)
const override {
146 Location loc = op->getLoc();
147 Value
lhs = op.getOperand1();
148 Value
rhs = op.getOperand2();
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
156 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
158 Value
mul = lowerExtendedMultiplication(op, rewriter,
lhs,
rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op,
mul);
165using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
170template <
typename Op,
typename ArithOp>
171struct ExpandAddCarryOrSubBorrowPattern final : OpRewritePattern<Op> {
172 using OpRewritePattern<
Op>::OpRewritePattern;
174 LogicalResult matchAndRewrite(
Op op,
175 PatternRewriter &rewriter)
const override {
176 Location loc = op->getLoc();
177 Value
lhs = op.getOperand1();
178 Value
rhs = op.getOperand2();
182 Type argTy =
lhs.getType();
184 if (elemTy.getIntOrFloatBitWidth() != 32)
185 return rewriter.notifyMatchFailure(
187 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
189 Value one = ConstantOp::create(rewriter, loc, argTy,
190 getScalarOrSplatAttr(argTy, 1));
191 Value zero = ConstantOp::create(rewriter, loc, argTy,
192 getScalarOrSplatAttr(argTy, 0));
194 Value out = ArithOp::create(rewriter, loc,
lhs,
rhs);
198 if constexpr (std::is_same_v<Op, IAddCarryOp>)
199 cmp = ULessThanOp::create(rewriter, loc, out,
lhs);
201 cmp = ULessThanOp::create(rewriter, loc,
lhs,
rhs);
202 Value flag = SelectOp::create(rewriter, loc, cmp, one, zero);
204 Value
result = CompositeConstructOp::create(rewriter, loc,
205 op->getResultTypes().front(),
206 llvm::ArrayRef({out, flag}));
208 rewriter.replaceOp(op,
result);
213using ExpandAddCarryPattern =
214 ExpandAddCarryOrSubBorrowPattern<IAddCarryOp, IAddOp>;
215using ExpandSubBorrowPattern =
216 ExpandAddCarryOrSubBorrowPattern<ISubBorrowOp, ISubOp>;
218struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
221 LogicalResult matchAndRewrite(IsInfOp op,
222 PatternRewriter &rewriter)
const override {
224 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
225 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
230struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
233 LogicalResult matchAndRewrite(IsNanOp op,
234 PatternRewriter &rewriter)
const override {
236 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
237 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
245struct WebGPUPreparePass final
247 void runOnOperation()
override {
265 patterns.
add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
266 ExpandAddCarryPattern, ExpandSubBorrowPattern>(
274 patterns.
add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.
getContext());
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.