23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(op,
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](
int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(op->getLoc(),
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 op,
"dynamic size or offset is not supported");
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 op,
"offset not multiple of elementsPerByte is not supported");
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
91 int sourceBits,
int targetBits,
93 assert(targetBits % sourceBits == 0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
102 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
113 auto maskRightAlignedAttr =
115 Value maskRightAligned = arith::ConstantOp::create(
116 builder, loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
121 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
133 int64_t scaler = dstBits / srcBits;
135 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
143 auto stridedMetadata =
144 memref::ExtractStridedMetadataOp::create(builder, loc,
memref);
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(),
indices);
152 return linearizedIndices;
161template <
typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
166 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter)
const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
173 this->getTypeConverter()->template convertType<MemRefType>(
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
178 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType,
ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(0);
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
199 rewriter, loc, srcBits, dstBits, zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter)
const override {
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
226 return rewriter.notifyMatchFailure(
227 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
228 op.getMemref().getType()));
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257 adaptor.getTarget());
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 op,
"only dstBits % srcBits == 0 supported");
297 Location loc = op.getLoc();
300 if (convertedType.getRank() == 0) {
301 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
309 Value newLoad = memref::LoadOp::create(
310 rewriter, loc, adaptor.getMemref(),
317 srcBits, dstBits, rewriter);
318 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
331 : IntegerType::get(rewriter.getContext(),
332 resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = arith::ConstantOp::create(
335 rewriter, loc, convertedElementType,
336 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
338 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
340 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
343 if (conversionTy != resultTy) {
344 result = arith::BitcastOp::create(rewriter, loc, resultTy,
result);
347 rewriter.replaceOp(op,
result);
356struct ConvertMemRefMemorySpaceCast final
357 : OpConversionPattern<memref::MemorySpaceCastOp> {
358 using OpConversionPattern::OpConversionPattern;
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter)
const override {
363 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
365 return rewriter.notifyMatchFailure(
366 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
367 op.getDest().getType()));
370 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
371 adaptor.getSource());
382struct ConvertMemRefReinterpretCast final
383 : OpConversionPattern<memref::ReinterpretCastOp> {
384 using OpConversionPattern::OpConversionPattern;
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter)
const override {
390 getTypeConverter()->convertType<MemRefType>(op.getType());
392 return rewriter.notifyMatchFailure(
394 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
398 if (op.getType().getRank() > 1) {
399 return rewriter.notifyMatchFailure(
400 op->getLoc(),
"subview with rank > 1 is not supported");
411struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
412 using OpConversionPattern::OpConversionPattern;
415 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter)
const override {
417 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
418 int srcBits = op.getMemRefType().getElementTypeBitWidth();
419 int dstBits = convertedType.getElementTypeBitWidth();
420 auto dstIntegerType = rewriter.getIntegerType(dstBits);
421 if (dstBits % srcBits != 0) {
422 return rewriter.notifyMatchFailure(
423 op,
"only dstBits % srcBits == 0 supported");
426 Location loc = op.getLoc();
429 Value input = adaptor.getValue();
431 input = arith::BitcastOp::create(
433 IntegerType::get(rewriter.getContext(),
437 Value extendedInput =
438 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
441 if (convertedType.getRank() == 0) {
442 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
443 extendedInput, adaptor.getMemref(),
445 rewriter.eraseOp(op);
450 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
452 rewriter, loc, linearizedIndices, srcBits, dstBits);
456 dstBits, bitwidthOffset, rewriter);
459 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
462 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
463 writeMask, adaptor.getMemref(), storeIndices);
465 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
466 alignedVal, adaptor.getMemref(), storeIndices);
467 rewriter.eraseOp(op);
480struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
481 using OpConversionPattern::OpConversionPattern;
484 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const override {
487 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
489 return rewriter.notifyMatchFailure(
491 llvm::formatv(
"failed to convert memref type: {0}",
492 subViewOp.getType()));
495 Location loc = subViewOp.getLoc();
496 Type convertedElementType = newTy.getElementType();
497 Type oldElementType = subViewOp.getType().getElementType();
500 if (dstBits % srcBits != 0)
501 return rewriter.notifyMatchFailure(
502 subViewOp,
"only dstBits % srcBits == 0 supported");
505 if (llvm::any_of(subViewOp.getStaticStrides(),
506 [](int64_t stride) { return stride != 1; })) {
507 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
508 "stride != 1 is not supported");
512 return rewriter.notifyMatchFailure(
513 subViewOp,
"the result memref type is not contiguous");
516 auto sizes = subViewOp.getStaticSizes();
517 int64_t lastOffset = subViewOp.getStaticOffsets().back();
519 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
520 lastOffset == ShapedType::kDynamic) {
521 return rewriter.notifyMatchFailure(
522 subViewOp->getLoc(),
"dynamic size or offset is not supported");
526 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
527 rewriter, loc, subViewOp.getViewSource());
529 OpFoldResult linearizedIndices;
530 auto strides = stridedMetadata.getConstifiedMixedStrides();
531 memref::LinearizedMemRefInfo linearizedInfo;
532 std::tie(linearizedInfo, linearizedIndices) =
534 rewriter, loc, srcBits, dstBits,
535 stridedMetadata.getConstifiedMixedOffset(),
536 subViewOp.getMixedSizes(), strides,
540 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
541 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
554struct ConvertMemRefCollapseShape final
555 : OpConversionPattern<memref::CollapseShapeOp> {
556 using OpConversionPattern::OpConversionPattern;
559 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
561 Value srcVal = adaptor.getSrc();
562 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
566 if (newTy.getRank() != 1)
569 rewriter.replaceOp(collapseShapeOp, srcVal);
577struct ConvertMemRefExpandShape final
578 : OpConversionPattern<memref::ExpandShapeOp> {
579 using OpConversionPattern::OpConversionPattern;
582 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
583 ConversionPatternRewriter &rewriter)
const override {
584 Value srcVal = adaptor.getSrc();
585 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
589 if (newTy.getRank() != 1)
592 rewriter.replaceOp(expandShapeOp, srcVal);
607 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
608 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
609 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
610 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
611 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
612 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
613 typeConverter,
patterns.getContext());
619 if (ty.getRank() == 0)
623 for (
auto shape : ty.getShape()) {
624 if (
shape == ShapedType::kDynamic)
625 return {ShapedType::kDynamic};
626 linearizedShape *=
shape;
628 int scale = dstBits / srcBits;
631 linearizedShape = (linearizedShape + scale - 1) / scale;
632 return {linearizedShape};
637 typeConverter.addConversion(
638 [&typeConverter](MemRefType ty) -> std::optional<Type> {
639 Type elementType = ty.getElementType();
645 if (width >= loadStoreWidth)
651 if (failed(ty.getStridesAndOffset(strides, offset)))
653 if (!strides.empty() && strides.back() != 1)
656 auto newElemTy = IntegerType::get(
657 ty.getContext(), loadStoreWidth,
659 ? cast<IntegerType>(elementType).getSignedness()
660 : IntegerType::SignednessSemantics::Signless);
664 StridedLayoutAttr layoutAttr;
668 if (offset == ShapedType::kDynamic) {
669 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
674 if ((offset * width) % loadStoreWidth != 0)
676 offset = (offset * width) / loadStoreWidth;
678 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
684 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult linearizedSize