23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
26 #include <type_traits>
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
48 "only dstBits % srcBits == 0 supported");
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
55 "stride != 1 is not supported");
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
61 if (llvm::any_of(sizes,
62 [](int64_t size) {
return size == ShapedType::kDynamic; }) ||
63 offset == ShapedType::kDynamic) {
65 op,
"dynamic size or offset is not supported");
68 int elementsPerByte = dstBits / srcBits;
69 if (offset % elementsPerByte != 0) {
71 op,
"offset not multiple of elementsPerByte is not supported");
76 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
77 offset = offset / elementsPerByte;
80 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
92 int sourceBits,
int targetBits,
94 assert(targetBits % sourceBits == 0);
97 int scaleFactor = targetBits / sourceBits;
98 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
103 return builder.
create<arith::IndexCastOp>(loc, dstType, bitOffset);
111 int64_t srcBits, int64_t dstBits,
114 auto maskRightAlignedAttr =
116 Value maskRightAligned = builder.
create<arith::ConstantOp>(
117 loc, dstIntegerType, maskRightAlignedAttr);
118 Value writeMaskInverse =
119 builder.
create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
122 builder.
create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
123 return builder.
create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
131 int64_t srcBits, int64_t dstBits) {
134 int64_t scaler = dstBits / srcBits;
136 builder, loc, s0.
floorDiv(scaler), {linearizedIndex});
144 auto stridedMetadata =
145 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref);
147 std::tie(std::ignore, linearizedIndices) =
149 builder, loc, srcBits, srcBits,
150 stridedMetadata.getConstifiedMixedOffset(),
151 stridedMetadata.getConstifiedMixedSizes(),
152 stridedMetadata.getConstifiedMixedStrides(), indices);
153 return linearizedIndices;
162 template <
typename OpTy>
167 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
169 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
170 std::is_same<OpTy, memref::AllocaOp>(),
171 "expected only memref::AllocOp or memref::AllocaOp");
172 auto currentType = cast<MemRefType>(op.getMemref().getType());
173 auto newResultType = dyn_cast<MemRefType>(
174 this->getTypeConverter()->convertType(op.getType()));
175 if (!newResultType) {
178 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
182 if (currentType.getRank() == 0) {
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
194 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
195 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
200 rewriter, loc, srcBits, dstBits, zero, sizes);
202 if (!newResultType.hasStaticShape()) {
208 adaptor.getSymbolOperands(),
209 adaptor.getAlignmentAttr());
218 struct ConvertMemRefAssumeAlignment final
223 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
225 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
228 op->
getLoc(), llvm::formatv(
"failed to convert memref type: {0}",
229 op.getMemref().getType()));
233 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
246 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
248 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
249 auto convertedElementType = convertedType.getElementType();
250 auto oldElementType = op.getMemRefType().getElementType();
251 int srcBits = oldElementType.getIntOrFloatBitWidth();
252 int dstBits = convertedElementType.getIntOrFloatBitWidth();
253 if (dstBits % srcBits != 0) {
255 op,
"only dstBits % srcBits == 0 supported");
261 if (convertedType.getRank() == 0) {
262 bitsLoad = rewriter.
create<memref::LoadOp>(loc, adaptor.getMemref(),
268 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
271 loc, adaptor.getMemref(),
278 srcBits, dstBits, rewriter);
279 bitsLoad = rewriter.
create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
288 auto resultTy = getTypeConverter()->convertType(oldElementType);
289 if (resultTy == convertedElementType) {
290 auto mask = rewriter.
create<arith::ConstantOp>(
291 loc, convertedElementType,
292 rewriter.
getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
294 result = rewriter.
create<arith::AndIOp>(loc, bitsLoad, mask);
296 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, bitsLoad);
310 struct ConvertMemRefReinterpretCast final
315 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
318 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
322 llvm::formatv(
"failed to convert memref type: {0}", op.getType()));
326 if (op.getType().getRank() > 1) {
328 op->
getLoc(),
"subview with rank > 1 is not supported");
343 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
345 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
346 int srcBits = op.getMemRefType().getElementTypeBitWidth();
347 int dstBits = convertedType.getElementTypeBitWidth();
349 if (dstBits % srcBits != 0) {
351 op,
"only dstBits % srcBits == 0 supported");
355 Value extendedInput = rewriter.
create<arith::ExtUIOp>(loc, dstIntegerType,
359 if (convertedType.getRank() == 0) {
360 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
361 extendedInput, adaptor.getMemref(),
368 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
370 rewriter, loc, linearizedIndices, srcBits, dstBits);
374 dstBits, bitwidthOffset, rewriter);
377 rewriter.
create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
380 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
381 writeMask, adaptor.getMemref(),
384 rewriter.
create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
385 alignedVal, adaptor.getMemref(),
404 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
406 MemRefType newTy = dyn_cast<MemRefType>(
407 getTypeConverter()->convertType(subViewOp.getType()));
411 llvm::formatv(
"failed to convert memref type: {0}",
412 subViewOp.getType()));
416 Type convertedElementType = newTy.getElementType();
417 Type oldElementType = subViewOp.getType().getElementType();
420 if (dstBits % srcBits != 0)
422 subViewOp,
"only dstBits % srcBits == 0 supported");
425 if (llvm::any_of(subViewOp.getStaticStrides(),
426 [](int64_t stride) { return stride != 1; })) {
428 "stride != 1 is not supported");
433 subViewOp,
"the result memref type is not contiguous");
436 auto sizes = subViewOp.getStaticSizes();
437 int64_t lastOffset = subViewOp.getStaticOffsets().back();
440 sizes, [](int64_t size) {
return size == ShapedType::kDynamic; }) ||
441 lastOffset == ShapedType::kDynamic) {
443 subViewOp->getLoc(),
"dynamic size or offset is not supported");
447 auto stridedMetadata = rewriter.
create<memref::ExtractStridedMetadataOp>(
448 loc, subViewOp.getViewSource());
451 auto strides = stridedMetadata.getConstifiedMixedStrides();
453 std::tie(linearizedInfo, linearizedIndices) =
455 rewriter, loc, srcBits, dstBits,
456 stridedMetadata.getConstifiedMixedOffset(),
457 subViewOp.getMixedSizes(), strides,
462 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
475 struct ConvertMemRefCollapseShape final
480 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
482 Value srcVal = adaptor.getSrc();
483 auto newTy = dyn_cast<MemRefType>(srcVal.
getType());
487 if (newTy.getRank() != 1)
490 rewriter.
replaceOp(collapseShapeOp, srcVal);
506 patterns.
add<ConvertMemRefAllocation<memref::AllocOp>,
507 ConvertMemRefAllocation<memref::AllocaOp>,
508 ConvertMemRefCollapseShape, ConvertMemRefLoad,
509 ConvertMemrefStore, ConvertMemRefAssumeAlignment,
510 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
517 if (ty.getRank() == 0)
520 int64_t linearizedShape = 1;
521 for (
auto shape : ty.getShape()) {
522 if (shape == ShapedType::kDynamic)
523 return {ShapedType::kDynamic};
524 linearizedShape *= shape;
526 int scale = dstBits / srcBits;
529 linearizedShape = (linearizedShape + scale - 1) / scale;
530 return {linearizedShape};
536 [&typeConverter](MemRefType ty) -> std::optional<Type> {
537 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
541 unsigned width = intTy.getWidth();
543 if (width >= loadStoreWidth)
551 if (!strides.empty() && strides.back() != 1)
555 intTy.getSignedness());
559 StridedLayoutAttr layoutAttr;
563 if (offset == ShapedType::kDynamic) {
569 if ((offset * width) % loadStoreWidth != 0)
571 offset = (offset * width) / loadStoreWidth;
579 newElemTy, layoutAttr, ty.getMemorySpace());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
unsigned getLoadStoreBitwidth() const
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
OpFoldResult linearizedSize