23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
26 #include <type_traits>
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
64 op,
"dynamic size or offset is not supported");
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
70 op,
"offset not multiple of elementsPerByte is not supported");
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
91 int sourceBits,
int targetBits,
93 assert(targetBits % sourceBits == 0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
102 return builder.
create<arith::IndexCastOp>(loc, dstType, bitOffset);
110 int64_t srcBits, int64_t dstBits,
113 auto maskRightAlignedAttr =
115 Value maskRightAligned = builder.
create<arith::ConstantOp>(
116 loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 builder.
create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
121 builder.
create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122 return builder.
create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
130 int64_t srcBits, int64_t dstBits) {
133 int64_t scaler = dstBits / srcBits;
135 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
143 auto stridedMetadata =
144 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref);
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
161 template <
typename OpTy>
166 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
172 auto newResultType = dyn_cast<MemRefType>(
173 this->getTypeConverter()->convertType(op.getType()));
174 if (!newResultType) {
177 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
181 if (currentType.getRank() == 0) {
183 adaptor.getSymbolOperands(),
184 adaptor.getAlignmentAttr());
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
199 rewriter, loc, srcBits, dstBits, zero, sizes);
201 if (!newResultType.hasStaticShape()) {
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
217 struct ConvertMemRefAssumeAlignment final
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
227 op->
getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
228 op.getMemref().getType()));
232 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
252 op, llvm::formatv(
"memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
257 adaptor.getTarget());
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
294 op,
"only dstBits % srcBits == 0 supported");
300 if (convertedType.getRank() == 0) {
301 bitsLoad = rewriter.
create<memref::LoadOp>(loc, adaptor.getMemref(),
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
310 loc, adaptor.getMemref(),
317 srcBits, dstBits, rewriter);
318 bitsLoad = rewriter.
create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
328 if (resultTy == convertedElementType) {
329 auto mask = rewriter.
create<arith::ConstantOp>(
330 loc, convertedElementType,
331 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
333 result = rewriter.
create<arith::AndIOp>(loc, bitsLoad, mask);
335 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, bitsLoad);
347 struct ConvertMemRefMemorySpaceCast final
352 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
354 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
357 op->
getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
358 op.getDest().getType()));
362 adaptor.getSource());
373 struct ConvertMemRefReinterpretCast final
378 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
381 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
385 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
389 if (op.getType().getRank() > 1) {
391 op->
getLoc(),
"subview with rank > 1 is not supported");
406 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
408 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
409 int srcBits = op.getMemRefType().getElementTypeBitWidth();
410 int dstBits = convertedType.getElementTypeBitWidth();
412 if (dstBits % srcBits != 0) {
414 op,
"only dstBits % srcBits == 0 supported");
418 Value extendedInput = rewriter.
create<arith::ExtUIOp>(loc, dstIntegerType,
422 if (convertedType.getRank() == 0) {
423 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
424 extendedInput, adaptor.getMemref(),
431 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
433 rewriter, loc, linearizedIndices, srcBits, dstBits);
437 dstBits, bitwidthOffset, rewriter);
440 rewriter.
create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
443 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
444 writeMask, adaptor.getMemref(),
447 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
448 alignedVal, adaptor.getMemref(),
467 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
469 MemRefType newTy = dyn_cast<MemRefType>(
470 getTypeConverter()->convertType(subViewOp.getType()));
474 llvm::formatv(
"failed to convert memref type: {0}",
475 subViewOp.getType()));
479 Type convertedElementType = newTy.getElementType();
480 Type oldElementType = subViewOp.getType().getElementType();
483 if (dstBits % srcBits != 0)
485 subViewOp,
"only dstBits % srcBits == 0 supported");
488 if (llvm::any_of(subViewOp.getStaticStrides(),
489 [](int64_t stride) { return stride != 1; })) {
491 "stride != 1 is not supported");
496 subViewOp,
"the result memref type is not contiguous");
499 auto sizes = subViewOp.getStaticSizes();
500 int64_t lastOffset = subViewOp.getStaticOffsets().back();
502 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
503 lastOffset == ShapedType::kDynamic) {
505 subViewOp->getLoc(),
"dynamic size or offset is not supported");
509 auto stridedMetadata = rewriter.
create<memref::ExtractStridedMetadataOp>(
510 loc, subViewOp.getViewSource());
513 auto strides = stridedMetadata.getConstifiedMixedStrides();
515 std::tie(linearizedInfo, linearizedIndices) =
517 rewriter, loc, srcBits, dstBits,
518 stridedMetadata.getConstifiedMixedOffset(),
519 subViewOp.getMixedSizes(), strides,
524 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
537 struct ConvertMemRefCollapseShape final
542 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
544 Value srcVal = adaptor.getSrc();
545 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
549 if (newTy.getRank() != 1)
552 rewriter.
replaceOp(collapseShapeOp, srcVal);
560 struct ConvertMemRefExpandShape final
565 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
567 Value srcVal = adaptor.getSrc();
568 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
572 if (newTy.getRank() != 1)
575 rewriter.
replaceOp(expandShapeOp, srcVal);
590 patterns.
add<ConvertMemRefAllocation<memref::AllocOp>,
591 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
592 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
593 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
594 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
595 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
602 if (ty.getRank() == 0)
605 int64_t linearizedShape = 1;
606 for (
auto shape : ty.getShape()) {
607 if (shape == ShapedType::kDynamic)
608 return {ShapedType::kDynamic};
609 linearizedShape *= shape;
611 int scale = dstBits / srcBits;
614 linearizedShape = (linearizedShape + scale - 1) / scale;
615 return {linearizedShape};
621 [&typeConverter](MemRefType ty) -> std::optional<Type> {
622 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
626 unsigned width = intTy.getWidth();
628 if (width >= loadStoreWidth)
636 if (!strides.empty() && strides.back() != 1)
640 intTy.getSignedness());
644 StridedLayoutAttr layoutAttr;
648 if (offset == ShapedType::kDynamic) {
654 if ((offset * width) % loadStoreWidth != 0)
656 offset = (offset * width) / loadStoreWidth;
664 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
OpFoldResult linearizedSize