26 using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
30 setHasBoundedRewriteRecursion();
32 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
33 PatternRewriter &rewriter)
const override {
34 auto loc = op.getLoc();
35 auto value = op.getValue();
36 auto valueType = value.getType();
37 auto valueLoc = value.getLoc();
42 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
44 op,
"only 64-bit int/float types are supported");
49 if (isa<FloatType>(valueType))
50 value = arith::BitcastOp::create(rewriter, valueLoc, i64, value);
53 lo = arith::TruncIOp::create(rewriter, valueLoc, i32, value);
56 auto c32 = arith::ConstantOp::create(rewriter, valueLoc,
58 hi = arith::ShRUIOp::create(rewriter, valueLoc, value, c32);
59 hi = arith::TruncIOp::create(rewriter, valueLoc, i32, hi);
63 gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(),
64 op.getWidth(), op.getMode())
67 gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(),
68 op.getWidth(), op.getMode())
72 lo = arith::ExtUIOp::create(rewriter, valueLoc, i64, loRes[0]);
75 hi = arith::ExtUIOp::create(rewriter, valueLoc, i64, hiRes[0]);
76 hi = arith::ShLIOp::create(rewriter, valueLoc, hi, c32);
79 value = arith::OrIOp::create(rewriter, loc, hi, lo);
82 if (isa<FloatType>(valueType))
83 value = arith::BitcastOp::create(rewriter, valueLoc, valueType, value);
86 auto validity = arith::AndIOp::create(rewriter, loc, loRes[1], hiRes[1]);
89 rewriter.
replaceOp(op, {value, validity});
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
IntegerAttr getIntegerAttr(Type type, int64_t value)
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Include the generated interface declarations.
void populateGpuShufflePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...