23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
43 memref::ReinterpretCastOp::Adaptor adaptor,
44 memref::ReinterpretCastOp op, MemRefType newTy,
46 if (newTy == op.getType()) {
47 return rewriter.notifyMatchFailure(
48 op,
"result type was not converted by narrow-type emulation");
51 Type convertedElementType = newTy.getElementType();
52 Type oldElementType = op.getType().getElementType();
55 if (dstBits % srcBits != 0) {
56 return rewriter.notifyMatchFailure(op,
57 "only dstBits % srcBits == 0 supported");
61 if (!staticStrides.empty() && staticStrides.back() != 1) {
62 return rewriter.notifyMatchFailure(
63 op->getLoc(),
"innermost stride != 1 is not supported");
68 if (llvm::is_contained(op.getStaticSizes(), ShapedType::kDynamic)) {
69 return rewriter.notifyMatchFailure(op,
"dynamic sizes are not supported");
73 return rewriter.notifyMatchFailure(
74 op,
"result memref is not row-major contiguous");
81 llvm::is_contained(op.getStaticOffsets(), ShapedType::kDynamic)) {
82 return rewriter.notifyMatchFailure(
83 op,
"dynamic offsets require assumeAligned=true to ensure the offset "
84 "is a multiple of dstBits / srcBits");
95 if (mixedSizes.empty()) {
96 int64_t elementsPerByte = dstBits / srcBits;
100 rewriter, loc, s0.
floorDiv(elementsPerByte), {origOffset});
102 rewriter, loc, s0 % elementsPerByte, {origOffset});
109 rewriter, loc, srcBits, dstBits, origOffset, mixedSizes,
114 newStrides.push_back(rewriter.getIndexAttr(1));
118 return rewriter.notifyMatchFailure(
119 op,
"offset is provably not a multiple of dstBits / srcBits");
122 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
123 op, newTy, adaptor.getSource(), newOffset, newSizes, newStrides);
135 int sourceBits,
int targetBits,
137 assert(targetBits % sourceBits == 0);
140 int scaleFactor = targetBits / sourceBits;
141 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
146 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
157 auto maskRightAlignedAttr =
159 Value maskRightAligned = arith::ConstantOp::create(
160 builder, loc, dstIntegerType, maskRightAlignedAttr);
161 Value writeMaskInverse =
162 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
165 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
166 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
177 int64_t scaler = dstBits / srcBits;
179 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
187 auto stridedMetadata =
188 memref::ExtractStridedMetadataOp::create(builder, loc,
memref);
190 std::tie(std::ignore, linearizedIndices) =
192 builder, loc, srcBits, srcBits,
193 stridedMetadata.getConstifiedMixedOffset(),
194 stridedMetadata.getConstifiedMixedSizes(),
195 stridedMetadata.getConstifiedMixedStrides(),
indices);
196 return linearizedIndices;
205template <
typename OpTy>
206struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
207 using OpConversionPattern<OpTy>::OpConversionPattern;
210 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
211 ConversionPatternRewriter &rewriter)
const override {
212 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
213 std::is_same<OpTy, memref::AllocaOp>(),
214 "expected only memref::AllocOp or memref::AllocaOp");
215 auto currentType = cast<MemRefType>(op.getMemref().getType());
217 this->getTypeConverter()->template convertType<MemRefType>(
219 if (!newResultType) {
220 return rewriter.notifyMatchFailure(
222 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
226 if (currentType.getRank() == 0) {
227 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType,
ValueRange{},
228 adaptor.getSymbolOperands(),
229 adaptor.getAlignmentAttr());
233 Location loc = op.getLoc();
234 OpFoldResult zero = rewriter.getIndexAttr(0);
237 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
238 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
239 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
241 memref::LinearizedMemRefInfo linearizedMemRefInfo =
243 rewriter, loc, srcBits, dstBits, zero, sizes);
244 SmallVector<Value> dynamicLinearizedSize;
245 if (!newResultType.hasStaticShape()) {
250 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
251 adaptor.getSymbolOperands(),
252 adaptor.getAlignmentAttr());
261struct ConvertMemRefAssumeAlignment final
262 : OpConversionPattern<memref::AssumeAlignmentOp> {
263 using OpConversionPattern::OpConversionPattern;
266 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter)
const override {
268 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
270 return rewriter.notifyMatchFailure(
271 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
272 op.getMemref().getType()));
275 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
276 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
285struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
286 using OpConversionPattern::OpConversionPattern;
289 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter)
const override {
291 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
292 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
293 if (maybeRankedSource && maybeRankedDest &&
294 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
295 return rewriter.notifyMatchFailure(
296 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
297 "and {1}) is currently unimplemented",
298 maybeRankedSource.getLayout(),
299 maybeRankedDest.getLayout()));
300 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
301 adaptor.getTarget());
310struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
311 using OpConversionPattern::OpConversionPattern;
314 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter)
const override {
316 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
325struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
326 using OpConversionPattern::OpConversionPattern;
329 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter)
const override {
331 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
332 auto convertedElementType = convertedType.getElementType();
333 auto oldElementType = op.getMemRefType().getElementType();
334 int srcBits = oldElementType.getIntOrFloatBitWidth();
335 int dstBits = convertedElementType.getIntOrFloatBitWidth();
336 if (dstBits % srcBits != 0) {
337 return rewriter.notifyMatchFailure(
338 op,
"only dstBits % srcBits == 0 supported");
341 Location loc = op.getLoc();
344 if (convertedType.getRank() == 0) {
345 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
351 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
353 Value newLoad = memref::LoadOp::create(
354 rewriter, loc, adaptor.getMemref(),
361 srcBits, dstBits, rewriter);
362 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
371 auto resultTy = getTypeConverter()->convertType(oldElementType);
375 : IntegerType::get(rewriter.getContext(),
376 resultTy.getIntOrFloatBitWidth());
377 if (conversionTy == convertedElementType) {
378 auto mask = arith::ConstantOp::create(
379 rewriter, loc, convertedElementType,
380 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
382 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
384 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
387 if (conversionTy != resultTy) {
388 result = arith::BitcastOp::create(rewriter, loc, resultTy,
result);
391 rewriter.replaceOp(op,
result);
402struct ConvertMemRefCast final : OpConversionPattern<memref::CastOp> {
403 using OpConversionPattern::OpConversionPattern;
406 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const override {
408 Type newTy = getTypeConverter()->convertType(op.getType());
410 return rewriter.notifyMatchFailure(
412 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
414 if (newTy == op.getType())
417 rewriter.replaceOpWithNewOp<memref::CastOp>(op, newTy, adaptor.getSource());
426struct ConvertMemRefMemorySpaceCast final
427 : OpConversionPattern<memref::MemorySpaceCastOp> {
428 using OpConversionPattern::OpConversionPattern;
431 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter)
const override {
433 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
435 return rewriter.notifyMatchFailure(
436 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
437 op.getDest().getType()));
440 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
441 adaptor.getSource());
453struct ConvertMemRefReinterpretCast final
454 : OpConversionPattern<memref::ReinterpretCastOp> {
455 ConvertMemRefReinterpretCast(
const TypeConverter &typeConverter,
456 MLIRContext *context,
bool assumeAligned)
457 : OpConversionPattern<memref::ReinterpretCastOp>(typeConverter, context),
458 assumeAligned(assumeAligned) {}
461 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter)
const override {
464 getTypeConverter()->convertType<MemRefType>(op.getType());
466 return rewriter.notifyMatchFailure(
468 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
486struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
487 using OpConversionPattern::OpConversionPattern;
489 ConvertMemrefStore(
const TypeConverter &typeConverter, MLIRContext *context,
490 bool disableAtomicRMW)
491 : OpConversionPattern<memref::StoreOp>(typeConverter, context),
492 disableAtomicRMW(disableAtomicRMW) {}
495 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter)
const override {
497 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
498 int srcBits = op.getMemRefType().getElementTypeBitWidth();
499 int dstBits = convertedType.getElementTypeBitWidth();
500 auto dstIntegerType = rewriter.getIntegerType(dstBits);
501 if (dstBits % srcBits != 0) {
502 return rewriter.notifyMatchFailure(
503 op,
"only dstBits % srcBits == 0 supported");
506 Location loc = op.getLoc();
509 Value input = adaptor.getValue();
511 input = arith::BitcastOp::create(
513 IntegerType::get(rewriter.getContext(),
517 Value extendedInput =
518 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
522 if (convertedType.getRank() == 0) {
523 memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
525 rewriter.eraseOp(op);
530 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
532 rewriter, loc, linearizedIndices, srcBits, dstBits);
536 dstBits, bitwidthOffset, rewriter);
539 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
541 if (disableAtomicRMW) {
543 Value origValue = memref::LoadOp::create(
544 rewriter, loc, adaptor.getMemref(), storeIndices);
547 arith::AndIOp::create(rewriter, loc, origValue, writeMask);
551 arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
552 memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
557 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
558 writeMask, adaptor.getMemref(), storeIndices);
560 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
561 alignedVal, adaptor.getMemref(),
564 rewriter.eraseOp(op);
569 bool disableAtomicRMW;
583struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
584 ConvertMemRefSubview(
const TypeConverter &typeConverter, MLIRContext *context,
586 : OpConversionPattern<memref::SubViewOp>(typeConverter, context),
587 assumeAligned(assumeAligned) {}
590 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter)
const override {
593 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
595 return rewriter.notifyMatchFailure(
597 llvm::formatv(
"failed to convert memref type: {0}",
598 subViewOp.getType()));
601 Location loc = subViewOp.getLoc();
602 Type convertedElementType = newTy.getElementType();
603 Type oldElementType = subViewOp.getType().getElementType();
606 if (dstBits % srcBits != 0)
607 return rewriter.notifyMatchFailure(
608 subViewOp,
"only dstBits % srcBits == 0 supported");
611 if (llvm::any_of(subViewOp.getStaticStrides(),
612 [](int64_t stride) { return stride != 1; })) {
613 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
614 "stride != 1 is not supported");
618 return rewriter.notifyMatchFailure(
619 subViewOp,
"the result memref type is not contiguous");
622 auto sizes = subViewOp.getStaticSizes();
625 if (llvm::is_contained(sizes, ShapedType::kDynamic)) {
626 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
627 "dynamic size is not supported");
632 if (!assumeAligned && llvm::is_contained(subViewOp.getStaticOffsets(),
633 ShapedType::kDynamic)) {
634 return rewriter.notifyMatchFailure(
636 "dynamic offsets require assumeAligned=true to ensure the offset "
637 "is a multiple of dstBits / srcBits");
641 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
642 rewriter, loc, subViewOp.getViewSource());
644 OpFoldResult linearizedIndices;
645 auto strides = stridedMetadata.getConstifiedMixedStrides();
646 memref::LinearizedMemRefInfo linearizedInfo;
647 std::tie(linearizedInfo, linearizedIndices) =
649 rewriter, loc, srcBits, dstBits,
650 stridedMetadata.getConstifiedMixedOffset(),
651 subViewOp.getMixedSizes(), strides,
657 return rewriter.notifyMatchFailure(
659 "subview offset is provably not a multiple of dstBits / srcBits");
662 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
663 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
679struct ConvertMemRefCollapseShape final
680 : OpConversionPattern<memref::CollapseShapeOp> {
681 using OpConversionPattern::OpConversionPattern;
684 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
685 ConversionPatternRewriter &rewriter)
const override {
686 Value srcVal = adaptor.getSrc();
687 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
691 if (newTy.getRank() != 1)
694 rewriter.replaceOp(collapseShapeOp, srcVal);
702struct ConvertMemRefExpandShape final
703 : OpConversionPattern<memref::ExpandShapeOp> {
704 using OpConversionPattern::OpConversionPattern;
707 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter)
const override {
709 Value srcVal = adaptor.getSrc();
710 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
714 if (newTy.getRank() != 1)
717 rewriter.replaceOp(expandShapeOp, srcVal);
733 .
add<ConvertMemRefAllocation<memref::AllocOp>,
734 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCast,
735 ConvertMemRefCopy, ConvertMemRefDealloc, ConvertMemRefCollapseShape,
736 ConvertMemRefExpandShape, ConvertMemRefLoad,
737 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast>(
739 patterns.
add<ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
740 typeConverter, patterns.
getContext(), assumeAligned);
748 if (ty.getRank() == 0)
752 for (
auto shape : ty.getShape()) {
753 if (
shape == ShapedType::kDynamic)
754 return {ShapedType::kDynamic};
755 linearizedShape *=
shape;
757 int scale = dstBits / srcBits;
760 linearizedShape = (linearizedShape + scale - 1) / scale;
761 return {linearizedShape};
766 typeConverter.addConversion(
767 [&typeConverter](MemRefType ty) -> std::optional<Type> {
768 Type elementType = ty.getElementType();
774 if (width >= loadStoreWidth)
780 if (failed(ty.getStridesAndOffset(strides, offset)))
782 if (!strides.empty() && strides.back() != 1)
785 auto newElemTy = IntegerType::get(
786 ty.getContext(), loadStoreWidth,
788 ? cast<IntegerType>(elementType).getSignedness()
789 : IntegerType::SignednessSemantics::Signless);
793 StridedLayoutAttr layoutAttr;
797 if (offset == ShapedType::kDynamic) {
798 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
803 if ((offset * width) % loadStoreWidth != 0)
805 offset = (offset * width) / loadStoreWidth;
807 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
813 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false, bool assumeAligned=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize
OpFoldResult linearizedOffset
OpFoldResult intraDataOffset