22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
31 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
42 Attribute getScalarOrSplatAttr(
Type type, int64_t value) {
44 if (
auto intTy = dyn_cast<IntegerType>(type))
50 Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
52 bool signExtendArguments) {
53 Location loc = mulOp->getLoc();
54 Type argTy = lhs.getType();
68 Value cstLowMask = rewriter.create<ConstantOp>(
69 loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
70 auto getLowDigit = [&rewriter, loc, cstLowMask](
Value val) {
71 return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
74 Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
75 getScalarOrSplatAttr(argTy, 16));
76 auto getHighDigit = [&rewriter, loc, cst16](
Value val) {
77 return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
80 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](
Value val) {
86 rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
89 Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
90 getScalarOrSplatAttr(argTy, 0));
92 Value lhsLow = getLowDigit(lhs);
93 Value lhsHigh = getHighDigit(lhs);
94 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
95 Value rhsLow = getLowDigit(rhs);
96 Value rhsHigh = getHighDigit(rhs);
97 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
99 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
100 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
101 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
105 if (i +
j >= resultDigits.size())
108 if (lhsDigit == cst0 || rhsDigit == cst0)
111 Value &thisResDigit = resultDigits[i +
j];
112 Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
113 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
114 thisResDigit = getLowDigit(current);
116 if (i +
j + 1 != resultDigits.size()) {
117 Value &nextResDigit = resultDigits[i +
j + 1];
118 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
119 getHighDigit(current));
120 nextResDigit = carry;
125 auto combineDigits = [loc, cst16, &rewriter](
Value low,
Value high) {
126 Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
127 return rewriter.create<BitwiseOrOp>(loc, low, highBits);
129 Value low = combineDigits(resultDigits[0], resultDigits[1]);
130 Value high = combineDigits(resultDigits[2], resultDigits[3]);
132 return rewriter.create<CompositeConstructOp>(
133 loc, mulOp->getResultTypes().front(),
llvm::ArrayRef({low, high}));
140 template <
typename MulExtendedOp,
bool SignExtendArguments>
141 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
144 LogicalResult matchAndRewrite(MulExtendedOp op,
145 PatternRewriter &rewriter)
const override {
146 Location loc = op->getLoc();
147 Value lhs = op.getOperand1();
148 Value rhs = op.getOperand2();
153 if (elemTy.getIntOrFloatBitWidth() != 32)
154 return rewriter.notifyMatchFailure(
156 llvm::formatv(
"Unexpected integer type for WebGPU: '{0}'", elemTy));
158 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
159 SignExtendArguments);
160 rewriter.replaceOp(op, mul);
165 using ExpandSMulExtendedPattern =
166 ExpandMulExtendedPattern<SMulExtendedOp, true>;
167 using ExpandUMulExtendedPattern =
168 ExpandMulExtendedPattern<UMulExtendedOp, false>;
173 class WebGPUPreparePass
174 :
public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
176 void runOnOperation()
override {
194 patterns.
add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
static MLIRContext * getContext(OpFoldResult val)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Type
An inlay hint that for a type annotation.
void populateSPIRVExpandExtendedMultiplicationPatterns(RewritePatternSet &patterns)
Appends to a pattern list additional patterns to expand extended multiplication ops into regular arit...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.