23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
26 #include <type_traits>
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
64 op,
"dynamic size or offset is not supported");
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
70 op,
"offset not multiple of elementsPerByte is not supported");
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
91 int sourceBits,
int targetBits,
93 assert(targetBits % sourceBits == 0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
102 return builder.
create<arith::IndexCastOp>(loc, dstType, bitOffset);
110 int64_t srcBits, int64_t dstBits,
113 auto maskRightAlignedAttr =
115 Value maskRightAligned = builder.
create<arith::ConstantOp>(
116 loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 builder.
create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
121 builder.
create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122 return builder.
create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
130 int64_t srcBits, int64_t dstBits) {
133 int64_t scaler = dstBits / srcBits;
135 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
143 auto stridedMetadata =
144 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref);
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
161 template <
typename OpTy>
166 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
173 this->getTypeConverter()->template convertType<MemRefType>(
175 if (!newResultType) {
178 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
182 if (currentType.getRank() == 0) {
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
194 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
195 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
200 rewriter, loc, srcBits, dstBits, zero, sizes);
202 if (!newResultType.hasStaticShape()) {
208 adaptor.getSymbolOperands(),
209 adaptor.getAlignmentAttr());
218 struct ConvertMemRefAssumeAlignment final
223 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
225 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
228 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
229 op.getMemref().getType()));
233 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
246 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
248 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
249 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
250 if (maybeRankedSource && maybeRankedDest &&
251 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
253 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
254 "and {1}) is currently unimplemented",
255 maybeRankedSource.getLayout(),
256 maybeRankedDest.getLayout()));
258 adaptor.getTarget());
271 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
286 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
288 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
289 auto convertedElementType = convertedType.getElementType();
290 auto oldElementType = op.getMemRefType().getElementType();
291 int srcBits = oldElementType.getIntOrFloatBitWidth();
292 int dstBits = convertedElementType.getIntOrFloatBitWidth();
293 if (dstBits % srcBits != 0) {
295 op,
"only dstBits % srcBits == 0 supported");
301 if (convertedType.getRank() == 0) {
302 bitsLoad = rewriter.
create<memref::LoadOp>(loc, adaptor.getMemref(),
308 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
311 loc, adaptor.getMemref(),
318 srcBits, dstBits, rewriter);
319 bitsLoad = rewriter.
create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
328 auto resultTy = getTypeConverter()->convertType(oldElementType);
329 if (resultTy == convertedElementType) {
330 auto mask = rewriter.
create<arith::ConstantOp>(
331 loc, convertedElementType,
332 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
334 result = rewriter.
create<arith::AndIOp>(loc, bitsLoad, mask);
336 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, bitsLoad);
348 struct ConvertMemRefMemorySpaceCast final
353 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
355 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
358 op->getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
359 op.getDest().getType()));
363 adaptor.getSource());
374 struct ConvertMemRefReinterpretCast final
379 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
382 getTypeConverter()->convertType<MemRefType>(op.getType());
386 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
390 if (op.getType().getRank() > 1) {
392 op->getLoc(),
"subview with rank > 1 is not supported");
407 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
409 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
410 int srcBits = op.getMemRefType().getElementTypeBitWidth();
411 int dstBits = convertedType.getElementTypeBitWidth();
413 if (dstBits % srcBits != 0) {
415 op,
"only dstBits % srcBits == 0 supported");
419 Value extendedInput = rewriter.
create<arith::ExtUIOp>(loc, dstIntegerType,
423 if (convertedType.getRank() == 0) {
424 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
425 extendedInput, adaptor.getMemref(),
432 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
434 rewriter, loc, linearizedIndices, srcBits, dstBits);
438 dstBits, bitwidthOffset, rewriter);
441 rewriter.
create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
444 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
445 writeMask, adaptor.getMemref(),
448 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
449 alignedVal, adaptor.getMemref(),
468 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
471 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
475 llvm::formatv(
"failed to convert memref type: {0}",
476 subViewOp.getType()));
480 Type convertedElementType = newTy.getElementType();
481 Type oldElementType = subViewOp.getType().getElementType();
484 if (dstBits % srcBits != 0)
486 subViewOp,
"only dstBits % srcBits == 0 supported");
489 if (llvm::any_of(subViewOp.getStaticStrides(),
490 [](int64_t stride) { return stride != 1; })) {
492 "stride != 1 is not supported");
497 subViewOp,
"the result memref type is not contiguous");
500 auto sizes = subViewOp.getStaticSizes();
501 int64_t lastOffset = subViewOp.getStaticOffsets().back();
503 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
504 lastOffset == ShapedType::kDynamic) {
506 subViewOp->getLoc(),
"dynamic size or offset is not supported");
510 auto stridedMetadata = rewriter.
create<memref::ExtractStridedMetadataOp>(
511 loc, subViewOp.getViewSource());
514 auto strides = stridedMetadata.getConstifiedMixedStrides();
516 std::tie(linearizedInfo, linearizedIndices) =
518 rewriter, loc, srcBits, dstBits,
519 stridedMetadata.getConstifiedMixedOffset(),
520 subViewOp.getMixedSizes(), strides,
525 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
538 struct ConvertMemRefCollapseShape final
543 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
545 Value srcVal = adaptor.getSrc();
546 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
550 if (newTy.getRank() != 1)
553 rewriter.
replaceOp(collapseShapeOp, srcVal);
561 struct ConvertMemRefExpandShape final
566 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
568 Value srcVal = adaptor.getSrc();
569 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
573 if (newTy.getRank() != 1)
576 rewriter.
replaceOp(expandShapeOp, srcVal);
591 patterns.
add<ConvertMemRefAllocation<memref::AllocOp>,
592 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
593 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
594 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
595 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
596 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
603 if (ty.getRank() == 0)
606 int64_t linearizedShape = 1;
607 for (
auto shape : ty.getShape()) {
608 if (shape == ShapedType::kDynamic)
609 return {ShapedType::kDynamic};
610 linearizedShape *= shape;
612 int scale = dstBits / srcBits;
615 linearizedShape = (linearizedShape + scale - 1) / scale;
616 return {linearizedShape};
622 [&typeConverter](MemRefType ty) -> std::optional<Type> {
623 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
627 unsigned width = intTy.getWidth();
629 if (width >= loadStoreWidth)
637 if (!strides.empty() && strides.back() != 1)
641 intTy.getSignedness());
645 StridedLayoutAttr layoutAttr;
649 if (offset == ShapedType::kDynamic) {
655 if ((offset * width) % loadStoreWidth != 0)
657 offset = (offset * width) / loadStoreWidth;
665 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize