26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
29 #include <type_traits>
41 template <
typename MemRefOpTy>
43 typename MemRefOpTy::Adaptor adaptor,
44 MemRefOpTy op, MemRefType newTy) {
45 static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
46 std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
47 "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
49 auto convertedElementType = newTy.getElementType();
50 auto oldElementType = op.getType().getElementType();
51 int srcBits = oldElementType.getIntOrFloatBitWidth();
52 int dstBits = convertedElementType.getIntOrFloatBitWidth();
53 if (dstBits % srcBits != 0) {
55 "only dstBits % srcBits == 0 supported");
59 if (llvm::any_of(op.getStaticStrides(),
60 [](int64_t stride) { return stride != 1; })) {
62 "stride != 1 is not supported");
65 auto sizes = op.getStaticSizes();
66 int64_t offset = op.getStaticOffset(0);
68 if (llvm::any_of(sizes,
69 [](int64_t size) {
return size == ShapedType::kDynamic; }) ||
70 offset == ShapedType::kDynamic) {
72 op->
getLoc(),
"dynamic size or offset is not supported");
75 int elementsPerByte = dstBits / srcBits;
76 if (offset % elementsPerByte != 0) {
78 op->
getLoc(),
"offset not multiple of elementsPerByte is not "
84 size.push_back(
ceilDiv(sizes[0], elementsPerByte));
85 offset = offset / elementsPerByte;
88 *adaptor.getODSOperands(0).begin(),
89 offset, size, op.getStaticStrides());
101 int sourceBits,
int targetBits,
103 assert(targetBits % sourceBits == 0);
106 int scaleFactor = targetBits / sourceBits;
107 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
112 return builder.
create<arith::IndexCastOp>(loc, dstType, bitOffset);
120 int64_t srcBits, int64_t dstBits,
123 auto maskRightAlignedAttr =
125 Value maskRightAligned = builder.
create<arith::ConstantOp>(
126 loc, dstIntegerType, maskRightAlignedAttr);
127 Value writeMaskInverse =
128 builder.
create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
131 builder.
create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
132 return builder.
create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
140 int64_t srcBits, int64_t dstBits) {
143 int64_t scaler = dstBits / srcBits;
145 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
153 auto stridedMetadata =
154 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref);
156 std::tie(std::ignore, linearizedIndices) =
158 builder, loc, srcBits, srcBits,
159 stridedMetadata.getConstifiedMixedOffset(),
160 stridedMetadata.getConstifiedMixedSizes(),
161 stridedMetadata.getConstifiedMixedStrides(), indices);
162 return linearizedIndices;
171 template <
typename OpTy>
176 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
178 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
179 std::is_same<OpTy, memref::AllocaOp>(),
180 "expected only memref::AllocOp or memref::AllocaOp");
181 auto currentType = cast<MemRefType>(op.getMemref().getType());
182 auto newResultType = dyn_cast<MemRefType>(
183 this->getTypeConverter()->convertType(op.getType()));
184 if (!newResultType) {
187 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
191 if (currentType.getRank() == 0) {
193 adaptor.getSymbolOperands(),
194 adaptor.getAlignmentAttr());
203 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
204 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
209 rewriter, loc, srcBits, dstBits, zero, sizes);
211 if (!newResultType.hasStaticShape()) {
217 adaptor.getSymbolOperands(),
218 adaptor.getAlignmentAttr());
227 struct ConvertMemRefAssumeAlignment final
232 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
234 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
237 op->
getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
238 op.getMemref().getType()));
242 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
255 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
257 auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
258 auto convertedElementType = convertedType.getElementType();
259 auto oldElementType = op.getMemRefType().getElementType();
260 int srcBits = oldElementType.getIntOrFloatBitWidth();
261 int dstBits = convertedElementType.getIntOrFloatBitWidth();
262 if (dstBits % srcBits != 0) {
264 op,
"only dstBits % srcBits == 0 supported");
270 if (convertedType.getRank() == 0) {
271 bitsLoad = rewriter.
create<memref::LoadOp>(loc, adaptor.getMemref(),
277 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
280 loc, adaptor.getMemref(),
287 srcBits, dstBits, rewriter);
288 bitsLoad = rewriter.
create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
297 auto resultTy = getTypeConverter()->convertType(oldElementType);
298 if (resultTy == convertedElementType) {
299 auto mask = rewriter.
create<arith::ConstantOp>(
300 loc, convertedElementType,
301 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
303 result = rewriter.
create<arith::AndIOp>(loc, bitsLoad, mask);
305 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, bitsLoad);
319 struct ConvertMemRefReinterpretCast final
324 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
327 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
331 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
335 if (op.getType().getRank() > 1) {
337 op->
getLoc(),
"subview with rank > 1 is not supported");
352 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
354 auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
355 int srcBits = op.getMemRefType().getElementTypeBitWidth();
356 int dstBits = convertedType.getElementTypeBitWidth();
358 if (dstBits % srcBits != 0) {
360 op,
"only dstBits % srcBits == 0 supported");
364 Value extendedInput = rewriter.
create<arith::ExtUIOp>(loc, dstIntegerType,
368 if (convertedType.getRank() == 0) {
369 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
370 extendedInput, adaptor.getMemref(),
377 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
379 rewriter, loc, linearizedIndices, srcBits, dstBits);
383 dstBits, bitwidthOffset, rewriter);
386 rewriter.
create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
389 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
390 writeMask, adaptor.getMemref(),
393 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
394 alignedVal, adaptor.getMemref(),
413 matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
416 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
420 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
424 if (op.getType().getRank() != 1) {
426 op->
getLoc(),
"subview with rank > 1 is not supported");
444 patterns.
add<ConvertMemRefAllocation<memref::AllocOp>,
445 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
446 ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
454 if (ty.getRank() == 0)
457 int64_t linearizedShape = 1;
458 for (
auto shape : ty.getShape()) {
459 if (shape == ShapedType::kDynamic)
460 return {ShapedType::kDynamic};
461 linearizedShape *= shape;
463 int scale = dstBits / srcBits;
466 linearizedShape = (linearizedShape + scale - 1) / scale;
467 return {linearizedShape};
473 [&typeConverter](MemRefType ty) -> std::optional<Type> {
474 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
478 unsigned width = intTy.getWidth();
480 if (width >= loadStoreWidth)
488 if (!strides.empty() && strides.back() != 1)
492 intTy.getSignedness());
496 StridedLayoutAttr layoutAttr;
500 if (offset == ShapedType::kDynamic) {
506 if ((offset * width) % loadStoreWidth != 0)
508 offset = (offset * width) / loadStoreWidth;
516 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
OpFoldResult linearizedSize