23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
26 #include <type_traits>
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
64 op,
"dynamic size or offset is not supported");
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
70 op,
"offset not multiple of elementsPerByte is not supported");
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
91 int sourceBits,
int targetBits,
93 assert(targetBits % sourceBits == 0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
102 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
110 int64_t srcBits, int64_t dstBits,
113 auto maskRightAlignedAttr =
115 Value maskRightAligned = arith::ConstantOp::create(
116 builder, loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
121 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
130 int64_t srcBits, int64_t dstBits) {
133 int64_t scaler = dstBits / srcBits;
135 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
143 auto stridedMetadata =
144 memref::ExtractStridedMetadataOp::create(builder, loc, memref);
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
161 template <
typename OpTy>
166 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
173 this->getTypeConverter()->template convertType<MemRefType>(
175 if (!newResultType) {
178 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
182 if (currentType.getRank() == 0) {
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
199 rewriter, loc, srcBits, dstBits, zero, sizes);
201 if (!newResultType.hasStaticShape()) {
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
217 struct ConvertMemRefAssumeAlignment final
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
227 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
228 op.getMemref().getType()));
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
252 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
257 adaptor.getTarget());
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
294 op,
"only dstBits % srcBits == 0 supported");
300 if (convertedType.getRank() == 0) {
301 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
309 Value newLoad = memref::LoadOp::create(
310 rewriter, loc, adaptor.getMemref(),
317 srcBits, dstBits, rewriter);
318 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
332 resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = arith::ConstantOp::create(
335 rewriter, loc, convertedElementType,
336 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
338 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
340 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
343 if (conversionTy != resultTy) {
344 result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
356 struct ConvertMemRefMemorySpaceCast final
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
363 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
366 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
367 op.getDest().getType()));
371 adaptor.getSource());
382 struct ConvertMemRefReinterpretCast final
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
390 getTypeConverter()->convertType<MemRefType>(op.getType());
394 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
398 if (op.getType().getRank() > 1) {
400 op->getLoc(),
"subview with rank > 1 is not supported");
415 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
417 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
418 int srcBits = op.getMemRefType().getElementTypeBitWidth();
419 int dstBits = convertedType.getElementTypeBitWidth();
421 if (dstBits % srcBits != 0) {
423 op,
"only dstBits % srcBits == 0 supported");
429 Value input = adaptor.getValue();
431 input = arith::BitcastOp::create(
437 Value extendedInput =
438 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
441 if (convertedType.getRank() == 0) {
442 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
443 extendedInput, adaptor.getMemref(),
450 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
452 rewriter, loc, linearizedIndices, srcBits, dstBits);
456 dstBits, bitwidthOffset, rewriter);
459 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
462 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
463 writeMask, adaptor.getMemref(), storeIndices);
465 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
466 alignedVal, adaptor.getMemref(), storeIndices);
484 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
487 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
491 llvm::formatv(
"failed to convert memref type: {0}",
492 subViewOp.getType()));
496 Type convertedElementType = newTy.getElementType();
497 Type oldElementType = subViewOp.getType().getElementType();
500 if (dstBits % srcBits != 0)
502 subViewOp,
"only dstBits % srcBits == 0 supported");
505 if (llvm::any_of(subViewOp.getStaticStrides(),
506 [](int64_t stride) { return stride != 1; })) {
508 "stride != 1 is not supported");
513 subViewOp,
"the result memref type is not contiguous");
516 auto sizes = subViewOp.getStaticSizes();
517 int64_t lastOffset = subViewOp.getStaticOffsets().back();
519 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
520 lastOffset == ShapedType::kDynamic) {
522 subViewOp->getLoc(),
"dynamic size or offset is not supported");
526 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
527 rewriter, loc, subViewOp.getViewSource());
530 auto strides = stridedMetadata.getConstifiedMixedStrides();
532 std::tie(linearizedInfo, linearizedIndices) =
534 rewriter, loc, srcBits, dstBits,
535 stridedMetadata.getConstifiedMixedOffset(),
536 subViewOp.getMixedSizes(), strides,
541 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
554 struct ConvertMemRefCollapseShape final
559 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
561 Value srcVal = adaptor.getSrc();
562 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
566 if (newTy.getRank() != 1)
569 rewriter.
replaceOp(collapseShapeOp, srcVal);
577 struct ConvertMemRefExpandShape final
582 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
584 Value srcVal = adaptor.getSrc();
585 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
589 if (newTy.getRank() != 1)
592 rewriter.
replaceOp(expandShapeOp, srcVal);
607 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
608 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
609 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
610 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
611 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
612 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
613 typeConverter,
patterns.getContext());
619 if (ty.getRank() == 0)
622 int64_t linearizedShape = 1;
623 for (
auto shape : ty.getShape()) {
624 if (shape == ShapedType::kDynamic)
625 return {ShapedType::kDynamic};
626 linearizedShape *= shape;
628 int scale = dstBits / srcBits;
631 linearizedShape = (linearizedShape + scale - 1) / scale;
632 return {linearizedShape};
638 [&typeConverter](MemRefType ty) -> std::optional<Type> {
639 Type elementType = ty.getElementType();
645 if (width >= loadStoreWidth)
651 if (
failed(ty.getStridesAndOffset(strides, offset)))
653 if (!strides.empty() && strides.back() != 1)
657 ty.getContext(), loadStoreWidth,
659 ? cast<IntegerType>(elementType).getSignedness()
660 : IntegerType::SignednessSemantics::Signless);
664 StridedLayoutAttr layoutAttr;
668 if (offset == ShapedType::kDynamic) {
674 if ((offset * width) % loadStoreWidth != 0)
676 offset = (offset * width) / loadStoreWidth;
684 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize