21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
43 if (
auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value
lhs,
51 Value
rhs,
bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy =
lhs.getType();
67 Value cstLowMask = ConstantOp::create(
68 rewriter, loc,
lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return BitwiseAndOp::create(rewriter, loc, val, cstLowMask);
73 Value cst16 = ConstantOp::create(rewriter, loc,
lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return ShiftRightLogicalOp::create(rewriter, loc, val, cst16);
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
85 ShiftRightArithmeticOp::create(rewriter, loc, val, cst16));
88 Value cst0 = ConstantOp::create(rewriter, loc,
lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
91 Value lhsLow = getLowDigit(
lhs);
92 Value lhsHigh = getHighDigit(
lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(
lhs) : cst0;
94 Value rhsLow = getLowDigit(
rhs);
95 Value rhsHigh = getHighDigit(
rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(
rhs) : cst0;
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
102 for (
auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (
auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
107 if (lhsDigit == cst0 || rhsDigit == cst0)
110 Value &thisResDigit = resultDigits[i + j];
111 Value
mul = IMulOp::create(rewriter, loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit,
mul);
113 thisResDigit = getLowDigit(current);
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = ShiftLeftLogicalOp::create(rewriter, loc, high, cst16);
126 return BitwiseOrOp::create(rewriter, loc, low, highBits);
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
131 return CompositeConstructOp::create(rewriter, loc,
132 mulOp->getResultTypes().front(),
133 llvm::ArrayRef({low, high}));
140template <
typename MulExtendedOp,
bool SignExtendArguments>
141struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
142 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter)
const override {
146 Location loc = op->getLoc();
147 Value
lhs = op.getOperand1();
148 Value
rhs = op.getOperand2();
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
156 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
158 Value
mul = lowerExtendedMultiplication(op, rewriter,
lhs,
rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op,
mul);
165using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
170struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
173 LogicalResult matchAndRewrite(IAddCarryOp op,
174 PatternRewriter &rewriter)
const override {
175 Location loc = op->getLoc();
176 Value
lhs = op.getOperand1();
177 Value
rhs = op.getOperand2();
181 Type argTy =
lhs.getType();
183 if (elemTy.getIntOrFloatBitWidth() != 32)
184 return rewriter.notifyMatchFailure(
186 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
188 Value one = ConstantOp::create(rewriter, loc, argTy,
189 getScalarOrSplatAttr(argTy, 1));
190 Value zero = ConstantOp::create(rewriter, loc, argTy,
191 getScalarOrSplatAttr(argTy, 0));
194 Value out = IAddOp::create(rewriter, loc,
lhs,
rhs);
195 Value cmp = ULessThanOp::create(rewriter, loc, out,
lhs);
196 Value carry = SelectOp::create(rewriter, loc, cmp, one, zero);
198 Value
add = CompositeConstructOp::create(rewriter, loc,
199 op->getResultTypes().front(),
200 llvm::ArrayRef({out, carry}));
202 rewriter.replaceOp(op,
add);
207struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
210 LogicalResult matchAndRewrite(IsInfOp op,
211 PatternRewriter &rewriter)
const override {
213 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
214 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
219struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
222 LogicalResult matchAndRewrite(IsNanOp op,
223 PatternRewriter &rewriter)
const override {
225 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
226 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
234struct WebGPUPreparePass final
236 void runOnOperation()
override {
254 patterns.
add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255 ExpandAddCarryPattern>(patterns.
getContext());
262 patterns.
add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.
getContext());
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void populateSPIRVExpandNonFiniteArithmeticPatterns(RewritePatternSet &patterns)
Appends patterns to expand non-finite arithmetic ops IsNan and IsInf.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends patterns to expand extended multiplication and adition ops into regular arithmetic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.