25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
36 return arith::ConstantOp::materialize(builder, value, type, loc);
49 auto cast = operand.get().getDefiningOp<CastOp>();
50 if (cast && operand.get() != inner &&
51 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
52 operand.set(cast.getOperand());
62 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 return RankedTensorType::get(
memref.getShape(),
memref.getElementType());
64 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
65 return UnrankedTensorType::get(
memref.getElementType());
71 auto memrefType = llvm::cast<MemRefType>(value.
getType());
72 if (memrefType.isDynamicDim(dim))
73 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
80 auto memrefType = llvm::cast<MemRefType>(value.
getType());
82 for (
int64_t i = 0; i < memrefType.getRank(); ++i)
99 assert(constValues.size() == values.size() &&
100 "incorrect number of const values");
101 for (
auto [i, cstVal] : llvm::enumerate(constValues)) {
103 if (ShapedType::isStatic(cstVal)) {
117static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
119 MemorySpaceCastOpInterface castOp =
120 MemorySpaceCastOpInterface::getIfPromotableCast(src);
128 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
129 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
133 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
134 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
139 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
142 return std::make_tuple(castOp, *tgtTy, *srcTy);
147template <
typename ConcreteOpTy>
148static FailureOr<std::optional<SmallVector<Value>>>
158 llvm::append_range(operands, op->getOperands());
162 auto newOp = ConcreteOpTy::create(
163 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
164 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
167 MemorySpaceCastOpInterface
result = castOp.cloneMemorySpaceCastOp(
170 return std::optional<SmallVector<Value>>(
178void AllocOp::getAsmResultNames(
180 setNameFn(getResult(),
"alloc");
183void AllocaOp::getAsmResultNames(
185 setNameFn(getResult(),
"alloca");
188template <
typename AllocLikeOp>
190 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
191 "applies to only alloc or alloca");
192 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
194 return op.emitOpError(
"result must be a memref");
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError(
"symbol operand count does not equal memref symbol "
205 << numSymbols <<
", got " << op.getSymbolOperands().size();
212LogicalResult AllocaOp::verify() {
216 "requires an ancestor op with AutomaticAllocationScope trait");
223template <
typename AllocLikeOp>
225 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
228 PatternRewriter &rewriter)
const override {
231 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return constSizeArg.isNonNegative();
239 auto memrefType = alloc.getType();
243 SmallVector<int64_t, 4> newShapeConstants;
244 newShapeConstants.reserve(memrefType.getRank());
245 SmallVector<Value, 4> dynamicSizes;
247 unsigned dynamicDimPos = 0;
248 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
258 constSizeArg.isNonNegative()) {
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
270 MemRefType newMemRefType =
271 MemRefType::Builder(memrefType).setShape(newShapeConstants);
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
287 using OpRewritePattern<T>::OpRewritePattern;
289 LogicalResult matchAndRewrite(T alloc,
290 PatternRewriter &rewriter)
const override {
291 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
298 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
309 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
314 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
322LogicalResult ReallocOp::verify() {
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
324 MemRefType resultType =
getType();
327 if (!sourceType.getLayout().isIdentity())
328 return emitError(
"unsupported layout for source memref type ")
332 if (!resultType.getLayout().isIdentity())
333 return emitError(
"unsupported layout for result memref type ")
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError(
"different memory spaces specified for source memref "
340 << sourceType <<
" and result memref type " << resultType;
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError(
"different element types specified for source memref "
346 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
385 result.regions.reserve(1);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405void AllocaScopeOp::getSuccessorRegions(
422 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
445 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
476 bool hasPotentialAlloca =
489 if (hasPotentialAlloca) {
522 if (!lastParentWithoutScope ||
535 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
536 if (!lastParentWithoutScope ||
543 Region *containingRegion =
nullptr;
544 for (
auto &r : lastParentWithoutScope->
getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion ==
nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
551 assert(containingRegion &&
"op must be contained in a region");
561 return containingRegion->isAncestor(v.getParentRegion());
564 toHoist.push_back(alloc);
571 for (
auto *op : toHoist) {
572 auto *cloned = rewriter.
clone(*op);
573 rewriter.
replaceOp(op, cloned->getResults());
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError(
"alignment must be power of 2");
594void AssumeAlignmentOp::getAsmResultNames(
596 setNameFn(getResult(),
"assume_align");
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
603 if (source.getAlignment() != getAlignment())
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(
OpBuilder &builder,
616 assert(resultIndex == 0 &&
"AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError(
"operand types and result types must match");
628 if (getOperandTypes().empty())
629 return emitOpError(
"expected at least one operand");
634LogicalResult DistinctObjectsOp::inferReturnTypes(
639 llvm::copy(operands.
getTypes(), std::back_inserter(inferredReturnTypes));
648 setNameFn(getResult(),
"cast");
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
694 if (!sourceType || !resultType)
698 if (sourceType.getElementType() != resultType.getElementType())
702 if (sourceType.getRank() != resultType.getRank())
706 int64_t sourceOffset, resultOffset;
708 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
713 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
727 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
738 if (inputs.size() != 1 || outputs.size() != 1)
740 Type a = inputs.front(),
b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(
b);
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
748 if (aT.getElementType() != bT.getElementType())
750 if (aT.getLayout() != bT.getLayout()) {
753 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(
b) || a ==
b);
766 if (!checkCompatible(aOffset, bOffset))
769 if (aT.getDimSize(
index) == 1)
771 if (!checkCompatible(aStride, bStrides[
index]))
775 if (aT.getMemorySpace() != bT.getMemorySpace())
779 if (aT.getRank() != bT.getRank())
782 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(
OpBuilder &builder) {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter)
const override {
832 if (copyOp.getSource() != copyOp.getTarget())
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
843 static bool isEmptyMemRef(BaseMemRefType type) {
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter)
const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
862 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
869 for (
OpOperand &operand : op->getOpOperands()) {
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
901 setNameFn(getResult(),
"dim");
906 auto loc =
result.location;
908 build(builder,
result, source, indexValue);
911std::optional<int64_t> DimOp::getConstantIndex() {
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
921 if (!rankedSourceType)
924 if (rankedSourceType.getRank() <= constantIndex)
932 setResultRange(getResult(),
941 std::map<int64_t, unsigned> numOccurences;
942 for (
auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
954static FailureOr<llvm::SmallBitVector>
957 llvm::SmallBitVector unusedDims(originalType.getRank());
958 if (originalType.getRank() == reducedType.getRank())
961 for (
const auto &dim : llvm::enumerate(sizes))
962 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
963 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
964 unusedDims.set(dim.index());
968 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
969 originalType.getRank())
973 int64_t originalOffset, candidateOffset;
975 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
977 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
989 std::map<int64_t, unsigned> currUnaccountedStrides =
991 std::map<int64_t, unsigned> candidateStridesNumOccurences =
993 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
994 if (!unusedDims.test(dim))
996 int64_t originalStride = originalStrides[dim];
997 if (currUnaccountedStrides[originalStride] >
998 candidateStridesNumOccurences[originalStride]) {
1000 currUnaccountedStrides[originalStride]--;
1003 if (currUnaccountedStrides[originalStride] ==
1004 candidateStridesNumOccurences[originalStride]) {
1006 unusedDims.reset(dim);
1009 if (currUnaccountedStrides[originalStride] <
1010 candidateStridesNumOccurences[originalStride]) {
1017 if ((
int64_t)unusedDims.count() + reducedType.getRank() !=
1018 originalType.getRank())
1023llvm::SmallBitVector SubViewOp::getDroppedDims() {
1024 MemRefType sourceType = getSourceType();
1025 MemRefType resultType =
getType();
1026 FailureOr<llvm::SmallBitVector> unusedDims =
1028 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1034 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1039 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1046 if (indexVal < 0 || indexVal >= memrefType.getRank())
1050 if (!memrefType.isDynamicDim(
index.getInt())) {
1056 unsigned unsignedIndex =
index.getValue().getZExtValue();
1059 Operation *definingOp = getSource().getDefiningOp();
1061 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1062 return *(alloc.getDynamicSizes().begin() +
1063 memrefType.getDynamicDimIndex(unsignedIndex));
1065 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1066 return *(alloca.getDynamicSizes().begin() +
1067 memrefType.getDynamicDimIndex(unsignedIndex));
1069 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1070 return *(view.getDynamicSizes().begin() +
1071 memrefType.getDynamicDimIndex(unsignedIndex));
1073 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1074 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1075 unsigned resultIndex = 0;
1076 unsigned sourceRank = subview.getSourceType().getRank();
1077 unsigned sourceIndex = 0;
1078 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1079 if (unusedDims.test(i))
1081 if (resultIndex == unsignedIndex) {
1087 assert(subview.isDynamicSize(sourceIndex) &&
1088 "expected dynamic subview size");
1089 return subview.getDynamicSize(sourceIndex);
1103 using OpRewritePattern<DimOp>::OpRewritePattern;
1105 LogicalResult matchAndRewrite(DimOp dim,
1106 PatternRewriter &rewriter)
const override {
1107 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1111 dim,
"Dim op is not defined by a reshape op.");
1122 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1123 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1124 if (reshape->isBeforeInBlock(definingOp)) {
1127 "dim.getIndex is not defined before reshape in the same block.");
1132 else if (dim->getBlock() != reshape->getBlock() &&
1133 !dim.getIndex().getParentRegion()->isProperAncestor(
1134 reshape->getParentRegion())) {
1139 dim,
"dim.getIndex does not dominate reshape.");
1145 Location loc = dim.getLoc();
1147 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1148 if (
load.getType() != dim.getType())
1149 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(),
load);
1159 results.
add<DimOfMemRefReshape>(context);
1170 Value elementsPerStride) {
1171 result.addOperands(srcMemRef);
1172 result.addOperands(srcIndices);
1173 result.addOperands(destMemRef);
1174 result.addOperands(destIndices);
1175 result.addOperands({numElements, tagMemRef});
1176 result.addOperands(tagIndices);
1178 result.addOperands({stride, elementsPerStride});
1182 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1183 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1184 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1186 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1189 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1190 <<
", " << getTagMemRef().getType();
1231 bool isStrided = strideInfo.size() == 2;
1232 if (!strideInfo.empty() && !isStrided) {
1234 "expected two stride related operands");
1239 if (types.size() != 3)
1261LogicalResult DmaStartOp::verify() {
1262 unsigned numOperands = getNumOperands();
1266 if (numOperands < 4)
1267 return emitOpError(
"expected at least 4 operands");
1272 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1273 return emitOpError(
"expected source to be of memref type");
1274 if (numOperands < getSrcMemRefRank() + 4)
1275 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1277 if (!getSrcIndices().empty() &&
1278 !llvm::all_of(getSrcIndices().getTypes(),
1280 return emitOpError(
"expected source indices to be of index type");
1283 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1284 return emitOpError(
"expected destination to be of memref type");
1285 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1286 if (numOperands < numExpectedOperands)
1287 return emitOpError() <<
"expected at least " << numExpectedOperands
1289 if (!getDstIndices().empty() &&
1290 !llvm::all_of(getDstIndices().getTypes(),
1292 return emitOpError(
"expected destination indices to be of index type");
1296 return emitOpError(
"expected num elements to be of index type");
1299 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1300 return emitOpError(
"expected tag to be of memref type");
1301 numExpectedOperands += getTagMemRefRank();
1302 if (numOperands < numExpectedOperands)
1303 return emitOpError() <<
"expected at least " << numExpectedOperands
1305 if (!getTagIndices().empty() &&
1306 !llvm::all_of(getTagIndices().getTypes(),
1308 return emitOpError(
"expected tag indices to be of index type");
1312 if (numOperands != numExpectedOperands &&
1313 numOperands != numExpectedOperands + 2)
1314 return emitOpError(
"incorrect number of operands");
1318 if (!getStride().
getType().isIndex() ||
1319 !getNumElementsPerStride().
getType().isIndex())
1321 "expected stride and num elements per stride to be of type index");
1327LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1337LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1343LogicalResult DmaWaitOp::verify() {
1345 unsigned numTagIndices = getTagIndices().size();
1346 unsigned tagMemRefRank = getTagMemRefRank();
1347 if (numTagIndices != tagMemRefRank)
1348 return emitOpError() <<
"expected tagIndices to have the same number of "
1349 "elements as the tagMemRef rank, expected "
1350 << tagMemRefRank <<
", but got " << numTagIndices;
1358void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1360 setNameFn(getResult(),
"intptr");
1369LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1371 ExtractStridedMetadataOp::Adaptor adaptor,
1373 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1377 unsigned sourceRank = sourceType.getRank();
1378 IndexType indexType = IndexType::get(context);
1380 MemRefType::get({}, sourceType.getElementType(),
1381 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1383 inferredReturnTypes.push_back(memrefType);
1385 inferredReturnTypes.push_back(indexType);
1387 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1388 inferredReturnTypes.push_back(indexType);
1392void ExtractStridedMetadataOp::getAsmResultNames(
1394 setNameFn(getBaseBuffer(),
"base_buffer");
1395 setNameFn(getOffset(),
"offset");
1398 if (!getSizes().empty()) {
1399 setNameFn(getSizes().front(),
"sizes");
1400 setNameFn(getStrides().front(),
"strides");
1407template <
typename Container>
1411 assert(values.size() == maybeConstants.size() &&
1412 " expected values and maybeConstants of the same size");
1413 bool atLeastOneReplacement =
false;
1414 for (
auto [maybeConstant,
result] : llvm::zip(maybeConstants, values)) {
1419 assert(isa<Attribute>(maybeConstant) &&
1420 "The constified value should be either unchanged (i.e., == result) "
1424 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1429 atLeastOneReplacement =
true;
1432 return atLeastOneReplacement;
1436ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1442 getConstifiedMixedOffset());
1444 getConstifiedMixedSizes());
1446 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1449 if (
auto prev = getSource().getDefiningOp<CastOp>())
1450 if (isa<MemRefType>(prev.getSource().getType())) {
1451 getSourceMutable().assign(prev.getSource());
1452 atLeastOneReplacement =
true;
1455 return success(atLeastOneReplacement);
1465ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1469 LogicalResult status =
1470 getSource().getType().getStridesAndOffset(staticValues, unused);
1472 assert(succeeded(status) &&
"could not get strides from type");
1477OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1482 LogicalResult status =
1483 getSource().getType().getStridesAndOffset(unused, offset);
1485 assert(succeeded(status) &&
"could not get offset from type");
1486 staticValues.push_back(offset);
1501 if (
auto memrefType = llvm::dyn_cast<MemRefType>(
memref.getType())) {
1502 Type elementType = memrefType.getElementType();
1503 result.addTypes(elementType);
1511LogicalResult GenericAtomicRMWOp::verify() {
1512 auto &body = getRegion();
1513 if (body.getNumArguments() != 1)
1514 return emitOpError(
"expected single number of entry block arguments");
1516 if (getResult().
getType() != body.getArgument(0).getType())
1517 return emitOpError(
"expected block argument of the same type result type");
1524 "body of 'memref.generic_atomic_rmw' should contain "
1525 "only operations with no side effects");
1532ParseResult GenericAtomicRMWOp::parse(
OpAsmParser &parser,
1555 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1556 <<
"] : " << getMemref().
getType() <<
' ';
1565LogicalResult AtomicYieldOp::verify() {
1566 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1567 Type resultType = getResult().getType();
1568 if (parentType != resultType)
1569 return emitOpError() <<
"types mismatch between yield op: " << resultType
1570 <<
" and its parent: " << parentType;
1582 if (!op.isExternal()) {
1584 if (op.isUninitialized())
1585 p <<
"uninitialized";
1598 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1599 if (!memrefType || !memrefType.hasStaticShape())
1601 <<
"type should be static shaped memref, but got " << type;
1602 typeAttr = TypeAttr::get(type);
1608 initialValue = UnitAttr::get(parser.
getContext());
1615 if (!llvm::isa<ElementsAttr>(initialValue))
1617 <<
"initial value should be a unit or elements attribute";
1621LogicalResult GlobalOp::verify() {
1622 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1623 if (!memrefType || !memrefType.hasStaticShape())
1624 return emitOpError(
"type should be static shaped memref, but got ")
1629 if (getInitialValue().has_value()) {
1630 Attribute initValue = getInitialValue().value();
1631 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1632 return emitOpError(
"initial value should be a unit or elements "
1633 "attribute, but got ")
1638 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1640 auto initElementType =
1641 cast<TensorType>(elementsAttr.getType()).getElementType();
1642 auto memrefElementType = memrefType.getElementType();
1644 if (initElementType != memrefElementType)
1645 return emitOpError(
"initial value element expected to be of type ")
1646 << memrefElementType <<
", but was of type " << initElementType;
1651 auto initShape = elementsAttr.getShapedType().getShape();
1652 auto memrefShape = memrefType.getShape();
1653 if (initShape != memrefShape)
1654 return emitOpError(
"initial value shape expected to be ")
1655 << memrefShape <<
" but was " << initShape;
1663ElementsAttr GlobalOp::getConstantInitValue() {
1664 auto initVal = getInitialValue();
1665 if (getConstant() && initVal.has_value())
1666 return llvm::cast<ElementsAttr>(initVal.value());
1682 << getName() <<
"' does not reference a valid global memref";
1684 Type resultType = getResult().getType();
1685 if (global.getType() != resultType)
1687 << resultType <<
" does not match type " << global.getType()
1688 <<
" of the global memref @" << getName();
1696LogicalResult LoadOp::verify() {
1698 return emitOpError(
"incorrect number of indices for load, expected ")
1711FailureOr<std::optional<SmallVector<Value>>>
1712LoadOp::bubbleDownCasts(
OpBuilder &builder) {
1721void MemorySpaceCastOp::getAsmResultNames(
1723 setNameFn(getResult(),
"memspacecast");
1727 if (inputs.size() != 1 || outputs.size() != 1)
1729 Type a = inputs.front(),
b = outputs.front();
1730 auto aT = llvm::dyn_cast<MemRefType>(a);
1731 auto bT = llvm::dyn_cast<MemRefType>(
b);
1733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
1737 if (aT.getElementType() != bT.getElementType())
1739 if (aT.getLayout() != bT.getLayout())
1741 if (aT.getShape() != bT.getShape())
1746 return uaT.getElementType() == ubT.getElementType();
1751OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1754 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1755 getSourceMutable().assign(parentCast.getSource());
1769bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1770 PtrLikeTypeInterface src) {
1771 return isa<BaseMemRefType>(tgt) &&
1772 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1775MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1778 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1779 return MemorySpaceCastOp::create(
b, getLoc(), tgt, src);
1783bool MemorySpaceCastOp::isSourcePromotable() {
1784 return getDest().getType().getMemorySpace() ==
nullptr;
1792 p <<
" " << getMemref() <<
'[';
1794 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1795 p <<
", locality<" << getLocalityHint();
1796 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1798 (*this)->getAttrs(),
1799 {
"localityHint",
"isWrite",
"isDataCache"});
1806 IntegerAttr localityHint;
1808 StringRef readOrWrite, cacheType;
1825 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1827 "rw specifier has to be 'read' or 'write'");
1828 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1831 if (cacheType !=
"data" && cacheType !=
"instr")
1833 "cache type has to be 'data' or 'instr'");
1835 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1841LogicalResult PrefetchOp::verify() {
1848LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1860 auto type = getOperand().getType();
1861 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1862 if (shapedType && shapedType.hasRank())
1863 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1864 return IntegerAttr();
1871void ReinterpretCastOp::getAsmResultNames(
1873 setNameFn(getResult(),
"reinterpret_cast");
1880 MemRefType resultType,
Value source,
1889 result.addAttributes(attrs);
1890 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
1891 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
1892 b.getDenseI64ArrayAttr(staticSizes),
1893 b.getDenseI64ArrayAttr(staticStrides));
1901 auto sourceType = cast<BaseMemRefType>(source.
getType());
1907 auto stridedLayout = StridedLayoutAttr::get(
1908 b.getContext(), staticOffsets.front(), staticStrides);
1909 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1910 stridedLayout, sourceType.getMemorySpace());
1911 build(
b,
result, resultType, source, offset, sizes, strides, attrs);
1915 MemRefType resultType,
Value source,
1921 return b.getI64IntegerAttr(v);
1925 return b.getI64IntegerAttr(v);
1927 build(
b,
result, resultType, source,
b.getI64IntegerAttr(offset), sizeValues,
1928 strideValues, attrs);
1932 MemRefType resultType,
Value source,
Value offset,
1939 build(
b,
result, resultType, source, offset, sizeValues, strideValues, attrs);
1944LogicalResult ReinterpretCastOp::verify() {
1946 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1947 auto resultType = llvm::cast<MemRefType>(
getType());
1948 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1949 return emitError(
"different memory spaces specified for source type ")
1950 << srcType <<
" and result memref type " << resultType;
1951 if (srcType.getElementType() != resultType.getElementType())
1952 return emitError(
"different element types specified for source type ")
1953 << srcType <<
" and result memref type " << resultType;
1956 for (
auto [idx, resultSize, expectedSize] :
1957 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1958 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1959 return emitError(
"expected result type with size = ")
1960 << (ShapedType::isDynamic(expectedSize)
1961 ? std::string(
"dynamic")
1962 : std::to_string(expectedSize))
1963 <<
" instead of " << resultSize <<
" in dim = " << idx;
1971 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1972 return emitError(
"expected result type to have strided layout but found ")
1976 int64_t expectedOffset = getStaticOffsets().front();
1977 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1978 return emitError(
"expected result type with offset = ")
1979 << (ShapedType::isDynamic(expectedOffset)
1980 ? std::string(
"dynamic")
1981 : std::to_string(expectedOffset))
1982 <<
" instead of " << resultOffset;
1985 for (
auto [idx, resultStride, expectedStride] :
1986 llvm::enumerate(resultStrides, getStaticStrides())) {
1987 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1988 return emitError(
"expected result type with stride = ")
1989 << (ShapedType::isDynamic(expectedStride)
1990 ? std::string(
"dynamic")
1991 : std::to_string(expectedStride))
1992 <<
" instead of " << resultStride <<
" in dim = " << idx;
1999 Value src = getSource();
2000 auto getPrevSrc = [&]() ->
Value {
2003 return prev.getSource();
2007 return prev.getSource();
2013 return prev.getSource();
2018 if (
auto prevSrc = getPrevSrc()) {
2019 getSourceMutable().assign(prevSrc);
2042 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2044 assert(succeeded(status) &&
"could not get strides from type");
2049OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2051 assert(values.size() == 1 &&
2052 "reinterpret_cast must have one and only one offset");
2055 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2057 assert(succeeded(status) &&
"could not get offset from type");
2058 staticValues.push_back(offset);
2106struct ReinterpretCastOpExtractStridedMetadataFolder
2109 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2111 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2112 PatternRewriter &rewriter)
const override {
2113 auto extractStridedMetadata =
2114 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2115 if (!extractStridedMetadata)
2120 auto isReinterpretCastNoop = [&]() ->
bool {
2122 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2123 op.getConstifiedMixedStrides()))
2127 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2128 op.getConstifiedMixedSizes()))
2132 assert(op.getMixedOffsets().size() == 1 &&
2133 "reinterpret_cast with more than one offset should have been "
2134 "rejected by the verifier");
2135 return extractStridedMetadata.getConstifiedMixedOffset() ==
2136 op.getConstifiedMixedOffset();
2139 if (!isReinterpretCastNoop()) {
2156 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2166 Type srcTy = extractStridedMetadata.getSource().getType();
2167 if (srcTy == op.getResult().getType())
2168 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2171 extractStridedMetadata.getSource());
2177struct ReinterpretCastOpConstantFolder
2180 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2182 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2183 PatternRewriter &rewriter)
const override {
2184 unsigned srcStaticCount = llvm::count_if(
2185 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2186 op.getMixedStrides()),
2187 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2189 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2190 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2191 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2197 if (srcStaticCount ==
2198 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2199 [](OpFoldResult ofr) {
return isa<Attribute>(ofr); }))
2202 auto newReinterpretCast = ReinterpretCastOp::create(
2203 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2213 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2214 ReinterpretCastOpConstantFolder>(context);
2217FailureOr<std::optional<SmallVector<Value>>>
2218ReinterpretCastOp::bubbleDownCasts(
OpBuilder &builder) {
2226void CollapseShapeOp::getAsmResultNames(
2228 setNameFn(getResult(),
"collapse_shape");
2231void ExpandShapeOp::getAsmResultNames(
2233 setNameFn(getResult(),
"expand_shape");
2236LogicalResult ExpandShapeOp::reifyResultShapes(
2238 reifiedResultShapes = {
2239 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2252 bool allowMultipleDynamicDimsPerGroup) {
2254 if (collapsedShape.size() != reassociation.size())
2255 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2256 << reassociation.size() <<
", expected " << collapsedShape.size();
2261 for (
const auto &it : llvm::enumerate(reassociation)) {
2263 int64_t collapsedDim = it.index();
2265 bool foundDynamic =
false;
2266 for (
int64_t expandedDim : group) {
2267 if (expandedDim != nextDim++)
2268 return op->
emitOpError(
"reassociation indices must be contiguous");
2270 if (expandedDim >=
static_cast<int64_t>(expandedShape.size()))
2272 << expandedDim <<
" is out of bounds";
2275 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2276 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2278 "at most one dimension in a reassociation group may be dynamic");
2279 foundDynamic =
true;
2284 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2287 <<
") must be dynamic if and only if reassociation group is "
2292 if (!foundDynamic) {
2294 for (
int64_t expandedDim : group)
2295 groupSize *= expandedShape[expandedDim];
2296 if (groupSize != collapsedShape[collapsedDim])
2298 << collapsedShape[collapsedDim]
2299 <<
") must equal reassociation group size (" << groupSize <<
")";
2303 if (collapsedShape.empty()) {
2305 for (
int64_t d : expandedShape)
2308 "rank 0 memrefs can only be extended/collapsed with/from ones");
2309 }
else if (nextDim !=
static_cast<int64_t>(expandedShape.size())) {
2313 << expandedShape.size()
2314 <<
") inconsistent with number of reassociation indices (" << nextDim
2321SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2325SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2327 getReassociationIndices());
2330SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2334SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2336 getReassociationIndices());
2341static FailureOr<StridedLayoutAttr>
2346 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2348 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2363 reverseResultStrides.reserve(resultShape.size());
2364 unsigned shapeIndex = resultShape.size() - 1;
2365 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2367 int64_t currentStrideToExpand = std::get<1>(it);
2368 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2369 reverseResultStrides.push_back(currentStrideToExpand);
2370 currentStrideToExpand =
2376 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2377 resultStrides.resize(resultShape.size(), 1);
2378 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2381FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2382 MemRefType srcType, ArrayRef<int64_t> resultShape,
2383 ArrayRef<ReassociationIndices> reassociation) {
2384 if (srcType.getLayout().isIdentity()) {
2387 MemRefLayoutAttrInterface layout;
2388 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2389 srcType.getMemorySpace());
2393 FailureOr<StridedLayoutAttr> computedLayout =
2395 if (
failed(computedLayout))
2397 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2398 srcType.getMemorySpace());
2401FailureOr<SmallVector<OpFoldResult>>
2402ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
2403 MemRefType expandedType,
2404 ArrayRef<ReassociationIndices> reassociation,
2405 ArrayRef<OpFoldResult> inputShape) {
2406 std::optional<SmallVector<OpFoldResult>> outputShape =
2411 return *outputShape;
2414void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2415 Type resultType, Value src,
2416 ArrayRef<ReassociationIndices> reassociation,
2417 ArrayRef<OpFoldResult> outputShape) {
2418 auto [staticOutputShape, dynamicOutputShape] =
2420 build(builder,
result, llvm::cast<MemRefType>(resultType), src,
2422 dynamicOutputShape, staticOutputShape);
2425void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2426 Type resultType, Value src,
2427 ArrayRef<ReassociationIndices> reassociation) {
2428 SmallVector<OpFoldResult> inputShape =
2430 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2431 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2432 builder,
result.location, memrefResultTy, reassociation, inputShape);
2435 assert(succeeded(outputShape) &&
"unable to infer output shape");
2436 build(builder,
result, memrefResultTy, src, reassociation, *outputShape);
2439void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2440 ArrayRef<int64_t> resultShape, Value src,
2441 ArrayRef<ReassociationIndices> reassociation) {
2443 auto srcType = llvm::cast<MemRefType>(src.
getType());
2444 FailureOr<MemRefType> resultType =
2445 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2448 assert(succeeded(resultType) &&
"could not compute layout");
2449 build(builder,
result, *resultType, src, reassociation);
2452void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2453 ArrayRef<int64_t> resultShape, Value src,
2454 ArrayRef<ReassociationIndices> reassociation,
2455 ArrayRef<OpFoldResult> outputShape) {
2457 auto srcType = llvm::cast<MemRefType>(src.
getType());
2458 FailureOr<MemRefType> resultType =
2459 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2462 assert(succeeded(resultType) &&
"could not compute layout");
2463 build(builder,
result, *resultType, src, reassociation, outputShape);
2466LogicalResult ExpandShapeOp::verify() {
2467 MemRefType srcType = getSrcType();
2468 MemRefType resultType = getResultType();
2470 if (srcType.getRank() > resultType.getRank()) {
2471 auto r0 = srcType.getRank();
2472 auto r1 = resultType.getRank();
2474 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2475 << r0 <<
" > " << r1 <<
").";
2480 resultType.getShape(),
2481 getReassociationIndices(),
2486 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2487 srcType, resultType.getShape(), getReassociationIndices());
2488 if (
failed(expectedResultType))
2492 if (*expectedResultType != resultType)
2493 return emitOpError(
"expected expanded type to be ")
2494 << *expectedResultType <<
" but found " << resultType;
2496 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2497 return emitOpError(
"expected number of static shape bounds to be equal to "
2498 "the output rank (")
2499 << resultType.getRank() <<
") but found "
2500 << getStaticOutputShape().size() <<
" inputs instead";
2502 if ((int64_t)getOutputShape().size() !=
2503 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2504 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2505 "static_output_shape: static_output_shape has ")
2506 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2507 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2512 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2513 for (
auto [pos, shape] : llvm::enumerate(resShape)) {
2514 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2515 return emitOpError(
"invalid output shape provided at pos ") << pos;
2528 auto cast = op.getSrc().getDefiningOp<CastOp>();
2532 if (!CastOp::canFoldIntoConsumerOp(cast))
2540 for (
auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2542 if (!sizeOpt.has_value()) {
2543 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2547 newOutputShapeSizes.push_back(sizeOpt.value());
2548 newOutputShape[dimIdx] = rewriter.
getIndexAttr(sizeOpt.value());
2551 Value castSource = cast.getSource();
2552 auto castSourceType = llvm::cast<MemRefType>(castSource.
getType());
2554 op.getReassociationIndices();
2555 for (
auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2556 auto newOutputShapeSizesSlice =
2557 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2558 bool newOutputDynamic =
2559 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2560 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2562 op,
"folding cast will result in changing dynamicity in "
2563 "reassociation group");
2566 FailureOr<MemRefType> newResultTypeOrFailure =
2567 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2568 reassociationIndices);
2570 if (failed(newResultTypeOrFailure))
2572 op,
"could not compute new expanded type after folding cast");
2574 if (*newResultTypeOrFailure == op.getResultType()) {
2576 op, [&]() { op.getSrcMutable().assign(castSource); });
2578 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2579 *newResultTypeOrFailure, castSource,
2580 reassociationIndices, newOutputShape);
2587void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2588 MLIRContext *context) {
2590 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2591 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2592 ExpandShapeOpMemRefCastFolder>(context);
2595FailureOr<std::optional<SmallVector<Value>>>
2596ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2607static FailureOr<StridedLayoutAttr>
2610 bool strict =
false) {
2613 auto srcShape = srcType.getShape();
2614 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2623 resultStrides.reserve(reassociation.size());
2626 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2627 ref = ref.drop_back();
2628 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2629 resultStrides.push_back(srcStrides[ref.back()]);
2635 resultStrides.push_back(ShapedType::kDynamic);
2640 unsigned resultStrideIndex = resultStrides.size() - 1;
2644 for (
int64_t idx : llvm::reverse(trailingReassocs)) {
2649 if (srcShape[idx - 1] == 1)
2661 if (strict && (stride.saturated || srcStride.saturated))
2664 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2668 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2671bool CollapseShapeOp::isGuaranteedCollapsible(
2672 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2674 if (srcType.getLayout().isIdentity())
2681MemRefType CollapseShapeOp::computeCollapsedType(
2682 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2683 SmallVector<int64_t> resultShape;
2684 resultShape.reserve(reassociation.size());
2687 for (int64_t srcDim : group)
2690 resultShape.push_back(groupSize.asInteger());
2693 if (srcType.getLayout().isIdentity()) {
2696 MemRefLayoutAttrInterface layout;
2697 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2698 srcType.getMemorySpace());
2704 FailureOr<StridedLayoutAttr> computedLayout =
2706 assert(succeeded(computedLayout) &&
2707 "invalid source layout map or collapsing non-contiguous dims");
2708 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2709 srcType.getMemorySpace());
2712void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2713 ArrayRef<ReassociationIndices> reassociation,
2714 ArrayRef<NamedAttribute> attrs) {
2715 auto srcType = llvm::cast<MemRefType>(src.
getType());
2716 MemRefType resultType =
2717 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2720 build(
b,
result, resultType, src, attrs);
2723LogicalResult CollapseShapeOp::verify() {
2724 MemRefType srcType = getSrcType();
2725 MemRefType resultType = getResultType();
2727 if (srcType.getRank() < resultType.getRank()) {
2728 auto r0 = srcType.getRank();
2729 auto r1 = resultType.getRank();
2731 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2732 << r0 <<
" < " << r1 <<
").";
2737 srcType.getShape(), getReassociationIndices(),
2742 MemRefType expectedResultType;
2743 if (srcType.getLayout().isIdentity()) {
2746 MemRefLayoutAttrInterface layout;
2747 expectedResultType =
2748 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2749 srcType.getMemorySpace());
2754 FailureOr<StridedLayoutAttr> computedLayout =
2756 if (
failed(computedLayout))
2758 "invalid source layout map or collapsing non-contiguous dims");
2759 expectedResultType =
2760 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2761 *computedLayout, srcType.getMemorySpace());
2764 if (expectedResultType != resultType)
2765 return emitOpError(
"expected collapsed type to be ")
2766 << expectedResultType <<
" but found " << resultType;
2778 auto cast = op.getOperand().getDefiningOp<CastOp>();
2782 if (!CastOp::canFoldIntoConsumerOp(cast))
2785 Type newResultType = CollapseShapeOp::computeCollapsedType(
2786 llvm::cast<MemRefType>(cast.getOperand().getType()),
2787 op.getReassociationIndices());
2789 if (newResultType == op.getResultType()) {
2791 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2794 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2795 op.getReassociationIndices());
2802void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2803 MLIRContext *context) {
2805 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2806 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2807 memref::DimOp, MemRefType>,
2808 CollapseShapeOpMemRefCastFolder>(context);
2811OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2813 adaptor.getOperands());
2816OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2818 adaptor.getOperands());
2821FailureOr<std::optional<SmallVector<Value>>>
2822CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2830void ReshapeOp::getAsmResultNames(
2832 setNameFn(getResult(),
"reshape");
2835LogicalResult ReshapeOp::verify() {
2836 Type operandType = getSource().getType();
2837 Type resultType = getResult().getType();
2839 Type operandElementType =
2840 llvm::cast<ShapedType>(operandType).getElementType();
2841 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2842 if (operandElementType != resultElementType)
2843 return emitOpError(
"element types of source and destination memref "
2844 "types should be the same");
2846 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2847 if (!operandMemRefType.getLayout().isIdentity())
2848 return emitOpError(
"source memref type should have identity affine map");
2852 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2853 if (resultMemRefType) {
2854 if (!resultMemRefType.getLayout().isIdentity())
2855 return emitOpError(
"result memref type should have identity affine map");
2856 if (shapeSize == ShapedType::kDynamic)
2857 return emitOpError(
"cannot use shape operand with dynamic length to "
2858 "reshape to statically-ranked memref type");
2859 if (shapeSize != resultMemRefType.getRank())
2861 "length of shape operand differs from the result's memref rank");
2866FailureOr<std::optional<SmallVector<Value>>>
2867ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2875LogicalResult StoreOp::verify() {
2877 return emitOpError(
"store index operand count not equal to memref rank");
2882LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2883 SmallVectorImpl<OpFoldResult> &results) {
2888FailureOr<std::optional<SmallVector<Value>>>
2889StoreOp::bubbleDownCasts(OpBuilder &builder) {
2898void SubViewOp::getAsmResultNames(
2900 setNameFn(getResult(),
"subview");
2906MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2907 ArrayRef<int64_t> staticOffsets,
2908 ArrayRef<int64_t> staticSizes,
2909 ArrayRef<int64_t> staticStrides) {
2910 unsigned rank = sourceMemRefType.getRank();
2912 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2913 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2914 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2917 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2921 int64_t targetOffset = sourceOffset;
2922 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2923 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2932 SmallVector<int64_t, 4> targetStrides;
2933 targetStrides.reserve(staticOffsets.size());
2934 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2935 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2942 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2943 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2944 targetOffset, targetStrides),
2945 sourceMemRefType.getMemorySpace());
2948MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2949 ArrayRef<OpFoldResult> offsets,
2950 ArrayRef<OpFoldResult> sizes,
2951 ArrayRef<OpFoldResult> strides) {
2952 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2953 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2963 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2964 staticSizes, staticStrides);
2967MemRefType SubViewOp::inferRankReducedResultType(
2968 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2969 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2970 ArrayRef<int64_t> strides) {
2971 MemRefType inferredType =
2972 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2973 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2975 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2976 return inferredType;
2979 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2981 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2984 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2985 SmallVector<int64_t> rankReducedStrides;
2986 rankReducedStrides.reserve(resultShape.size());
2987 for (
auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2988 if (!dimsToProject->contains(idx))
2989 rankReducedStrides.push_back(value);
2991 return MemRefType::get(resultShape, inferredType.getElementType(),
2992 StridedLayoutAttr::get(inferredLayout.getContext(),
2993 inferredLayout.getOffset(),
2994 rankReducedStrides),
2995 inferredType.getMemorySpace());
2998MemRefType SubViewOp::inferRankReducedResultType(
2999 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3000 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3001 ArrayRef<OpFoldResult> strides) {
3002 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3003 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3007 return SubViewOp::inferRankReducedResultType(
3008 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3014void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3015 MemRefType resultType, Value source,
3016 ArrayRef<OpFoldResult> offsets,
3017 ArrayRef<OpFoldResult> sizes,
3018 ArrayRef<OpFoldResult> strides,
3019 ArrayRef<NamedAttribute> attrs) {
3020 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3021 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3025 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
3028 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3029 staticSizes, staticStrides);
3031 result.addAttributes(attrs);
3032 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
3033 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3034 b.getDenseI64ArrayAttr(staticSizes),
3035 b.getDenseI64ArrayAttr(staticStrides));
3040void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3041 ArrayRef<OpFoldResult> offsets,
3042 ArrayRef<OpFoldResult> sizes,
3043 ArrayRef<OpFoldResult> strides,
3044 ArrayRef<NamedAttribute> attrs) {
3045 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3049void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3050 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3051 ArrayRef<int64_t> strides,
3052 ArrayRef<NamedAttribute> attrs) {
3053 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3054 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3055 return b.getI64IntegerAttr(v);
3057 SmallVector<OpFoldResult> sizeValues =
3058 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3059 return b.getI64IntegerAttr(v);
3061 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3062 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3063 return b.getI64IntegerAttr(v);
3065 build(
b,
result, source, offsetValues, sizeValues, strideValues, attrs);
3070void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3071 MemRefType resultType, Value source,
3072 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3073 ArrayRef<int64_t> strides,
3074 ArrayRef<NamedAttribute> attrs) {
3075 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3076 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3077 return b.getI64IntegerAttr(v);
3079 SmallVector<OpFoldResult> sizeValues =
3080 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3081 return b.getI64IntegerAttr(v);
3083 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3084 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3085 return b.getI64IntegerAttr(v);
3087 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues,
3093void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3094 MemRefType resultType, Value source,
ValueRange offsets,
3096 ArrayRef<NamedAttribute> attrs) {
3097 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3098 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
3099 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3100 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
3101 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3102 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
3103 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
3107void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3109 ArrayRef<NamedAttribute> attrs) {
3110 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3114Value SubViewOp::getViewSource() {
return getSource(); }
3121 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3122 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3123 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3130 const llvm::SmallBitVector &droppedDims) {
3131 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3132 "incorrect number of bits");
3133 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3134 "incorrect number of dropped dims");
3137 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3138 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3139 if (failed(res1) || failed(res2))
3141 for (
int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3144 if (t1Strides[i] != t2Strides[
j])
3152 SubViewOp op,
Type expectedType) {
3153 auto memrefType = llvm::cast<ShapedType>(expectedType);
3158 return op->emitError(
"expected result rank to be smaller or equal to ")
3159 <<
"the source rank, but got " << op.getType();
3161 return op->emitError(
"expected result type to be ")
3163 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3166 return op->emitError(
"expected result element type to be ")
3167 << memrefType.getElementType() <<
", but got " << op.getType();
3169 return op->emitError(
3170 "expected result and source memory spaces to match, but got ")
3173 return op->emitError(
"expected result type to be ")
3175 <<
" or a rank-reduced version. (mismatch of result layout), but "
3179 llvm_unreachable(
"unexpected subview verification result");
3183LogicalResult SubViewOp::verify() {
3184 MemRefType baseType = getSourceType();
3185 MemRefType subViewType =
getType();
3186 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3187 ArrayRef<int64_t> staticSizes = getStaticSizes();
3188 ArrayRef<int64_t> staticStrides = getStaticStrides();
3191 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3192 return emitError(
"different memory spaces specified for base memref "
3194 << baseType <<
" and subview memref type " << subViewType;
3197 if (!baseType.isStrided())
3198 return emitError(
"base type ") << baseType <<
" is not strided";
3202 MemRefType expectedType = SubViewOp::inferResultType(
3203 baseType, staticOffsets, staticSizes, staticStrides);
3208 expectedType, subViewType);
3213 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3215 *
this, expectedType);
3220 *
this, expectedType);
3230 *
this, expectedType);
3235 *
this, expectedType);
3239 SliceBoundsVerificationResult boundsResult =
3241 staticStrides,
true);
3243 return getOperation()->emitError(boundsResult.
errorMessage);
3249 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3258 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3259 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3260 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3262 unsigned rank = ranks[0];
3264 for (
unsigned idx = 0; idx < rank; ++idx) {
3266 op.isDynamicOffset(idx)
3267 ? op.getDynamicOffset(idx)
3270 op.isDynamicSize(idx)
3271 ? op.getDynamicSize(idx)
3274 op.isDynamicStride(idx)
3275 ? op.getDynamicStride(idx)
3277 res.emplace_back(
Range{offset, size, stride});
3290 MemRefType currentResultType, MemRefType currentSourceType,
3293 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3294 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3296 currentSourceType, currentResultType, mixedSizes);
3297 if (failed(unusedDims))
3300 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3302 unsigned numDimsAfterReduction =
3303 nonRankReducedType.getRank() - unusedDims->count();
3304 shape.reserve(numDimsAfterReduction);
3305 strides.reserve(numDimsAfterReduction);
3306 for (
const auto &[idx, size, stride] :
3307 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3308 nonRankReducedType.getShape(), layout.getStrides())) {
3309 if (unusedDims->test(idx))
3311 shape.push_back(size);
3312 strides.push_back(stride);
3315 return MemRefType::get(
shape, nonRankReducedType.getElementType(),
3316 StridedLayoutAttr::get(sourceType.getContext(),
3317 layout.getOffset(), strides),
3318 nonRankReducedType.getMemorySpace());
3323 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3324 unsigned rank = memrefType.getRank();
3328 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3329 targetShape, memrefType, offsets, sizes, strides);
3330 return b.createOrFold<memref::SubViewOp>(loc, targetType,
memref, offsets,
3337 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3338 assert(sourceMemrefType &&
"not a ranked memref type");
3339 auto sourceShape = sourceMemrefType.getShape();
3340 if (sourceShape.equals(desiredShape))
3342 auto maybeRankReductionMask =
3344 if (!maybeRankReductionMask)
3354 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3357 auto mixedOffsets = subViewOp.getMixedOffsets();
3358 auto mixedSizes = subViewOp.getMixedSizes();
3359 auto mixedStrides = subViewOp.getMixedStrides();
3364 return !intValue || intValue.value() != 0;
3371 return !intValue || intValue.value() != 1;
3377 for (
const auto &size : llvm::enumerate(mixedSizes)) {
3379 if (!intValue || *intValue != sourceShape[size.index()])
3403class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3405 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3407 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3408 PatternRewriter &rewriter)
const override {
3411 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3412 return matchPattern(operand, matchConstantIndex());
3416 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3420 if (!CastOp::canFoldIntoConsumerOp(castOp))
3428 subViewOp.getType(), subViewOp.getSourceType(),
3429 llvm::cast<MemRefType>(castOp.getSource().getType()),
3430 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3431 subViewOp.getMixedStrides());
3435 Value newSubView = SubViewOp::create(
3436 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3437 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3438 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3439 subViewOp.getStaticStrides());
3448class TrivialSubViewOpFolder final :
public OpRewritePattern<SubViewOp> {
3450 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3452 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3453 PatternRewriter &rewriter)
const override {
3456 if (subViewOp.getSourceType() == subViewOp.getType()) {
3457 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3461 subViewOp.getSource());
3473 MemRefType resTy = SubViewOp::inferResultType(
3474 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3477 MemRefType nonReducedType = resTy;
3480 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3481 if (droppedDims.none())
3482 return nonReducedType;
3485 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3490 for (
int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3491 if (droppedDims.test(i))
3493 targetStrides.push_back(nonReducedStrides[i]);
3494 targetShape.push_back(nonReducedType.getDimSize(i));
3497 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3498 StridedLayoutAttr::get(nonReducedType.getContext(),
3499 offset, targetStrides),
3500 nonReducedType.getMemorySpace());
3511void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3512 MLIRContext *context) {
3514 .
add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3515 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3516 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3519OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3520 MemRefType sourceMemrefType = getSource().getType();
3521 MemRefType resultMemrefType = getResult().getType();
3523 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3525 if (resultMemrefType == sourceMemrefType &&
3526 resultMemrefType.hasStaticShape() &&
3527 (!resultLayout || resultLayout.hasStaticLayout())) {
3528 return getViewSource();
3534 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3535 auto srcSizes = srcSubview.getMixedSizes();
3537 auto offsets = getMixedOffsets();
3539 auto strides = getMixedStrides();
3540 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3541 bool allSizesSame = llvm::equal(sizes, srcSizes);
3542 if (allOffsetsZero && allStridesOne && allSizesSame &&
3543 resultMemrefType == sourceMemrefType)
3544 return getViewSource();
3550FailureOr<std::optional<SmallVector<Value>>>
3551SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3555void SubViewOp::inferStridedMetadataRanges(
3556 ArrayRef<StridedMetadataRange> ranges,
GetIntRangeFn getIntRange,
3558 auto isUninitialized =
3559 +[](IntegerValueRange range) {
return range.isUninitialized(); };
3562 SmallVector<IntegerValueRange> offsetOperands =
3564 if (llvm::any_of(offsetOperands, isUninitialized))
3567 SmallVector<IntegerValueRange> sizeOperands =
3569 if (llvm::any_of(sizeOperands, isUninitialized))
3572 SmallVector<IntegerValueRange> stridesOperands =
3574 if (llvm::any_of(stridesOperands, isUninitialized))
3577 StridedMetadataRange sourceRange =
3578 ranges[getSourceMutable().getOperandNumber()];
3582 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.
getStrides();
3588 ConstantIntRanges offset = sourceRange.
getOffsets()[0];
3589 SmallVector<ConstantIntRanges> strides, sizes;
3591 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3592 bool dropped = droppedDims.test(i);
3594 ConstantIntRanges off =
3605 sizes.push_back(sizeOperands[i].getValue());
3608 setMetadata(getResult(),
3610 SmallVector<ConstantIntRanges>({std::move(offset)}),
3611 std::move(sizes), std::move(strides)));
3618void TransposeOp::getAsmResultNames(
3620 setNameFn(getResult(),
"transpose");
3626 auto originalSizes = memRefType.getShape();
3627 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3628 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3637 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3640void TransposeOp::build(OpBuilder &
b, OperationState &
result, Value in,
3641 AffineMapAttr permutation,
3642 ArrayRef<NamedAttribute> attrs) {
3643 auto permutationMap = permutation.getValue();
3644 assert(permutationMap);
3646 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3650 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3651 build(
b,
result, resultType, in, attrs);
3655void TransposeOp::print(OpAsmPrinter &p) {
3656 p <<
" " << getIn() <<
" " << getPermutation();
3658 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3661ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
3662 OpAsmParser::UnresolvedOperand in;
3663 AffineMap permutation;
3664 MemRefType srcType, dstType;
3673 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3674 AffineMapAttr::get(permutation));
3678LogicalResult TransposeOp::verify() {
3681 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3682 return emitOpError(
"expected a permutation map of same rank as the input");
3684 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3685 auto resultType = llvm::cast<MemRefType>(
getType());
3687 .canonicalizeStridedLayout();
3689 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3692 <<
" is not equivalent to the canonical transposed input type "
3693 << canonicalResultType;
3697OpFoldResult TransposeOp::fold(FoldAdaptor) {
3700 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3704 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3705 AffineMap composedPermutation =
3706 getPermutation().compose(otherTransposeOp.getPermutation());
3707 getInMutable().assign(otherTransposeOp.getIn());
3708 setPermutation(composedPermutation);
3714FailureOr<std::optional<SmallVector<Value>>>
3715TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3723void ViewOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
3724 setNameFn(getResult(),
"view");
3727LogicalResult ViewOp::verify() {
3728 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3732 if (!baseType.getLayout().isIdentity())
3733 return emitError(
"unsupported map for base memref type ") << baseType;
3736 if (!viewType.getLayout().isIdentity())
3737 return emitError(
"unsupported map for result memref type ") << viewType;
3740 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3741 return emitError(
"different memory spaces specified for base memref "
3743 << baseType <<
" and view memref type " << viewType;
3752Value ViewOp::getViewSource() {
return getSource(); }
3754OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3755 MemRefType sourceMemrefType = getSource().getType();
3756 MemRefType resultMemrefType = getResult().getType();
3758 if (resultMemrefType == sourceMemrefType &&
3759 resultMemrefType.hasStaticShape() &&
isZeroInteger(getByteShift()))
3760 return getViewSource();
3767struct ViewOpShapeFolder :
public OpRewritePattern<ViewOp> {
3770 LogicalResult matchAndRewrite(ViewOp viewOp,
3771 PatternRewriter &rewriter)
const override {
3773 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3774 return matchPattern(operand, matchConstantIndex());
3779 auto memrefType = viewOp.getType();
3783 SmallVector<int64_t, 4> oldStrides;
3784 if (
failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3786 assert(oldOffset == 0 &&
"Expected 0 offset");
3788 SmallVector<Value, 4> newOperands;
3793 SmallVector<int64_t, 4> newShapeConstants;
3794 newShapeConstants.reserve(memrefType.getRank());
3796 unsigned dynamicDimPos = 0;
3797 unsigned rank = memrefType.getRank();
3798 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3799 int64_t dimSize = memrefType.getDimSize(dim);
3801 if (ShapedType::isStatic(dimSize)) {
3802 newShapeConstants.push_back(dimSize);
3805 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3806 if (
auto constantIndexOp =
3807 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3809 newShapeConstants.push_back(constantIndexOp.value());
3812 newShapeConstants.push_back(dimSize);
3813 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3819 MemRefType newMemRefType =
3820 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3822 if (newMemRefType == memrefType)
3826 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3827 viewOp.getOperand(0), viewOp.getByteShift(),
3836struct ViewOpMemrefCastFolder :
public OpRewritePattern<ViewOp> {
3839 LogicalResult matchAndRewrite(ViewOp viewOp,
3840 PatternRewriter &rewriter)
const override {
3841 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3846 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3847 viewOp.getByteShift(), viewOp.getSizes());
3853void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3854 MLIRContext *context) {
3855 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3858FailureOr<std::optional<SmallVector<Value>>>
3859ViewOp::bubbleDownCasts(OpBuilder &builder) {
3867LogicalResult AtomicRMWOp::verify() {
3870 "expects the number of subscripts to be equal to memref rank");
3871 switch (getKind()) {
3872 case arith::AtomicRMWKind::addf:
3873 case arith::AtomicRMWKind::maximumf:
3874 case arith::AtomicRMWKind::minimumf:
3875 case arith::AtomicRMWKind::mulf:
3876 if (!llvm::isa<FloatType>(getValue().
getType()))
3878 << arith::stringifyAtomicRMWKind(getKind())
3879 <<
"' expects a floating-point type";
3881 case arith::AtomicRMWKind::addi:
3882 case arith::AtomicRMWKind::maxs:
3883 case arith::AtomicRMWKind::maxu:
3884 case arith::AtomicRMWKind::mins:
3885 case arith::AtomicRMWKind::minu:
3886 case arith::AtomicRMWKind::muli:
3887 case arith::AtomicRMWKind::ori:
3888 case arith::AtomicRMWKind::xori:
3889 case arith::AtomicRMWKind::andi:
3890 if (!llvm::isa<IntegerType>(getValue().
getType()))
3892 << arith::stringifyAtomicRMWKind(getKind())
3893 <<
"' expects an integer type";
3901OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3905 return OpFoldResult();
3908FailureOr<std::optional<SmallVector<Value>>>
3909AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3918#define GET_OP_CLASSES
3919#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.