25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
37 return arith::ConstantOp::materialize(builder, value, type, loc);
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
63 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(
memref.getShape(),
memref.getElementType());
65 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(
memref.getElementType());
72 auto memrefType = llvm::cast<MemRefType>(value.
getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
81 auto memrefType = llvm::cast<MemRefType>(value.
getType());
83 for (
int64_t i = 0; i < memrefType.getRank(); ++i)
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (
auto [i, cstVal] : llvm::enumerate(constValues)) {
104 if (ShapedType::isStatic(cstVal)) {
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
148template <
typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
159 llvm::append_range(operands, op->getOperands());
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
168 MemorySpaceCastOpInterface
result = castOp.cloneMemorySpaceCastOp(
171 return std::optional<SmallVector<Value>>(
179void AllocOp::getAsmResultNames(
181 setNameFn(getResult(),
"alloc");
184void AllocaOp::getAsmResultNames(
186 setNameFn(getResult(),
"alloca");
189template <
typename AllocLikeOp>
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
195 return op.emitOpError(
"result must be a memref");
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError(
"symbol operand count does not equal memref symbol "
206 << numSymbols <<
", got " << op.getSymbolOperands().size();
213LogicalResult AllocaOp::verify() {
217 "requires an ancestor op with AutomaticAllocationScope trait");
224template <
typename AllocLikeOp>
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter)
const override {
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
236 return constSizeArg.isNonNegative();
240 auto memrefType = alloc.getType();
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
248 unsigned dynamicDimPos = 0;
249 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
259 constSizeArg.isNonNegative()) {
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
288 using OpRewritePattern<T>::OpRewritePattern;
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter)
const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
310 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
315 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
325 MemRefType resultType =
getType();
328 if (!sourceType.getLayout().isIdentity())
329 return emitError(
"unsupported layout for source memref type ")
333 if (!resultType.getLayout().isIdentity())
334 return emitError(
"unsupported layout for result memref type ")
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError(
"different memory spaces specified for source memref "
341 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
385 result.regions.reserve(1);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405void AllocaScopeOp::getSuccessorRegions(
422 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
445 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
476 bool hasPotentialAlloca =
489 if (hasPotentialAlloca) {
522 if (!lastParentWithoutScope ||
535 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
536 if (!lastParentWithoutScope ||
543 Region *containingRegion =
nullptr;
544 for (
auto &r : lastParentWithoutScope->
getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion ==
nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
551 assert(containingRegion &&
"op must be contained in a region");
561 return containingRegion->isAncestor(v.getParentRegion());
564 toHoist.push_back(alloc);
571 for (
auto *op : toHoist) {
572 auto *cloned = rewriter.
clone(*op);
573 rewriter.
replaceOp(op, cloned->getResults());
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError(
"alignment must be power of 2");
594void AssumeAlignmentOp::getAsmResultNames(
596 setNameFn(getResult(),
"assume_align");
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
603 if (source.getAlignment() != getAlignment())
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(
OpBuilder &builder,
616 assert(resultIndex == 0 &&
"AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError(
"operand types and result types must match");
628 if (getOperandTypes().empty())
629 return emitOpError(
"expected at least one operand");
634LogicalResult DistinctObjectsOp::inferReturnTypes(
639 llvm::copy(operands.
getTypes(), std::back_inserter(inferredReturnTypes));
648 setNameFn(getResult(),
"cast");
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
694 if (!sourceType || !resultType)
698 if (sourceType.getElementType() != resultType.getElementType())
702 if (sourceType.getRank() != resultType.getRank())
706 int64_t sourceOffset, resultOffset;
708 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
713 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
727 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
738 if (inputs.size() != 1 || outputs.size() != 1)
740 Type a = inputs.front(),
b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(
b);
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
748 if (aT.getElementType() != bT.getElementType())
750 if (aT.getLayout() != bT.getLayout()) {
753 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(
b) || a ==
b);
766 if (!checkCompatible(aOffset, bOffset))
769 if (aT.getDimSize(
index) == 1 || bT.getDimSize(
index) == 1)
771 if (!checkCompatible(aStride, bStrides[
index]))
775 if (aT.getMemorySpace() != bT.getMemorySpace())
779 if (aT.getRank() != bT.getRank())
782 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(
OpBuilder &builder) {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter)
const override {
832 if (copyOp.getSource() != copyOp.getTarget())
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
843 static bool isEmptyMemRef(BaseMemRefType type) {
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter)
const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
862 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
869 for (
OpOperand &operand : op->getOpOperands()) {
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
901 setNameFn(getResult(),
"dim");
906 auto loc =
result.location;
908 build(builder,
result, source, indexValue);
911std::optional<int64_t> DimOp::getConstantIndex() {
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
921 if (!rankedSourceType)
924 if (rankedSourceType.getRank() <= constantIndex)
932 setResultRange(getResult(),
941 std::map<int64_t, unsigned> numOccurences;
942 for (
auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
954static FailureOr<llvm::SmallBitVector>
957 llvm::SmallBitVector unusedDims(originalType.getRank());
958 if (originalType.getRank() == reducedType.getRank())
961 for (
const auto &dim : llvm::enumerate(sizes))
962 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
963 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
964 unusedDims.set(dim.index());
968 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
969 originalType.getRank())
973 int64_t originalOffset, candidateOffset;
975 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
977 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
989 std::map<int64_t, unsigned> currUnaccountedStrides =
991 std::map<int64_t, unsigned> candidateStridesNumOccurences =
993 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
994 if (!unusedDims.test(dim))
996 int64_t originalStride = originalStrides[dim];
997 if (currUnaccountedStrides[originalStride] >
998 candidateStridesNumOccurences[originalStride]) {
1000 currUnaccountedStrides[originalStride]--;
1003 if (currUnaccountedStrides[originalStride] ==
1004 candidateStridesNumOccurences[originalStride]) {
1006 unusedDims.reset(dim);
1009 if (currUnaccountedStrides[originalStride] <
1010 candidateStridesNumOccurences[originalStride]) {
1017 if ((
int64_t)unusedDims.count() + reducedType.getRank() !=
1018 originalType.getRank())
1023llvm::SmallBitVector SubViewOp::getDroppedDims() {
1024 MemRefType sourceType = getSourceType();
1025 MemRefType resultType =
getType();
1026 FailureOr<llvm::SmallBitVector> unusedDims =
1028 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1034 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1039 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1046 if (indexVal < 0 || indexVal >= memrefType.getRank())
1050 if (!memrefType.isDynamicDim(
index.getInt())) {
1056 unsigned unsignedIndex =
index.getValue().getZExtValue();
1059 Operation *definingOp = getSource().getDefiningOp();
1061 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1062 return *(alloc.getDynamicSizes().begin() +
1063 memrefType.getDynamicDimIndex(unsignedIndex));
1065 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1066 return *(alloca.getDynamicSizes().begin() +
1067 memrefType.getDynamicDimIndex(unsignedIndex));
1069 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1070 return *(view.getDynamicSizes().begin() +
1071 memrefType.getDynamicDimIndex(unsignedIndex));
1073 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1074 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1075 unsigned resultIndex = 0;
1076 unsigned sourceRank = subview.getSourceType().getRank();
1077 unsigned sourceIndex = 0;
1078 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1079 if (unusedDims.test(i))
1081 if (resultIndex == unsignedIndex) {
1087 assert(subview.isDynamicSize(sourceIndex) &&
1088 "expected dynamic subview size");
1089 return subview.getDynamicSize(sourceIndex);
1103 using OpRewritePattern<DimOp>::OpRewritePattern;
1105 LogicalResult matchAndRewrite(DimOp dim,
1106 PatternRewriter &rewriter)
const override {
1107 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1111 dim,
"Dim op is not defined by a reshape op.");
1122 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1123 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1124 if (reshape->isBeforeInBlock(definingOp)) {
1127 "dim.getIndex is not defined before reshape in the same block.");
1132 else if (dim->getBlock() != reshape->getBlock() &&
1133 !dim.getIndex().getParentRegion()->isProperAncestor(
1134 reshape->getParentRegion())) {
1139 dim,
"dim.getIndex does not dominate reshape.");
1145 Location loc = dim.getLoc();
1147 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1148 if (
load.getType() != dim.getType())
1149 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(),
load);
1159 results.
add<DimOfMemRefReshape>(context);
1170 Value elementsPerStride) {
1171 result.addOperands(srcMemRef);
1172 result.addOperands(srcIndices);
1173 result.addOperands(destMemRef);
1174 result.addOperands(destIndices);
1175 result.addOperands({numElements, tagMemRef});
1176 result.addOperands(tagIndices);
1178 result.addOperands({stride, elementsPerStride});
1182 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1183 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1184 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1186 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1189 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1190 <<
", " << getTagMemRef().getType();
1231 bool isStrided = strideInfo.size() == 2;
1232 if (!strideInfo.empty() && !isStrided) {
1234 "expected two stride related operands");
1239 if (types.size() != 3)
1261LogicalResult DmaStartOp::verify() {
1262 unsigned numOperands = getNumOperands();
1266 if (numOperands < 4)
1267 return emitOpError(
"expected at least 4 operands");
1272 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1273 return emitOpError(
"expected source to be of memref type");
1274 if (numOperands < getSrcMemRefRank() + 4)
1275 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1277 if (!getSrcIndices().empty() &&
1278 !llvm::all_of(getSrcIndices().getTypes(),
1280 return emitOpError(
"expected source indices to be of index type");
1283 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1284 return emitOpError(
"expected destination to be of memref type");
1285 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1286 if (numOperands < numExpectedOperands)
1287 return emitOpError() <<
"expected at least " << numExpectedOperands
1289 if (!getDstIndices().empty() &&
1290 !llvm::all_of(getDstIndices().getTypes(),
1292 return emitOpError(
"expected destination indices to be of index type");
1296 return emitOpError(
"expected num elements to be of index type");
1299 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1300 return emitOpError(
"expected tag to be of memref type");
1301 numExpectedOperands += getTagMemRefRank();
1302 if (numOperands < numExpectedOperands)
1303 return emitOpError() <<
"expected at least " << numExpectedOperands
1305 if (!getTagIndices().empty() &&
1306 !llvm::all_of(getTagIndices().getTypes(),
1308 return emitOpError(
"expected tag indices to be of index type");
1312 if (numOperands != numExpectedOperands &&
1313 numOperands != numExpectedOperands + 2)
1314 return emitOpError(
"incorrect number of operands");
1318 if (!getStride().
getType().isIndex() ||
1319 !getNumElementsPerStride().
getType().isIndex())
1321 "expected stride and num elements per stride to be of type index");
1327LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1337LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1343LogicalResult DmaWaitOp::verify() {
1345 unsigned numTagIndices = getTagIndices().size();
1346 unsigned tagMemRefRank = getTagMemRefRank();
1347 if (numTagIndices != tagMemRefRank)
1348 return emitOpError() <<
"expected tagIndices to have the same number of "
1349 "elements as the tagMemRef rank, expected "
1350 << tagMemRefRank <<
", but got " << numTagIndices;
1358void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1360 setNameFn(getResult(),
"intptr");
1369LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1371 ExtractStridedMetadataOp::Adaptor adaptor,
1373 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1377 unsigned sourceRank = sourceType.getRank();
1378 IndexType indexType = IndexType::get(context);
1380 MemRefType::get({}, sourceType.getElementType(),
1381 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1383 inferredReturnTypes.push_back(memrefType);
1385 inferredReturnTypes.push_back(indexType);
1387 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1388 inferredReturnTypes.push_back(indexType);
1392void ExtractStridedMetadataOp::getAsmResultNames(
1394 setNameFn(getBaseBuffer(),
"base_buffer");
1395 setNameFn(getOffset(),
"offset");
1398 if (!getSizes().empty()) {
1399 setNameFn(getSizes().front(),
"sizes");
1400 setNameFn(getStrides().front(),
"strides");
1407template <
typename Container>
1411 assert(values.size() == maybeConstants.size() &&
1412 " expected values and maybeConstants of the same size");
1413 bool atLeastOneReplacement =
false;
1414 for (
auto [maybeConstant,
result] : llvm::zip(maybeConstants, values)) {
1419 assert(isa<Attribute>(maybeConstant) &&
1420 "The constified value should be either unchanged (i.e., == result) "
1424 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1429 atLeastOneReplacement =
true;
1432 return atLeastOneReplacement;
1436ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1442 getConstifiedMixedOffset());
1444 getConstifiedMixedSizes());
1446 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1449 if (
auto prev = getSource().getDefiningOp<CastOp>())
1450 if (isa<MemRefType>(prev.getSource().getType())) {
1451 getSourceMutable().assign(prev.getSource());
1452 atLeastOneReplacement =
true;
1455 return success(atLeastOneReplacement);
1465ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1469 LogicalResult status =
1470 getSource().getType().getStridesAndOffset(staticValues, unused);
1472 assert(succeeded(status) &&
"could not get strides from type");
1477OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1482 LogicalResult status =
1483 getSource().getType().getStridesAndOffset(unused, offset);
1485 assert(succeeded(status) &&
"could not get offset from type");
1486 staticValues.push_back(offset);
1501 if (
auto memrefType = llvm::dyn_cast<MemRefType>(
memref.getType())) {
1502 Type elementType = memrefType.getElementType();
1503 result.addTypes(elementType);
1511LogicalResult GenericAtomicRMWOp::verify() {
1512 auto &body = getRegion();
1513 if (body.getNumArguments() != 1)
1514 return emitOpError(
"expected single number of entry block arguments");
1516 if (getResult().
getType() != body.getArgument(0).getType())
1517 return emitOpError(
"expected block argument of the same type result type");
1524 "body of 'memref.generic_atomic_rmw' should contain "
1525 "only operations with no side effects");
1532ParseResult GenericAtomicRMWOp::parse(
OpAsmParser &parser,
1555 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1556 <<
"] : " << getMemref().
getType() <<
' ';
1565LogicalResult AtomicYieldOp::verify() {
1566 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1567 Type resultType = getResult().getType();
1568 if (parentType != resultType)
1569 return emitOpError() <<
"types mismatch between yield op: " << resultType
1570 <<
" and its parent: " << parentType;
1582 if (!op.isExternal()) {
1584 if (op.isUninitialized())
1585 p <<
"uninitialized";
1598 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1599 if (!memrefType || !memrefType.hasStaticShape())
1601 <<
"type should be static shaped memref, but got " << type;
1602 typeAttr = TypeAttr::get(type);
1608 initialValue = UnitAttr::get(parser.
getContext());
1615 if (!llvm::isa<ElementsAttr>(initialValue))
1617 <<
"initial value should be a unit or elements attribute";
1621LogicalResult GlobalOp::verify() {
1622 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1623 if (!memrefType || !memrefType.hasStaticShape())
1624 return emitOpError(
"type should be static shaped memref, but got ")
1629 if (getInitialValue().has_value()) {
1630 Attribute initValue = getInitialValue().value();
1631 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1632 return emitOpError(
"initial value should be a unit or elements "
1633 "attribute, but got ")
1638 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1640 auto initElementType =
1641 cast<TensorType>(elementsAttr.getType()).getElementType();
1642 auto memrefElementType = memrefType.getElementType();
1644 if (initElementType != memrefElementType)
1645 return emitOpError(
"initial value element expected to be of type ")
1646 << memrefElementType <<
", but was of type " << initElementType;
1651 auto initShape = elementsAttr.getShapedType().getShape();
1652 auto memrefShape = memrefType.getShape();
1653 if (initShape != memrefShape)
1654 return emitOpError(
"initial value shape expected to be ")
1655 << memrefShape <<
" but was " << initShape;
1663ElementsAttr GlobalOp::getConstantInitValue() {
1664 auto initVal = getInitialValue();
1665 if (getConstant() && initVal.has_value())
1666 return llvm::cast<ElementsAttr>(initVal.value());
1682 << getName() <<
"' does not reference a valid global memref";
1684 Type resultType = getResult().getType();
1685 if (global.getType() != resultType)
1687 << resultType <<
" does not match type " << global.getType()
1688 <<
" of the global memref @" << getName();
1696LogicalResult LoadOp::verify() {
1698 return emitOpError(
"incorrect number of indices for load, expected ")
1710 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1716 getGlobalOp, getGlobalOp.getNameAttr());
1721 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1725 return splatAttr.getSplatValue<
Attribute>();
1728FailureOr<std::optional<SmallVector<Value>>>
1729LoadOp::bubbleDownCasts(
OpBuilder &builder) {
1738void MemorySpaceCastOp::getAsmResultNames(
1740 setNameFn(getResult(),
"memspacecast");
1744 if (inputs.size() != 1 || outputs.size() != 1)
1746 Type a = inputs.front(),
b = outputs.front();
1747 auto aT = llvm::dyn_cast<MemRefType>(a);
1748 auto bT = llvm::dyn_cast<MemRefType>(
b);
1750 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1751 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
1754 if (aT.getElementType() != bT.getElementType())
1756 if (aT.getLayout() != bT.getLayout())
1758 if (aT.getShape() != bT.getShape())
1763 return uaT.getElementType() == ubT.getElementType();
1768OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1771 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1772 getSourceMutable().assign(parentCast.getSource());
1786bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1787 PtrLikeTypeInterface src) {
1788 return isa<BaseMemRefType>(tgt) &&
1789 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1792MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1795 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1796 return MemorySpaceCastOp::create(
b, getLoc(), tgt, src);
1800bool MemorySpaceCastOp::isSourcePromotable() {
1801 return getDest().getType().getMemorySpace() ==
nullptr;
1809 p <<
" " << getMemref() <<
'[';
1811 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1812 p <<
", locality<" << getLocalityHint();
1813 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1815 (*this)->getAttrs(),
1816 {
"localityHint",
"isWrite",
"isDataCache"});
1823 IntegerAttr localityHint;
1825 StringRef readOrWrite, cacheType;
1842 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1844 "rw specifier has to be 'read' or 'write'");
1845 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1848 if (cacheType !=
"data" && cacheType !=
"instr")
1850 "cache type has to be 'data' or 'instr'");
1852 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1858LogicalResult PrefetchOp::verify() {
1865LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1877 auto type = getOperand().getType();
1878 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1879 if (shapedType && shapedType.hasRank())
1880 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1881 return IntegerAttr();
1888void ReinterpretCastOp::getAsmResultNames(
1890 setNameFn(getResult(),
"reinterpret_cast");
1897 MemRefType resultType,
Value source,
1906 result.addAttributes(attrs);
1907 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
1908 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
1909 b.getDenseI64ArrayAttr(staticSizes),
1910 b.getDenseI64ArrayAttr(staticStrides));
1918 auto sourceType = cast<BaseMemRefType>(source.
getType());
1924 auto stridedLayout = StridedLayoutAttr::get(
1925 b.getContext(), staticOffsets.front(), staticStrides);
1926 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1927 stridedLayout, sourceType.getMemorySpace());
1928 build(
b,
result, resultType, source, offset, sizes, strides, attrs);
1932 MemRefType resultType,
Value source,
1940 return b.getI64IntegerAttr(v);
1942 build(
b,
result, resultType, source,
b.getI64IntegerAttr(offset), sizeValues,
1943 strideValues, attrs);
1947 MemRefType resultType,
Value source,
Value offset,
1954 build(
b,
result, resultType, source, offset, sizeValues, strideValues, attrs);
1959LogicalResult ReinterpretCastOp::verify() {
1961 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1962 auto resultType = llvm::cast<MemRefType>(
getType());
1963 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1964 return emitError(
"different memory spaces specified for source type ")
1965 << srcType <<
" and result memref type " << resultType;
1971 for (
auto [idx, resultSize, expectedSize] :
1972 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1973 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1974 return emitError(
"expected result type with size = ")
1975 << (ShapedType::isDynamic(expectedSize)
1976 ? std::string(
"dynamic")
1977 : std::to_string(expectedSize))
1978 <<
" instead of " << resultSize <<
" in dim = " << idx;
1986 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1987 return emitError(
"expected result type to have strided layout but found ")
1991 int64_t expectedOffset = getStaticOffsets().front();
1992 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1993 return emitError(
"expected result type with offset = ")
1994 << (ShapedType::isDynamic(expectedOffset)
1995 ? std::string(
"dynamic")
1996 : std::to_string(expectedOffset))
1997 <<
" instead of " << resultOffset;
2000 for (
auto [idx, resultStride, expectedStride] :
2001 llvm::enumerate(resultStrides, getStaticStrides())) {
2002 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2003 return emitError(
"expected result type with stride = ")
2004 << (ShapedType::isDynamic(expectedStride)
2005 ? std::string(
"dynamic")
2006 : std::to_string(expectedStride))
2007 <<
" instead of " << resultStride <<
" in dim = " << idx;
2014 Value src = getSource();
2015 auto getPrevSrc = [&]() ->
Value {
2018 return prev.getSource();
2022 return prev.getSource();
2028 return prev.getSource();
2033 if (
auto prevSrc = getPrevSrc()) {
2034 getSourceMutable().assign(prevSrc);
2057 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2059 assert(succeeded(status) &&
"could not get strides from type");
2064OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2066 assert(values.size() == 1 &&
2067 "reinterpret_cast must have one and only one offset");
2070 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2072 assert(succeeded(status) &&
"could not get offset from type");
2073 staticValues.push_back(offset);
2121struct ReinterpretCastOpExtractStridedMetadataFolder
2124 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2126 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2127 PatternRewriter &rewriter)
const override {
2128 auto extractStridedMetadata =
2129 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2130 if (!extractStridedMetadata)
2135 auto isReinterpretCastNoop = [&]() ->
bool {
2137 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2138 op.getConstifiedMixedStrides()))
2142 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2143 op.getConstifiedMixedSizes()))
2147 assert(op.getMixedOffsets().size() == 1 &&
2148 "reinterpret_cast with more than one offset should have been "
2149 "rejected by the verifier");
2150 return extractStridedMetadata.getConstifiedMixedOffset() ==
2151 op.getConstifiedMixedOffset();
2154 if (!isReinterpretCastNoop()) {
2171 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2181 Type srcTy = extractStridedMetadata.getSource().getType();
2182 if (srcTy == op.getResult().getType())
2183 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2186 extractStridedMetadata.getSource());
2192struct ReinterpretCastOpConstantFolder
2195 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2197 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2198 PatternRewriter &rewriter)
const override {
2199 unsigned srcStaticCount = llvm::count_if(
2200 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2201 op.getMixedStrides()),
2202 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2204 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2205 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2206 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2212 if (srcStaticCount ==
2213 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2214 [](OpFoldResult ofr) {
return isa<Attribute>(ofr); }))
2217 auto newReinterpretCast = ReinterpretCastOp::create(
2218 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2228 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2229 ReinterpretCastOpConstantFolder>(context);
2232FailureOr<std::optional<SmallVector<Value>>>
2233ReinterpretCastOp::bubbleDownCasts(
OpBuilder &builder) {
2241void CollapseShapeOp::getAsmResultNames(
2243 setNameFn(getResult(),
"collapse_shape");
2246void ExpandShapeOp::getAsmResultNames(
2248 setNameFn(getResult(),
"expand_shape");
2251LogicalResult ExpandShapeOp::reifyResultShapes(
2253 reifiedResultShapes = {
2254 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2267 bool allowMultipleDynamicDimsPerGroup) {
2269 if (collapsedShape.size() != reassociation.size())
2270 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2271 << reassociation.size() <<
", expected " << collapsedShape.size();
2276 for (
const auto &it : llvm::enumerate(reassociation)) {
2278 int64_t collapsedDim = it.index();
2280 bool foundDynamic =
false;
2281 for (
int64_t expandedDim : group) {
2282 if (expandedDim != nextDim++)
2283 return op->
emitOpError(
"reassociation indices must be contiguous");
2285 if (expandedDim >=
static_cast<int64_t>(expandedShape.size()))
2287 << expandedDim <<
" is out of bounds";
2290 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2291 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2293 "at most one dimension in a reassociation group may be dynamic");
2294 foundDynamic =
true;
2299 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2302 <<
") must be dynamic if and only if reassociation group is "
2307 if (!foundDynamic) {
2309 for (
int64_t expandedDim : group)
2310 groupSize *= expandedShape[expandedDim];
2311 if (groupSize != collapsedShape[collapsedDim])
2313 << collapsedShape[collapsedDim]
2314 <<
") must equal reassociation group size (" << groupSize <<
")";
2318 if (collapsedShape.empty()) {
2320 for (
int64_t d : expandedShape)
2323 "rank 0 memrefs can only be extended/collapsed with/from ones");
2324 }
else if (nextDim !=
static_cast<int64_t>(expandedShape.size())) {
2328 << expandedShape.size()
2329 <<
") inconsistent with number of reassociation indices (" << nextDim
2336SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2340SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2342 getReassociationIndices());
2345SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2349SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2351 getReassociationIndices());
2356static FailureOr<StridedLayoutAttr>
2361 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2363 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2378 reverseResultStrides.reserve(resultShape.size());
2379 unsigned shapeIndex = resultShape.size() - 1;
2380 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2382 int64_t currentStrideToExpand = std::get<1>(it);
2383 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2384 reverseResultStrides.push_back(currentStrideToExpand);
2385 currentStrideToExpand =
2391 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2392 resultStrides.resize(resultShape.size(), 1);
2393 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2396FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2397 MemRefType srcType, ArrayRef<int64_t> resultShape,
2398 ArrayRef<ReassociationIndices> reassociation) {
2399 if (srcType.getLayout().isIdentity()) {
2402 MemRefLayoutAttrInterface layout;
2403 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2404 srcType.getMemorySpace());
2408 FailureOr<StridedLayoutAttr> computedLayout =
2410 if (
failed(computedLayout))
2412 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2413 srcType.getMemorySpace());
2416FailureOr<SmallVector<OpFoldResult>>
2417ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
2418 MemRefType expandedType,
2419 ArrayRef<ReassociationIndices> reassociation,
2420 ArrayRef<OpFoldResult> inputShape) {
2421 std::optional<SmallVector<OpFoldResult>> outputShape =
2426 return *outputShape;
2429void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2430 Type resultType, Value src,
2431 ArrayRef<ReassociationIndices> reassociation,
2432 ArrayRef<OpFoldResult> outputShape) {
2433 auto [staticOutputShape, dynamicOutputShape] =
2435 build(builder,
result, llvm::cast<MemRefType>(resultType), src,
2437 dynamicOutputShape, staticOutputShape);
2440void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2441 Type resultType, Value src,
2442 ArrayRef<ReassociationIndices> reassociation) {
2443 SmallVector<OpFoldResult> inputShape =
2445 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2446 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2447 builder,
result.location, memrefResultTy, reassociation, inputShape);
2450 assert(succeeded(outputShape) &&
"unable to infer output shape");
2451 build(builder,
result, memrefResultTy, src, reassociation, *outputShape);
2454void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2455 ArrayRef<int64_t> resultShape, Value src,
2456 ArrayRef<ReassociationIndices> reassociation) {
2458 auto srcType = llvm::cast<MemRefType>(src.
getType());
2459 FailureOr<MemRefType> resultType =
2460 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2463 assert(succeeded(resultType) &&
"could not compute layout");
2464 build(builder,
result, *resultType, src, reassociation);
2467void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2468 ArrayRef<int64_t> resultShape, Value src,
2469 ArrayRef<ReassociationIndices> reassociation,
2470 ArrayRef<OpFoldResult> outputShape) {
2472 auto srcType = llvm::cast<MemRefType>(src.
getType());
2473 FailureOr<MemRefType> resultType =
2474 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2477 assert(succeeded(resultType) &&
"could not compute layout");
2478 build(builder,
result, *resultType, src, reassociation, outputShape);
2481LogicalResult ExpandShapeOp::verify() {
2482 MemRefType srcType = getSrcType();
2483 MemRefType resultType = getResultType();
2485 if (srcType.getRank() > resultType.getRank()) {
2486 auto r0 = srcType.getRank();
2487 auto r1 = resultType.getRank();
2489 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2490 << r0 <<
" > " << r1 <<
").";
2495 resultType.getShape(),
2496 getReassociationIndices(),
2501 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2502 srcType, resultType.getShape(), getReassociationIndices());
2503 if (
failed(expectedResultType))
2507 if (*expectedResultType != resultType)
2508 return emitOpError(
"expected expanded type to be ")
2509 << *expectedResultType <<
" but found " << resultType;
2511 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2512 return emitOpError(
"expected number of static shape bounds to be equal to "
2513 "the output rank (")
2514 << resultType.getRank() <<
") but found "
2515 << getStaticOutputShape().size() <<
" inputs instead";
2517 if ((int64_t)getOutputShape().size() !=
2518 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2519 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2520 "static_output_shape: static_output_shape has ")
2521 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2522 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2527 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2528 for (
auto [pos, shape] : llvm::enumerate(resShape)) {
2529 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2530 return emitOpError(
"invalid output shape provided at pos ") << pos;
2543 auto cast = op.getSrc().getDefiningOp<CastOp>();
2547 if (!CastOp::canFoldIntoConsumerOp(cast))
2555 for (
auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2557 if (!sizeOpt.has_value()) {
2558 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2562 newOutputShapeSizes.push_back(sizeOpt.value());
2563 newOutputShape[dimIdx] = rewriter.
getIndexAttr(sizeOpt.value());
2566 Value castSource = cast.getSource();
2567 auto castSourceType = llvm::cast<MemRefType>(castSource.
getType());
2569 op.getReassociationIndices();
2570 for (
auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2571 auto newOutputShapeSizesSlice =
2572 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2573 bool newOutputDynamic =
2574 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2575 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2577 op,
"folding cast will result in changing dynamicity in "
2578 "reassociation group");
2581 FailureOr<MemRefType> newResultTypeOrFailure =
2582 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2583 reassociationIndices);
2585 if (failed(newResultTypeOrFailure))
2587 op,
"could not compute new expanded type after folding cast");
2589 if (*newResultTypeOrFailure == op.getResultType()) {
2591 op, [&]() { op.getSrcMutable().assign(castSource); });
2593 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2594 *newResultTypeOrFailure, castSource,
2595 reassociationIndices, newOutputShape);
2602void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2603 MLIRContext *context) {
2605 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2606 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2607 ExpandShapeOpMemRefCastFolder>(context);
2610FailureOr<std::optional<SmallVector<Value>>>
2611ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2622static FailureOr<StridedLayoutAttr>
2625 bool strict =
false) {
2628 auto srcShape = srcType.getShape();
2629 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2638 resultStrides.reserve(reassociation.size());
2641 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2642 ref = ref.drop_back();
2643 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2644 resultStrides.push_back(srcStrides[ref.back()]);
2650 resultStrides.push_back(ShapedType::kDynamic);
2655 unsigned resultStrideIndex = resultStrides.size() - 1;
2659 for (
int64_t idx : llvm::reverse(trailingReassocs)) {
2664 if (srcShape[idx - 1] == 1)
2676 if (strict && (stride.saturated || srcStride.saturated))
2679 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2683 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2686bool CollapseShapeOp::isGuaranteedCollapsible(
2687 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2689 if (srcType.getLayout().isIdentity())
2696MemRefType CollapseShapeOp::computeCollapsedType(
2697 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2698 SmallVector<int64_t> resultShape;
2699 resultShape.reserve(reassociation.size());
2702 for (int64_t srcDim : group)
2705 resultShape.push_back(groupSize.asInteger());
2708 if (srcType.getLayout().isIdentity()) {
2711 MemRefLayoutAttrInterface layout;
2712 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2713 srcType.getMemorySpace());
2719 FailureOr<StridedLayoutAttr> computedLayout =
2721 assert(succeeded(computedLayout) &&
2722 "invalid source layout map or collapsing non-contiguous dims");
2723 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2724 srcType.getMemorySpace());
2727void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2728 ArrayRef<ReassociationIndices> reassociation,
2729 ArrayRef<NamedAttribute> attrs) {
2730 auto srcType = llvm::cast<MemRefType>(src.
getType());
2731 MemRefType resultType =
2732 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2735 build(
b,
result, resultType, src, attrs);
2738LogicalResult CollapseShapeOp::verify() {
2739 MemRefType srcType = getSrcType();
2740 MemRefType resultType = getResultType();
2742 if (srcType.getRank() < resultType.getRank()) {
2743 auto r0 = srcType.getRank();
2744 auto r1 = resultType.getRank();
2746 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2747 << r0 <<
" < " << r1 <<
").";
2752 srcType.getShape(), getReassociationIndices(),
2757 MemRefType expectedResultType;
2758 if (srcType.getLayout().isIdentity()) {
2761 MemRefLayoutAttrInterface layout;
2762 expectedResultType =
2763 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2764 srcType.getMemorySpace());
2769 FailureOr<StridedLayoutAttr> computedLayout =
2771 if (
failed(computedLayout))
2773 "invalid source layout map or collapsing non-contiguous dims");
2774 expectedResultType =
2775 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2776 *computedLayout, srcType.getMemorySpace());
2779 if (expectedResultType != resultType)
2780 return emitOpError(
"expected collapsed type to be ")
2781 << expectedResultType <<
" but found " << resultType;
2793 auto cast = op.getOperand().getDefiningOp<CastOp>();
2797 if (!CastOp::canFoldIntoConsumerOp(cast))
2800 Type newResultType = CollapseShapeOp::computeCollapsedType(
2801 llvm::cast<MemRefType>(cast.getOperand().getType()),
2802 op.getReassociationIndices());
2804 if (newResultType == op.getResultType()) {
2806 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2809 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2810 op.getReassociationIndices());
2817void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2818 MLIRContext *context) {
2820 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2821 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2822 memref::DimOp, MemRefType>,
2823 CollapseShapeOpMemRefCastFolder>(context);
2826OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2828 adaptor.getOperands());
2831OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2833 adaptor.getOperands());
2836FailureOr<std::optional<SmallVector<Value>>>
2837CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2845void ReshapeOp::getAsmResultNames(
2847 setNameFn(getResult(),
"reshape");
2850LogicalResult ReshapeOp::verify() {
2851 Type operandType = getSource().getType();
2852 Type resultType = getResult().getType();
2854 Type operandElementType =
2855 llvm::cast<ShapedType>(operandType).getElementType();
2856 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2857 if (operandElementType != resultElementType)
2858 return emitOpError(
"element types of source and destination memref "
2859 "types should be the same");
2861 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2862 if (!operandMemRefType.getLayout().isIdentity())
2863 return emitOpError(
"source memref type should have identity affine map");
2867 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2868 if (resultMemRefType) {
2869 if (!resultMemRefType.getLayout().isIdentity())
2870 return emitOpError(
"result memref type should have identity affine map");
2871 if (shapeSize == ShapedType::kDynamic)
2872 return emitOpError(
"cannot use shape operand with dynamic length to "
2873 "reshape to statically-ranked memref type");
2874 if (shapeSize != resultMemRefType.getRank())
2876 "length of shape operand differs from the result's memref rank");
2881FailureOr<std::optional<SmallVector<Value>>>
2882ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2890LogicalResult StoreOp::verify() {
2892 return emitOpError(
"store index operand count not equal to memref rank");
2897LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2898 SmallVectorImpl<OpFoldResult> &results) {
2903FailureOr<std::optional<SmallVector<Value>>>
2904StoreOp::bubbleDownCasts(OpBuilder &builder) {
2913void SubViewOp::getAsmResultNames(
2915 setNameFn(getResult(),
"subview");
2921MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2922 ArrayRef<int64_t> staticOffsets,
2923 ArrayRef<int64_t> staticSizes,
2924 ArrayRef<int64_t> staticStrides) {
2925 unsigned rank = sourceMemRefType.getRank();
2927 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2928 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2929 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2932 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2936 int64_t targetOffset = sourceOffset;
2937 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2938 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2947 SmallVector<int64_t, 4> targetStrides;
2948 targetStrides.reserve(staticOffsets.size());
2949 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2950 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2957 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2958 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2959 targetOffset, targetStrides),
2960 sourceMemRefType.getMemorySpace());
2963MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2964 ArrayRef<OpFoldResult> offsets,
2965 ArrayRef<OpFoldResult> sizes,
2966 ArrayRef<OpFoldResult> strides) {
2967 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2968 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2978 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2979 staticSizes, staticStrides);
2982MemRefType SubViewOp::inferRankReducedResultType(
2983 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2984 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2985 ArrayRef<int64_t> strides) {
2986 MemRefType inferredType =
2987 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2988 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2990 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2991 return inferredType;
2994 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2996 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2999 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3000 SmallVector<int64_t> rankReducedStrides;
3001 rankReducedStrides.reserve(resultShape.size());
3002 for (
auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3003 if (!dimsToProject->contains(idx))
3004 rankReducedStrides.push_back(value);
3006 return MemRefType::get(resultShape, inferredType.getElementType(),
3007 StridedLayoutAttr::get(inferredLayout.getContext(),
3008 inferredLayout.getOffset(),
3009 rankReducedStrides),
3010 inferredType.getMemorySpace());
3013MemRefType SubViewOp::inferRankReducedResultType(
3014 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3015 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3016 ArrayRef<OpFoldResult> strides) {
3017 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3018 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3022 return SubViewOp::inferRankReducedResultType(
3023 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3029void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3030 MemRefType resultType, Value source,
3031 ArrayRef<OpFoldResult> offsets,
3032 ArrayRef<OpFoldResult> sizes,
3033 ArrayRef<OpFoldResult> strides,
3034 ArrayRef<NamedAttribute> attrs) {
3035 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3036 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3040 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
3043 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3044 staticSizes, staticStrides);
3046 result.addAttributes(attrs);
3047 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
3048 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3049 b.getDenseI64ArrayAttr(staticSizes),
3050 b.getDenseI64ArrayAttr(staticStrides));
3055void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3056 ArrayRef<OpFoldResult> offsets,
3057 ArrayRef<OpFoldResult> sizes,
3058 ArrayRef<OpFoldResult> strides,
3059 ArrayRef<NamedAttribute> attrs) {
3060 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3064void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3065 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3066 ArrayRef<int64_t> strides,
3067 ArrayRef<NamedAttribute> attrs) {
3068 SmallVector<OpFoldResult> offsetValues =
3069 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3070 return b.getI64IntegerAttr(v);
3072 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3073 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3074 SmallVector<OpFoldResult> strideValues =
3075 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3076 return b.getI64IntegerAttr(v);
3078 build(
b,
result, source, offsetValues, sizeValues, strideValues, attrs);
3083void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3084 MemRefType resultType, Value source,
3085 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3086 ArrayRef<int64_t> strides,
3087 ArrayRef<NamedAttribute> attrs) {
3088 SmallVector<OpFoldResult> offsetValues =
3089 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3090 return b.getI64IntegerAttr(v);
3092 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3093 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3094 SmallVector<OpFoldResult> strideValues =
3095 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3096 return b.getI64IntegerAttr(v);
3098 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues,
3104void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3105 MemRefType resultType, Value source,
ValueRange offsets,
3107 ArrayRef<NamedAttribute> attrs) {
3108 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3109 offsets, [](Value v) -> OpFoldResult {
return v; });
3110 SmallVector<OpFoldResult> sizeValues =
3111 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3112 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3113 strides, [](Value v) -> OpFoldResult {
return v; });
3114 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
3118void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3120 ArrayRef<NamedAttribute> attrs) {
3121 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3125Value SubViewOp::getViewSource() {
return getSource(); }
3132 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3133 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3134 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3141 const llvm::SmallBitVector &droppedDims) {
3142 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3143 "incorrect number of bits");
3144 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3145 "incorrect number of dropped dims");
3148 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3149 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3150 if (failed(res1) || failed(res2))
3152 for (
int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3155 if (t1Strides[i] != t2Strides[
j])
3163 SubViewOp op,
Type expectedType) {
3164 auto memrefType = llvm::cast<ShapedType>(expectedType);
3169 return op->emitError(
"expected result rank to be smaller or equal to ")
3170 <<
"the source rank, but got " << op.getType();
3172 return op->emitError(
"expected result type to be ")
3174 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3177 return op->emitError(
"expected result element type to be ")
3178 << memrefType.getElementType() <<
", but got " << op.getType();
3180 return op->emitError(
3181 "expected result and source memory spaces to match, but got ")
3184 return op->emitError(
"expected result type to be ")
3186 <<
" or a rank-reduced version. (mismatch of result layout), but "
3190 llvm_unreachable(
"unexpected subview verification result");
3194LogicalResult SubViewOp::verify() {
3195 MemRefType baseType = getSourceType();
3196 MemRefType subViewType =
getType();
3197 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3198 ArrayRef<int64_t> staticSizes = getStaticSizes();
3199 ArrayRef<int64_t> staticStrides = getStaticStrides();
3202 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3203 return emitError(
"different memory spaces specified for base memref "
3205 << baseType <<
" and subview memref type " << subViewType;
3208 if (!baseType.isStrided())
3209 return emitError(
"base type ") << baseType <<
" is not strided";
3213 MemRefType expectedType = SubViewOp::inferResultType(
3214 baseType, staticOffsets, staticSizes, staticStrides);
3219 expectedType, subViewType);
3224 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3226 *
this, expectedType);
3231 *
this, expectedType);
3241 *
this, expectedType);
3246 *
this, expectedType);
3250 SliceBoundsVerificationResult boundsResult =
3252 staticStrides,
true);
3254 return getOperation()->emitError(boundsResult.
errorMessage);
3260 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3269 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3270 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3271 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3273 unsigned rank = ranks[0];
3275 for (
unsigned idx = 0; idx < rank; ++idx) {
3277 op.isDynamicOffset(idx)
3278 ? op.getDynamicOffset(idx)
3281 op.isDynamicSize(idx)
3282 ? op.getDynamicSize(idx)
3285 op.isDynamicStride(idx)
3286 ? op.getDynamicStride(idx)
3288 res.emplace_back(
Range{offset, size, stride});
3301 MemRefType currentResultType, MemRefType currentSourceType,
3304 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3305 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3307 currentSourceType, currentResultType, mixedSizes);
3308 if (failed(unusedDims))
3311 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3313 unsigned numDimsAfterReduction =
3314 nonRankReducedType.getRank() - unusedDims->count();
3315 shape.reserve(numDimsAfterReduction);
3316 strides.reserve(numDimsAfterReduction);
3317 for (
const auto &[idx, size, stride] :
3318 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3319 nonRankReducedType.getShape(), layout.getStrides())) {
3320 if (unusedDims->test(idx))
3322 shape.push_back(size);
3323 strides.push_back(stride);
3326 return MemRefType::get(
shape, nonRankReducedType.getElementType(),
3327 StridedLayoutAttr::get(sourceType.getContext(),
3328 layout.getOffset(), strides),
3329 nonRankReducedType.getMemorySpace());
3334 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3335 unsigned rank = memrefType.getRank();
3339 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3340 targetShape, memrefType, offsets, sizes, strides);
3341 return b.createOrFold<memref::SubViewOp>(loc, targetType,
memref, offsets,
3348 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3349 assert(sourceMemrefType &&
"not a ranked memref type");
3350 auto sourceShape = sourceMemrefType.getShape();
3351 if (sourceShape.equals(desiredShape))
3353 auto maybeRankReductionMask =
3355 if (!maybeRankReductionMask)
3365 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3368 auto mixedOffsets = subViewOp.getMixedOffsets();
3369 auto mixedSizes = subViewOp.getMixedSizes();
3370 auto mixedStrides = subViewOp.getMixedStrides();
3375 return !intValue || intValue.value() != 0;
3382 return !intValue || intValue.value() != 1;
3388 for (
const auto &size : llvm::enumerate(mixedSizes)) {
3390 if (!intValue || *intValue != sourceShape[size.index()])
3414class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3416 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3418 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3419 PatternRewriter &rewriter)
const override {
3422 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3423 return matchPattern(operand, matchConstantIndex());
3427 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3431 if (!CastOp::canFoldIntoConsumerOp(castOp))
3439 subViewOp.getType(), subViewOp.getSourceType(),
3440 llvm::cast<MemRefType>(castOp.getSource().getType()),
3441 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3442 subViewOp.getMixedStrides());
3446 Value newSubView = SubViewOp::create(
3447 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3448 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3449 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3450 subViewOp.getStaticStrides());
3459class TrivialSubViewOpFolder final :
public OpRewritePattern<SubViewOp> {
3461 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3463 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3464 PatternRewriter &rewriter)
const override {
3467 if (subViewOp.getSourceType() == subViewOp.getType()) {
3468 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3472 subViewOp.getSource());
3484 MemRefType resTy = SubViewOp::inferResultType(
3485 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3488 MemRefType nonReducedType = resTy;
3491 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3492 if (droppedDims.none())
3493 return nonReducedType;
3496 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3501 for (
int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3502 if (droppedDims.test(i))
3504 targetStrides.push_back(nonReducedStrides[i]);
3505 targetShape.push_back(nonReducedType.getDimSize(i));
3508 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3509 StridedLayoutAttr::get(nonReducedType.getContext(),
3510 offset, targetStrides),
3511 nonReducedType.getMemorySpace());
3522void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3523 MLIRContext *context) {
3525 .
add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3526 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3527 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3530OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3531 MemRefType sourceMemrefType = getSource().getType();
3532 MemRefType resultMemrefType = getResult().getType();
3534 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3536 if (resultMemrefType == sourceMemrefType &&
3537 resultMemrefType.hasStaticShape() &&
3538 (!resultLayout || resultLayout.hasStaticLayout())) {
3539 return getViewSource();
3545 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3546 auto srcSizes = srcSubview.getMixedSizes();
3548 auto offsets = getMixedOffsets();
3550 auto strides = getMixedStrides();
3551 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3552 bool allSizesSame = llvm::equal(sizes, srcSizes);
3553 if (allOffsetsZero && allStridesOne && allSizesSame &&
3554 resultMemrefType == sourceMemrefType)
3555 return getViewSource();
3561FailureOr<std::optional<SmallVector<Value>>>
3562SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3566void SubViewOp::inferStridedMetadataRanges(
3567 ArrayRef<StridedMetadataRange> ranges,
GetIntRangeFn getIntRange,
3569 auto isUninitialized =
3570 +[](IntegerValueRange range) {
return range.isUninitialized(); };
3573 SmallVector<IntegerValueRange> offsetOperands =
3575 if (llvm::any_of(offsetOperands, isUninitialized))
3578 SmallVector<IntegerValueRange> sizeOperands =
3580 if (llvm::any_of(sizeOperands, isUninitialized))
3583 SmallVector<IntegerValueRange> stridesOperands =
3585 if (llvm::any_of(stridesOperands, isUninitialized))
3588 StridedMetadataRange sourceRange =
3589 ranges[getSourceMutable().getOperandNumber()];
3593 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.
getStrides();
3599 ConstantIntRanges offset = sourceRange.
getOffsets()[0];
3600 SmallVector<ConstantIntRanges> strides, sizes;
3602 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3603 bool dropped = droppedDims.test(i);
3605 ConstantIntRanges off =
3616 sizes.push_back(sizeOperands[i].getValue());
3619 setMetadata(getResult(),
3621 SmallVector<ConstantIntRanges>({std::move(offset)}),
3622 std::move(sizes), std::move(strides)));
3629void TransposeOp::getAsmResultNames(
3631 setNameFn(getResult(),
"transpose");
3637 auto originalSizes = memRefType.getShape();
3638 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3639 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3648 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3651void TransposeOp::build(OpBuilder &
b, OperationState &
result, Value in,
3652 AffineMapAttr permutation,
3653 ArrayRef<NamedAttribute> attrs) {
3654 auto permutationMap = permutation.getValue();
3655 assert(permutationMap);
3657 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3661 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3662 build(
b,
result, resultType, in, attrs);
3666void TransposeOp::print(OpAsmPrinter &p) {
3667 p <<
" " << getIn() <<
" " << getPermutation();
3669 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3672ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
3673 OpAsmParser::UnresolvedOperand in;
3674 AffineMap permutation;
3675 MemRefType srcType, dstType;
3684 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3685 AffineMapAttr::get(permutation));
3689LogicalResult TransposeOp::verify() {
3692 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3693 return emitOpError(
"expected a permutation map of same rank as the input");
3695 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3696 auto resultType = llvm::cast<MemRefType>(
getType());
3698 .canonicalizeStridedLayout();
3700 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3703 <<
" is not equivalent to the canonical transposed input type "
3704 << canonicalResultType;
3708OpFoldResult TransposeOp::fold(FoldAdaptor) {
3711 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3715 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3716 AffineMap composedPermutation =
3717 getPermutation().compose(otherTransposeOp.getPermutation());
3718 getInMutable().assign(otherTransposeOp.getIn());
3719 setPermutation(composedPermutation);
3725FailureOr<std::optional<SmallVector<Value>>>
3726TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3734void ViewOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
3735 setNameFn(getResult(),
"view");
3738LogicalResult ViewOp::verify() {
3739 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3743 if (!baseType.getLayout().isIdentity())
3744 return emitError(
"unsupported map for base memref type ") << baseType;
3747 if (!viewType.getLayout().isIdentity())
3748 return emitError(
"unsupported map for result memref type ") << viewType;
3751 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3752 return emitError(
"different memory spaces specified for base memref "
3754 << baseType <<
" and view memref type " << viewType;
3763Value ViewOp::getViewSource() {
return getSource(); }
3765OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3766 MemRefType sourceMemrefType = getSource().getType();
3767 MemRefType resultMemrefType = getResult().getType();
3769 if (resultMemrefType == sourceMemrefType &&
3770 resultMemrefType.hasStaticShape() &&
isZeroInteger(getByteShift()))
3771 return getViewSource();
3776SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3777 SmallVector<OpFoldResult>
result;
3781 if (ShapedType::isDynamic(dim)) {
3782 result.push_back(getSizes()[ctr++]);
3784 result.push_back(
b.getIndexAttr(dim));
3796 SmallVectorImpl<Value> &foldedDynamicSizes) {
3797 SmallVector<int64_t> staticShape(type.getShape());
3798 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3799 "incorrect number of dynamic sizes");
3803 for (
auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3804 if (ShapedType::isStatic(dimSize))
3807 Value dynamicSize = dynamicSizes[ctr++];
3810 if (cst.value() < 0) {
3811 foldedDynamicSizes.push_back(dynamicSize);
3814 staticShape[dim] = cst.value();
3816 foldedDynamicSizes.push_back(dynamicSize);
3820 return MemRefType::Builder(type).setShape(staticShape);
3834struct ViewOpShapeFolder :
public OpRewritePattern<ViewOp> {
3837 LogicalResult matchAndRewrite(ViewOp viewOp,
3838 PatternRewriter &rewriter)
const override {
3839 SmallVector<Value> foldedDynamicSizes;
3840 MemRefType resultType = viewOp.getType();
3842 resultType, viewOp.getSizes(), foldedDynamicSizes);
3845 if (foldedMemRefType == resultType)
3849 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
3850 viewOp.getSource(), viewOp.getByteShift(),
3851 foldedDynamicSizes);
3859struct ViewOpMemrefCastFolder :
public OpRewritePattern<ViewOp> {
3862 LogicalResult matchAndRewrite(ViewOp viewOp,
3863 PatternRewriter &rewriter)
const override {
3864 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3869 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3870 viewOp.getByteShift(), viewOp.getSizes());
3876void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3877 MLIRContext *context) {
3878 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3881FailureOr<std::optional<SmallVector<Value>>>
3882ViewOp::bubbleDownCasts(OpBuilder &builder) {
3890LogicalResult AtomicRMWOp::verify() {
3893 "expects the number of subscripts to be equal to memref rank");
3894 switch (getKind()) {
3895 case arith::AtomicRMWKind::addf:
3896 case arith::AtomicRMWKind::maximumf:
3897 case arith::AtomicRMWKind::minimumf:
3898 case arith::AtomicRMWKind::mulf:
3899 if (!llvm::isa<FloatType>(getValue().
getType()))
3901 << arith::stringifyAtomicRMWKind(getKind())
3902 <<
"' expects a floating-point type";
3904 case arith::AtomicRMWKind::addi:
3905 case arith::AtomicRMWKind::maxs:
3906 case arith::AtomicRMWKind::maxu:
3907 case arith::AtomicRMWKind::mins:
3908 case arith::AtomicRMWKind::minu:
3909 case arith::AtomicRMWKind::muli:
3910 case arith::AtomicRMWKind::ori:
3911 case arith::AtomicRMWKind::xori:
3912 case arith::AtomicRMWKind::andi:
3913 if (!llvm::isa<IntegerType>(getValue().
getType()))
3915 << arith::stringifyAtomicRMWKind(getKind())
3916 <<
"' expects an integer type";
3924OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3928 return OpFoldResult();
3931FailureOr<std::optional<SmallVector<Value>>>
3932AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3941#define GET_OP_CLASSES
3942#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.