25 #include "llvm/Support/FormatVariadic.h"
27 #include <type_traits>
39 template <
typename MemRefOpTy>
41 typename MemRefOpTy::Adaptor adaptor,
42 MemRefOpTy op, MemRefType newTy) {
43 static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44 std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45 "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
47 auto convertedElementType = newTy.getElementType();
48 auto oldElementType = op.getType().getElementType();
49 int srcBits = oldElementType.getIntOrFloatBitWidth();
50 int dstBits = convertedElementType.getIntOrFloatBitWidth();
51 if (dstBits % srcBits != 0) {
53 "only dstBits % srcBits == 0 supported");
57 if (llvm::any_of(op.getStaticStrides(),
58 [](int64_t stride) { return stride != 1; })) {
60 "stride != 1 is not supported");
63 auto sizes = op.getStaticSizes();
64 int64_t offset = op.getStaticOffset(0);
66 if (llvm::any_of(sizes,
67 [](int64_t size) {
return size == ShapedType::kDynamic; }) ||
68 offset == ShapedType::kDynamic) {
70 op->
getLoc(),
"dynamic size or offset is not supported");
73 int elementsPerByte = dstBits / srcBits;
74 if (offset % elementsPerByte != 0) {
76 op->
getLoc(),
"offset not multiple of elementsPerByte is not "
82 size.push_back(
ceilDiv(sizes[0], elementsPerByte));
83 offset = offset / elementsPerByte;
86 *adaptor.getODSOperands(0).begin(),
87 offset, size, op.getStaticStrides());
99 int sourceBits,
int targetBits,
101 assert(targetBits % sourceBits == 0);
104 int scaleFactor = targetBits / sourceBits;
105 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
110 return builder.
create<arith::IndexCastOp>(loc, dstType, bitOffset);
118 int64_t srcBits, int64_t dstBits,
121 auto maskRightAlignedAttr =
123 Value maskRightAligned = builder.
create<arith::ConstantOp>(
124 loc, dstIntegerType, maskRightAlignedAttr);
125 Value writeMaskInverse =
126 builder.
create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
129 builder.
create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
130 return builder.
create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
138 int64_t srcBits, int64_t dstBits) {
141 int64_t scaler = dstBits / srcBits;
143 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
151 auto stridedMetadata =
152 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref);
154 std::tie(std::ignore, linearizedIndices) =
156 builder, loc, srcBits, srcBits,
157 stridedMetadata.getConstifiedMixedOffset(),
158 stridedMetadata.getConstifiedMixedSizes(),
159 stridedMetadata.getConstifiedMixedStrides(), indices);
160 return linearizedIndices;
169 template <
typename OpTy>
174 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
176 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
177 std::is_same<OpTy, memref::AllocaOp>(),
178 "expected only memref::AllocOp or memref::AllocaOp");
179 auto currentType = cast<MemRefType>(op.getMemref().getType());
180 auto newResultType = dyn_cast<MemRefType>(
181 this->getTypeConverter()->convertType(op.getType()));
182 if (!newResultType) {
185 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
189 if (currentType.getRank() == 0) {
191 adaptor.getSymbolOperands(),
192 adaptor.getAlignmentAttr());
201 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
202 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
207 rewriter, loc, srcBits, dstBits, zero, sizes);
209 if (!newResultType.hasStaticShape()) {
215 adaptor.getSymbolOperands(),
216 adaptor.getAlignmentAttr());
225 struct ConvertMemRefAssumeAlignment final
230 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
232 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
235 op->
getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
236 op.getMemref().getType()));
240 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
253 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
255 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
256 auto convertedElementType = convertedType.getElementType();
257 auto oldElementType = op.getMemRefType().getElementType();
258 int srcBits = oldElementType.getIntOrFloatBitWidth();
259 int dstBits = convertedElementType.getIntOrFloatBitWidth();
260 if (dstBits % srcBits != 0) {
262 op,
"only dstBits % srcBits == 0 supported");
268 if (convertedType.getRank() == 0) {
269 bitsLoad = rewriter.
create<memref::LoadOp>(loc, adaptor.getMemref(),
275 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
278 loc, adaptor.getMemref(),
285 srcBits, dstBits, rewriter);
286 bitsLoad = rewriter.
create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
295 auto resultTy = getTypeConverter()->convertType(oldElementType);
296 if (resultTy == convertedElementType) {
297 auto mask = rewriter.
create<arith::ConstantOp>(
298 loc, convertedElementType,
299 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
301 result = rewriter.
create<arith::AndIOp>(loc, bitsLoad, mask);
303 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, bitsLoad);
317 struct ConvertMemRefReinterpretCast final
322 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
325 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
329 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
333 if (op.getType().getRank() > 1) {
335 op->
getLoc(),
"subview with rank > 1 is not supported");
350 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
352 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
353 int srcBits = op.getMemRefType().getElementTypeBitWidth();
354 int dstBits = convertedType.getElementTypeBitWidth();
356 if (dstBits % srcBits != 0) {
358 op,
"only dstBits % srcBits == 0 supported");
362 Value extendedInput = rewriter.
create<arith::ExtUIOp>(loc, dstIntegerType,
366 if (convertedType.getRank() == 0) {
367 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
368 extendedInput, adaptor.getMemref(),
375 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
377 rewriter, loc, linearizedIndices, srcBits, dstBits);
381 dstBits, bitwidthOffset, rewriter);
384 rewriter.
create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
387 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
388 writeMask, adaptor.getMemref(),
391 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
392 alignedVal, adaptor.getMemref(),
411 matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
414 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
418 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
422 if (op.getType().getRank() != 1) {
424 op->
getLoc(),
"subview with rank > 1 is not supported");
438 struct ConvertMemRefCollapseShape final
443 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
445 Value srcVal = adaptor.getSrc();
446 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
450 if (newTy.getRank() != 1)
453 rewriter.
replaceOp(collapseShapeOp, srcVal);
469 patterns.
add<ConvertMemRefAllocation<memref::AllocOp>,
470 ConvertMemRefAllocation<memref::AllocaOp>,
471 ConvertMemRefCollapseShape, ConvertMemRefLoad,
472 ConvertMemrefStore, ConvertMemRefAssumeAlignment,
473 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
480 if (ty.getRank() == 0)
483 int64_t linearizedShape = 1;
484 for (
auto shape : ty.getShape()) {
485 if (shape == ShapedType::kDynamic)
486 return {ShapedType::kDynamic};
487 linearizedShape *= shape;
489 int scale = dstBits / srcBits;
492 linearizedShape = (linearizedShape + scale - 1) / scale;
493 return {linearizedShape};
499 [&typeConverter](MemRefType ty) -> std::optional<Type> {
500 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
504 unsigned width = intTy.getWidth();
506 if (width >= loadStoreWidth)
514 if (!strides.empty() && strides.back() != 1)
518 intTy.getSignedness());
522 StridedLayoutAttr layoutAttr;
526 if (offset == ShapedType::kDynamic) {
532 if ((offset * width) % loadStoreWidth != 0)
534 offset = (offset * width) / loadStoreWidth;
542 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
OpFoldResult linearizedSize