23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(op,
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](
int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(op->getLoc(),
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 op,
"dynamic size or offset is not supported");
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 op,
"offset not multiple of elementsPerByte is not supported");
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
91 int sourceBits,
int targetBits,
93 assert(targetBits % sourceBits == 0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
102 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
113 auto maskRightAlignedAttr =
115 Value maskRightAligned = arith::ConstantOp::create(
116 builder, loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
121 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
133 int64_t scaler = dstBits / srcBits;
135 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
143 auto stridedMetadata =
144 memref::ExtractStridedMetadataOp::create(builder, loc,
memref);
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(),
indices);
152 return linearizedIndices;
161template <
typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
166 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter)
const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
173 this->getTypeConverter()->template convertType<MemRefType>(
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
178 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType,
ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(0);
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
199 rewriter, loc, srcBits, dstBits, zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter)
const override {
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
226 return rewriter.notifyMatchFailure(
227 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
228 op.getMemref().getType()));
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257 adaptor.getTarget());
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 op,
"only dstBits % srcBits == 0 supported");
297 Location loc = op.getLoc();
300 if (convertedType.getRank() == 0) {
301 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
309 Value newLoad = memref::LoadOp::create(
310 rewriter, loc, adaptor.getMemref(),
317 srcBits, dstBits, rewriter);
318 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
331 : IntegerType::get(rewriter.getContext(),
332 resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = arith::ConstantOp::create(
335 rewriter, loc, convertedElementType,
336 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
338 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
340 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
343 if (conversionTy != resultTy) {
344 result = arith::BitcastOp::create(rewriter, loc, resultTy,
result);
347 rewriter.replaceOp(op,
result);
356struct ConvertMemRefMemorySpaceCast final
357 : OpConversionPattern<memref::MemorySpaceCastOp> {
358 using OpConversionPattern::OpConversionPattern;
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter)
const override {
363 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
365 return rewriter.notifyMatchFailure(
366 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
367 op.getDest().getType()));
370 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
371 adaptor.getSource());
382struct ConvertMemRefReinterpretCast final
383 : OpConversionPattern<memref::ReinterpretCastOp> {
384 using OpConversionPattern::OpConversionPattern;
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter)
const override {
390 getTypeConverter()->convertType<MemRefType>(op.getType());
392 return rewriter.notifyMatchFailure(
394 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
398 if (op.getType().getRank() > 1) {
399 return rewriter.notifyMatchFailure(
400 op->getLoc(),
"subview with rank > 1 is not supported");
415struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
416 using OpConversionPattern::OpConversionPattern;
418 ConvertMemrefStore(MLIRContext *context,
bool disableAtomicRMW)
419 : OpConversionPattern<memref::StoreOp>(context),
420 disableAtomicRMW(disableAtomicRMW) {}
423 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
426 int srcBits = op.getMemRefType().getElementTypeBitWidth();
427 int dstBits = convertedType.getElementTypeBitWidth();
428 auto dstIntegerType = rewriter.getIntegerType(dstBits);
429 if (dstBits % srcBits != 0) {
430 return rewriter.notifyMatchFailure(
431 op,
"only dstBits % srcBits == 0 supported");
434 Location loc = op.getLoc();
437 Value input = adaptor.getValue();
439 input = arith::BitcastOp::create(
441 IntegerType::get(rewriter.getContext(),
445 Value extendedInput =
446 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
450 if (convertedType.getRank() == 0) {
451 memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
453 rewriter.eraseOp(op);
458 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
460 rewriter, loc, linearizedIndices, srcBits, dstBits);
464 dstBits, bitwidthOffset, rewriter);
467 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
469 if (disableAtomicRMW) {
471 Value origValue = memref::LoadOp::create(
472 rewriter, loc, adaptor.getMemref(), storeIndices);
475 arith::AndIOp::create(rewriter, loc, origValue, writeMask);
479 arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
480 memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
485 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
486 writeMask, adaptor.getMemref(), storeIndices);
488 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
489 alignedVal, adaptor.getMemref(),
492 rewriter.eraseOp(op);
497 bool disableAtomicRMW;
508struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
509 using OpConversionPattern::OpConversionPattern;
512 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
513 ConversionPatternRewriter &rewriter)
const override {
515 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
517 return rewriter.notifyMatchFailure(
519 llvm::formatv(
"failed to convert memref type: {0}",
520 subViewOp.getType()));
523 Location loc = subViewOp.getLoc();
524 Type convertedElementType = newTy.getElementType();
525 Type oldElementType = subViewOp.getType().getElementType();
528 if (dstBits % srcBits != 0)
529 return rewriter.notifyMatchFailure(
530 subViewOp,
"only dstBits % srcBits == 0 supported");
533 if (llvm::any_of(subViewOp.getStaticStrides(),
534 [](int64_t stride) { return stride != 1; })) {
535 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
536 "stride != 1 is not supported");
540 return rewriter.notifyMatchFailure(
541 subViewOp,
"the result memref type is not contiguous");
544 auto sizes = subViewOp.getStaticSizes();
545 int64_t lastOffset = subViewOp.getStaticOffsets().back();
547 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
548 lastOffset == ShapedType::kDynamic) {
549 return rewriter.notifyMatchFailure(
550 subViewOp->getLoc(),
"dynamic size or offset is not supported");
554 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
555 rewriter, loc, subViewOp.getViewSource());
557 OpFoldResult linearizedIndices;
558 auto strides = stridedMetadata.getConstifiedMixedStrides();
559 memref::LinearizedMemRefInfo linearizedInfo;
560 std::tie(linearizedInfo, linearizedIndices) =
562 rewriter, loc, srcBits, dstBits,
563 stridedMetadata.getConstifiedMixedOffset(),
564 subViewOp.getMixedSizes(), strides,
568 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
569 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
582struct ConvertMemRefCollapseShape final
583 : OpConversionPattern<memref::CollapseShapeOp> {
584 using OpConversionPattern::OpConversionPattern;
587 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter)
const override {
589 Value srcVal = adaptor.getSrc();
590 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
594 if (newTy.getRank() != 1)
597 rewriter.replaceOp(collapseShapeOp, srcVal);
605struct ConvertMemRefExpandShape final
606 : OpConversionPattern<memref::ExpandShapeOp> {
607 using OpConversionPattern::OpConversionPattern;
610 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
611 ConversionPatternRewriter &rewriter)
const override {
612 Value srcVal = adaptor.getSrc();
613 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
617 if (newTy.getRank() != 1)
620 rewriter.replaceOp(expandShapeOp, srcVal);
635 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
636 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
637 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
638 ConvertMemRefExpandShape, ConvertMemRefLoad,
639 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
640 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
641 typeConverter,
patterns.getContext());
642 patterns.insert<ConvertMemrefStore>(
patterns.getContext(), disableAtomicRMW);
648 if (ty.getRank() == 0)
652 for (
auto shape : ty.getShape()) {
653 if (
shape == ShapedType::kDynamic)
654 return {ShapedType::kDynamic};
655 linearizedShape *=
shape;
657 int scale = dstBits / srcBits;
660 linearizedShape = (linearizedShape + scale - 1) / scale;
661 return {linearizedShape};
666 typeConverter.addConversion(
667 [&typeConverter](MemRefType ty) -> std::optional<Type> {
668 Type elementType = ty.getElementType();
674 if (width >= loadStoreWidth)
680 if (failed(ty.getStridesAndOffset(strides, offset)))
682 if (!strides.empty() && strides.back() != 1)
685 auto newElemTy = IntegerType::get(
686 ty.getContext(), loadStoreWidth,
688 ? cast<IntegerType>(elementType).getSignedness()
689 : IntegerType::SignednessSemantics::Signless);
693 StridedLayoutAttr layoutAttr;
697 if (offset == ShapedType::kDynamic) {
698 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
703 if ((offset * width) % loadStoreWidth != 0)
705 offset = (offset * width) / loadStoreWidth;
707 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
713 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult linearizedSize