24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (ShapedType::isStatic(cstVal)) {
116 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118 MemorySpaceCastOpInterface castOp =
119 MemorySpaceCastOpInterface::getIfPromotableCast(src);
127 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
128 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
132 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
133 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
138 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return std::make_tuple(castOp, *tgtTy, *srcTy);
146 template <
typename ConcreteOpTy>
147 static FailureOr<std::optional<SmallVector<Value>>>
157 llvm::append_range(operands, op->getOperands());
161 auto newOp = ConcreteOpTy::create(
162 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
163 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 return std::optional<SmallVector<Value>>(
177 void AllocOp::getAsmResultNames(
179 setNameFn(getResult(),
"alloc");
182 void AllocaOp::getAsmResultNames(
184 setNameFn(getResult(),
"alloca");
187 template <
typename AllocLikeOp>
189 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190 "applies to only alloc or alloca");
191 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
193 return op.emitOpError(
"result must be a memref");
195 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196 return op.emitOpError(
"dimension operand count does not equal memref "
197 "dynamic dimension count");
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError(
"symbol operand count does not equal memref symbol "
205 << numSymbols <<
", got " << op.getSymbolOperands().size();
216 "requires an ancestor op with AutomaticAllocationScope trait");
223 template <
typename AllocLikeOp>
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
231 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return constSizeArg.isNonNegative();
239 auto memrefType = alloc.getType();
244 newShapeConstants.reserve(memrefType.getRank());
247 unsigned dynamicDimPos = 0;
248 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
258 constSizeArg.isNonNegative()) {
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
270 MemRefType newMemRefType =
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
285 template <
typename T>
289 LogicalResult matchAndRewrite(T alloc,
291 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
298 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
309 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
314 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
324 MemRefType resultType =
getType();
327 if (!sourceType.getLayout().isIdentity())
328 return emitError(
"unsupported layout for source memref type ")
332 if (!resultType.getLayout().isIdentity())
333 return emitError(
"unsupported layout for result memref type ")
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError(
"different memory spaces specified for source memref "
340 << sourceType <<
" and result memref type " << resultType;
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError(
"different element types specified for source memref "
346 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405 void AllocaScopeOp::getSuccessorRegions(
418 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
424 if (isa<SideEffects::AutomaticAllocationScopeResource>(
425 effect->getResource()))
441 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
447 if (isa<SideEffects::AutomaticAllocationScopeResource>(
448 effect->getResource()))
472 bool hasPotentialAlloca =
485 if (hasPotentialAlloca) {
518 if (!lastParentWithoutScope ||
531 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
532 if (!lastParentWithoutScope ||
539 Region *containingRegion =
nullptr;
540 for (
auto &r : lastParentWithoutScope->
getRegions()) {
541 if (r.isAncestor(op->getParentRegion())) {
542 assert(containingRegion ==
nullptr &&
543 "only one region can contain the op");
544 containingRegion = &r;
547 assert(containingRegion &&
"op must be contained in a region");
557 return containingRegion->isAncestor(v.getParentRegion());
560 toHoist.push_back(alloc);
567 for (
auto *op : toHoist) {
568 auto *cloned = rewriter.
clone(*op);
569 rewriter.
replaceOp(op, cloned->getResults());
585 if (!llvm::isPowerOf2_32(getAlignment()))
586 return emitOpError(
"alignment must be power of 2");
590 void AssumeAlignmentOp::getAsmResultNames(
592 setNameFn(getResult(),
"assume_align");
595 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
599 if (source.getAlignment() != getAlignment())
604 FailureOr<std::optional<SmallVector<Value>>>
605 AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
614 if (getOperandTypes() != getResultTypes())
615 return emitOpError(
"operand types and result types must match");
617 if (getOperandTypes().empty())
618 return emitOpError(
"expected at least one operand");
623 LogicalResult DistinctObjectsOp::inferReturnTypes(
637 setNameFn(getResult(),
"cast");
678 MemRefType sourceType =
679 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
683 if (!sourceType || !resultType)
687 if (sourceType.getElementType() != resultType.getElementType())
691 if (sourceType.getRank() != resultType.getRank())
695 int64_t sourceOffset, resultOffset;
697 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
702 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703 auto ss = std::get<0>(it), st = std::get<1>(it);
705 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
710 if (sourceOffset != resultOffset)
711 if (ShapedType::isDynamic(sourceOffset) &&
712 ShapedType::isStatic(resultOffset))
716 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
717 auto ss = std::get<0>(it), st = std::get<1>(it);
719 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
727 if (inputs.size() != 1 || outputs.size() != 1)
729 Type a = inputs.front(), b = outputs.front();
730 auto aT = llvm::dyn_cast<MemRefType>(a);
731 auto bT = llvm::dyn_cast<MemRefType>(b);
733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
737 if (aT.getElementType() != bT.getElementType())
739 if (aT.getLayout() != bT.getLayout()) {
740 int64_t aOffset, bOffset;
742 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744 aStrides.size() != bStrides.size())
751 auto checkCompatible = [](int64_t a, int64_t b) {
752 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
754 if (!checkCompatible(aOffset, bOffset))
756 for (
const auto &aStride :
enumerate(aStrides))
757 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
760 if (aT.getMemorySpace() != bT.getMemorySpace())
764 if (aT.getRank() != bT.getRank())
767 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
783 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785 if (aEltType != bEltType)
788 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790 return aMemSpace == bMemSpace;
800 FailureOr<std::optional<SmallVector<Value>>>
801 CastOp::bubbleDownCasts(
OpBuilder &builder) {
815 LogicalResult matchAndRewrite(CopyOp copyOp,
817 if (copyOp.getSource() != copyOp.getTarget())
832 LogicalResult matchAndRewrite(CopyOp copyOp,
834 if (isEmptyMemRef(copyOp.getSource().getType()) ||
835 isEmptyMemRef(copyOp.getTarget().getType())) {
847 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
854 for (
OpOperand &operand : op->getOpOperands()) {
857 operand.set(castOp.getOperand());
864 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
875 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
886 setNameFn(getResult(),
"dim");
893 build(builder, result, source, indexValue);
896 std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
906 if (!rankedSourceType)
917 setResultRange(getResult(),
926 std::map<int64_t, unsigned> numOccurences;
927 for (
auto val : vals)
928 numOccurences[val]++;
929 return numOccurences;
939 static FailureOr<llvm::SmallBitVector>
942 llvm::SmallBitVector unusedDims(originalType.getRank());
943 if (originalType.getRank() == reducedType.getRank())
947 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949 unusedDims.set(dim.index());
953 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
954 originalType.getRank())
958 int64_t originalOffset, candidateOffset;
960 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
962 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
974 std::map<int64_t, unsigned> currUnaccountedStrides =
976 std::map<int64_t, unsigned> candidateStridesNumOccurences =
978 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979 if (!unusedDims.test(dim))
981 int64_t originalStride = originalStrides[dim];
982 if (currUnaccountedStrides[originalStride] >
983 candidateStridesNumOccurences[originalStride]) {
985 currUnaccountedStrides[originalStride]--;
988 if (currUnaccountedStrides[originalStride] ==
989 candidateStridesNumOccurences[originalStride]) {
991 unusedDims.reset(dim);
994 if (currUnaccountedStrides[originalStride] <
995 candidateStridesNumOccurences[originalStride]) {
1002 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003 originalType.getRank())
1009 MemRefType sourceType = getSourceType();
1010 MemRefType resultType =
getType();
1011 FailureOr<llvm::SmallBitVector> unusedDims =
1013 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1019 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1024 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1030 int64_t indexVal = index.getInt();
1031 if (indexVal < 0 || indexVal >= memrefType.getRank())
1035 if (!memrefType.isDynamicDim(index.getInt())) {
1037 return builder.
getIndexAttr(memrefType.getShape()[index.getInt()]);
1041 unsigned unsignedIndex = index.getValue().getZExtValue();
1044 Operation *definingOp = getSource().getDefiningOp();
1046 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047 return *(alloc.getDynamicSizes().begin() +
1048 memrefType.getDynamicDimIndex(unsignedIndex));
1050 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051 return *(alloca.getDynamicSizes().begin() +
1052 memrefType.getDynamicDimIndex(unsignedIndex));
1054 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055 return *(view.getDynamicSizes().begin() +
1056 memrefType.getDynamicDimIndex(unsignedIndex));
1058 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060 unsigned resultIndex = 0;
1061 unsigned sourceRank = subview.getSourceType().getRank();
1062 unsigned sourceIndex = 0;
1063 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064 if (unusedDims.test(i))
1066 if (resultIndex == unsignedIndex) {
1072 assert(subview.isDynamicSize(sourceIndex) &&
1073 "expected dynamic subview size");
1074 return subview.getDynamicSize(sourceIndex);
1077 if (
auto sizeInterface =
1078 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1079 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1080 "Expected dynamic subview size");
1081 return sizeInterface.getDynamicSize(unsignedIndex);
1097 LogicalResult matchAndRewrite(DimOp dim,
1099 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1103 dim,
"Dim op is not defined by a reshape op.");
1114 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1115 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1116 if (reshape->isBeforeInBlock(definingOp)) {
1119 "dim.getIndex is not defined before reshape in the same block.");
1124 else if (dim->getBlock() != reshape->getBlock() &&
1125 !dim.getIndex().getParentRegion()->isProperAncestor(
1126 reshape->getParentRegion())) {
1131 dim,
"dim.getIndex does not dominate reshape.");
1139 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1140 if (load.
getType() != dim.getType())
1141 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1151 results.
add<DimOfMemRefReshape>(context);
1162 Value elementsPerStride) {
1174 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1175 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1176 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1178 p <<
", " <<
getStride() <<
", " << getNumElementsPerStride();
1181 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1182 <<
", " << getTagMemRef().getType();
1223 bool isStrided = strideInfo.size() == 2;
1224 if (!strideInfo.empty() && !isStrided) {
1226 "expected two stride related operands");
1231 if (types.size() != 3)
1254 unsigned numOperands = getNumOperands();
1258 if (numOperands < 4)
1259 return emitOpError(
"expected at least 4 operands");
1264 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1265 return emitOpError(
"expected source to be of memref type");
1266 if (numOperands < getSrcMemRefRank() + 4)
1267 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1269 if (!getSrcIndices().empty() &&
1270 !llvm::all_of(getSrcIndices().getTypes(),
1272 return emitOpError(
"expected source indices to be of index type");
1275 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1276 return emitOpError(
"expected destination to be of memref type");
1277 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1278 if (numOperands < numExpectedOperands)
1279 return emitOpError() <<
"expected at least " << numExpectedOperands
1281 if (!getDstIndices().empty() &&
1282 !llvm::all_of(getDstIndices().getTypes(),
1284 return emitOpError(
"expected destination indices to be of index type");
1288 return emitOpError(
"expected num elements to be of index type");
1291 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1292 return emitOpError(
"expected tag to be of memref type");
1293 numExpectedOperands += getTagMemRefRank();
1294 if (numOperands < numExpectedOperands)
1295 return emitOpError() <<
"expected at least " << numExpectedOperands
1297 if (!getTagIndices().empty() &&
1298 !llvm::all_of(getTagIndices().getTypes(),
1300 return emitOpError(
"expected tag indices to be of index type");
1304 if (numOperands != numExpectedOperands &&
1305 numOperands != numExpectedOperands + 2)
1306 return emitOpError(
"incorrect number of operands");
1311 !getNumElementsPerStride().
getType().isIndex())
1313 "expected stride and num elements per stride to be of type index");
1319 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1329 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1337 unsigned numTagIndices = getTagIndices().size();
1338 unsigned tagMemRefRank = getTagMemRefRank();
1339 if (numTagIndices != tagMemRefRank)
1340 return emitOpError() <<
"expected tagIndices to have the same number of "
1341 "elements as the tagMemRef rank, expected "
1342 << tagMemRefRank <<
", but got " << numTagIndices;
1350 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1352 setNameFn(getResult(),
"intptr");
1361 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1362 MLIRContext *context, std::optional<Location> location,
1363 ExtractStridedMetadataOp::Adaptor adaptor,
1365 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1369 unsigned sourceRank = sourceType.getRank();
1373 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1375 inferredReturnTypes.push_back(memrefType);
1377 inferredReturnTypes.push_back(indexType);
1379 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1380 inferredReturnTypes.push_back(indexType);
1384 void ExtractStridedMetadataOp::getAsmResultNames(
1386 setNameFn(getBaseBuffer(),
"base_buffer");
1387 setNameFn(getOffset(),
"offset");
1390 if (!getSizes().empty()) {
1391 setNameFn(getSizes().front(),
"sizes");
1392 setNameFn(getStrides().front(),
"strides");
1399 template <
typename Container>
1403 assert(values.size() == maybeConstants.size() &&
1404 " expected values and maybeConstants of the same size");
1405 bool atLeastOneReplacement =
false;
1406 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1411 assert(isa<Attribute>(maybeConstant) &&
1412 "The constified value should be either unchanged (i.e., == result) "
1416 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1417 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1421 atLeastOneReplacement =
true;
1424 return atLeastOneReplacement;
1428 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1434 getConstifiedMixedOffset());
1436 getConstifiedMixedSizes());
1438 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1440 return success(atLeastOneReplacement);
1450 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1454 LogicalResult status =
1455 getSource().getType().getStridesAndOffset(staticValues, unused);
1457 assert(succeeded(status) &&
"could not get strides from type");
1462 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1467 LogicalResult status =
1468 getSource().getType().getStridesAndOffset(unused, offset);
1470 assert(succeeded(status) &&
"could not get offset from type");
1471 staticValues.push_back(offset);
1486 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1487 Type elementType = memrefType.getElementType();
1497 auto &body = getRegion();
1498 if (body.getNumArguments() != 1)
1499 return emitOpError(
"expected single number of entry block arguments");
1501 if (getResult().
getType() != body.getArgument(0).getType())
1502 return emitOpError(
"expected block argument of the same type result type");
1509 "body of 'memref.generic_atomic_rmw' should contain "
1510 "only operations with no side effects");
1540 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1541 <<
"] : " << getMemref().
getType() <<
' ';
1551 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1552 Type resultType = getResult().getType();
1553 if (parentType != resultType)
1554 return emitOpError() <<
"types mismatch between yield op: " << resultType
1555 <<
" and its parent: " << parentType;
1567 if (!op.isExternal()) {
1569 if (op.isUninitialized())
1570 p <<
"uninitialized";
1583 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1584 if (!memrefType || !memrefType.hasStaticShape())
1586 <<
"type should be static shaped memref, but got " << type;
1600 if (!llvm::isa<ElementsAttr>(initialValue))
1602 <<
"initial value should be a unit or elements attribute";
1607 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1608 if (!memrefType || !memrefType.hasStaticShape())
1609 return emitOpError(
"type should be static shaped memref, but got ")
1614 if (getInitialValue().has_value()) {
1615 Attribute initValue = getInitialValue().value();
1616 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1617 return emitOpError(
"initial value should be a unit or elements "
1618 "attribute, but got ")
1623 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1625 auto initElementType =
1626 cast<TensorType>(elementsAttr.getType()).getElementType();
1627 auto memrefElementType = memrefType.getElementType();
1629 if (initElementType != memrefElementType)
1630 return emitOpError(
"initial value element expected to be of type ")
1631 << memrefElementType <<
", but was of type " << initElementType;
1636 auto initShape = elementsAttr.getShapedType().getShape();
1637 auto memrefShape = memrefType.getShape();
1638 if (initShape != memrefShape)
1639 return emitOpError(
"initial value shape expected to be ")
1640 << memrefShape <<
" but was " << initShape;
1648 ElementsAttr GlobalOp::getConstantInitValue() {
1649 auto initVal = getInitialValue();
1650 if (getConstant() && initVal.has_value())
1651 return llvm::cast<ElementsAttr>(initVal.value());
1666 return emitOpError(
"'")
1667 << getName() <<
"' does not reference a valid global memref";
1669 Type resultType = getResult().getType();
1670 if (global.getType() != resultType)
1671 return emitOpError(
"result type ")
1672 << resultType <<
" does not match type " << global.getType()
1673 <<
" of the global memref @" << getName();
1683 return emitOpError(
"incorrect number of indices for load, expected ")
1696 FailureOr<std::optional<SmallVector<Value>>>
1697 LoadOp::bubbleDownCasts(
OpBuilder &builder) {
1706 void MemorySpaceCastOp::getAsmResultNames(
1708 setNameFn(getResult(),
"memspacecast");
1712 if (inputs.size() != 1 || outputs.size() != 1)
1714 Type a = inputs.front(), b = outputs.front();
1715 auto aT = llvm::dyn_cast<MemRefType>(a);
1716 auto bT = llvm::dyn_cast<MemRefType>(b);
1718 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1719 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1722 if (aT.getElementType() != bT.getElementType())
1724 if (aT.getLayout() != bT.getLayout())
1726 if (aT.getShape() != bT.getShape())
1731 return uaT.getElementType() == ubT.getElementType();
1736 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1739 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1740 getSourceMutable().assign(parentCast.getSource());
1747 return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
1751 return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
1754 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1755 PtrLikeTypeInterface src) {
1756 return isa<BaseMemRefType>(tgt) &&
1757 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1760 MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1763 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1764 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1768 bool MemorySpaceCastOp::isSourcePromotable() {
1769 return getDest().getType().getMemorySpace() ==
nullptr;
1777 p <<
" " << getMemref() <<
'[';
1779 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1780 p <<
", locality<" << getLocalityHint();
1781 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1783 (*this)->getAttrs(),
1784 {
"localityHint",
"isWrite",
"isDataCache"});
1791 IntegerAttr localityHint;
1793 StringRef readOrWrite, cacheType;
1810 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1812 "rw specifier has to be 'read' or 'write'");
1813 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1816 if (cacheType !=
"data" && cacheType !=
"instr")
1818 "cache type has to be 'data' or 'instr'");
1820 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1828 return emitOpError(
"too few indices");
1833 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1845 auto type = getOperand().getType();
1846 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1847 if (shapedType && shapedType.hasRank())
1849 return IntegerAttr();
1856 void ReinterpretCastOp::getAsmResultNames(
1858 setNameFn(getResult(),
"reinterpret_cast");
1865 MemRefType resultType,
Value source,
1875 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1886 auto sourceType = cast<BaseMemRefType>(source.
getType());
1893 b.
getContext(), staticOffsets.front(), staticStrides);
1894 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1895 stridedLayout, sourceType.getMemorySpace());
1896 build(b, result, resultType, source, offset, sizes, strides, attrs);
1900 MemRefType resultType,
Value source,
1905 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1909 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1913 strideValues, attrs);
1917 MemRefType resultType,
Value source,
Value offset,
1924 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1931 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1932 auto resultType = llvm::cast<MemRefType>(
getType());
1933 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1934 return emitError(
"different memory spaces specified for source type ")
1935 << srcType <<
" and result memref type " << resultType;
1936 if (srcType.getElementType() != resultType.getElementType())
1937 return emitError(
"different element types specified for source type ")
1938 << srcType <<
" and result memref type " << resultType;
1941 for (
auto [idx, resultSize, expectedSize] :
1943 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1944 return emitError(
"expected result type with size = ")
1945 << (ShapedType::isDynamic(expectedSize)
1946 ? std::string(
"dynamic")
1947 : std::to_string(expectedSize))
1948 <<
" instead of " << resultSize <<
" in dim = " << idx;
1954 int64_t resultOffset;
1956 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1957 return emitError(
"expected result type to have strided layout but found ")
1961 int64_t expectedOffset = getStaticOffsets().front();
1962 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1963 return emitError(
"expected result type with offset = ")
1964 << (ShapedType::isDynamic(expectedOffset)
1965 ? std::string(
"dynamic")
1966 : std::to_string(expectedOffset))
1967 <<
" instead of " << resultOffset;
1970 for (
auto [idx, resultStride, expectedStride] :
1972 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1973 return emitError(
"expected result type with stride = ")
1974 << (ShapedType::isDynamic(expectedStride)
1975 ? std::string(
"dynamic")
1976 : std::to_string(expectedStride))
1977 <<
" instead of " << resultStride <<
" in dim = " << idx;
1984 Value src = getSource();
1985 auto getPrevSrc = [&]() ->
Value {
1988 return prev.getSource();
1992 return prev.getSource();
1998 return prev.getSource();
2003 if (
auto prevSrc = getPrevSrc()) {
2004 getSourceMutable().assign(prevSrc);
2027 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2029 assert(succeeded(status) &&
"could not get strides from type");
2034 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2036 assert(values.size() == 1 &&
2037 "reinterpret_cast must have one and only one offset");
2040 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2042 assert(succeeded(status) &&
"could not get offset from type");
2043 staticValues.push_back(offset);
2091 struct ReinterpretCastOpExtractStridedMetadataFolder
2096 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2098 auto extractStridedMetadata =
2099 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2100 if (!extractStridedMetadata)
2105 auto isReinterpretCastNoop = [&]() ->
bool {
2107 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2108 op.getConstifiedMixedStrides()))
2112 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2113 op.getConstifiedMixedSizes()))
2117 assert(op.getMixedOffsets().size() == 1 &&
2118 "reinterpret_cast with more than one offset should have been "
2119 "rejected by the verifier");
2120 return extractStridedMetadata.getConstifiedMixedOffset() ==
2121 op.getConstifiedMixedOffset();
2124 if (!isReinterpretCastNoop()) {
2141 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2151 Type srcTy = extractStridedMetadata.getSource().getType();
2152 if (srcTy == op.getResult().getType())
2153 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2156 extractStridedMetadata.getSource());
2165 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2168 FailureOr<std::optional<SmallVector<Value>>>
2169 ReinterpretCastOp::bubbleDownCasts(
OpBuilder &builder) {
2177 void CollapseShapeOp::getAsmResultNames(
2179 setNameFn(getResult(),
"collapse_shape");
2182 void ExpandShapeOp::getAsmResultNames(
2184 setNameFn(getResult(),
"expand_shape");
2189 reifiedResultShapes = {
2190 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2199 static LogicalResult
2203 bool allowMultipleDynamicDimsPerGroup) {
2205 if (collapsedShape.size() != reassociation.size())
2206 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2207 << reassociation.size() <<
", expected " << collapsedShape.size();
2211 int64_t nextDim = 0;
2214 int64_t collapsedDim = it.index();
2216 bool foundDynamic =
false;
2217 for (int64_t expandedDim : group) {
2218 if (expandedDim != nextDim++)
2219 return op->
emitOpError(
"reassociation indices must be contiguous");
2221 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2223 << expandedDim <<
" is out of bounds";
2226 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2227 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2229 "at most one dimension in a reassociation group may be dynamic");
2230 foundDynamic =
true;
2235 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2238 <<
") must be dynamic if and only if reassociation group is "
2243 if (!foundDynamic) {
2244 int64_t groupSize = 1;
2245 for (int64_t expandedDim : group)
2246 groupSize *= expandedShape[expandedDim];
2247 if (groupSize != collapsedShape[collapsedDim])
2249 << collapsedShape[collapsedDim]
2250 <<
") must equal reassociation group size (" << groupSize <<
")";
2254 if (collapsedShape.empty()) {
2256 for (int64_t d : expandedShape)
2259 "rank 0 memrefs can only be extended/collapsed with/from ones");
2260 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2264 << expandedShape.size()
2265 <<
") inconsistent with number of reassociation indices (" << nextDim
2278 getReassociationIndices());
2287 getReassociationIndices());
2292 static FailureOr<StridedLayoutAttr>
2297 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2299 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2314 reverseResultStrides.reserve(resultShape.size());
2315 unsigned shapeIndex = resultShape.size() - 1;
2316 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2318 int64_t currentStrideToExpand = std::get<1>(it);
2319 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2320 reverseResultStrides.push_back(currentStrideToExpand);
2321 currentStrideToExpand =
2327 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2328 resultStrides.resize(resultShape.size(), 1);
2332 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2335 if (srcType.getLayout().isIdentity()) {
2338 MemRefLayoutAttrInterface layout;
2340 srcType.getMemorySpace());
2344 FailureOr<StridedLayoutAttr> computedLayout =
2346 if (
failed(computedLayout))
2348 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2349 srcType.getMemorySpace());
2352 FailureOr<SmallVector<OpFoldResult>>
2354 MemRefType expandedType,
2357 std::optional<SmallVector<OpFoldResult>> outputShape =
2362 return *outputShape;
2369 auto [staticOutputShape, dynamicOutputShape] =
2371 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2373 dynamicOutputShape, staticOutputShape);
2381 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2382 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2383 builder, result.
location, memrefResultTy, reassociation, inputShape);
2386 assert(succeeded(outputShape) &&
"unable to infer output shape");
2387 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2394 auto srcType = llvm::cast<MemRefType>(src.
getType());
2395 FailureOr<MemRefType> resultType =
2396 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2399 assert(succeeded(resultType) &&
"could not compute layout");
2400 build(builder, result, *resultType, src, reassociation);
2408 auto srcType = llvm::cast<MemRefType>(src.
getType());
2409 FailureOr<MemRefType> resultType =
2410 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2413 assert(succeeded(resultType) &&
"could not compute layout");
2414 build(builder, result, *resultType, src, reassociation, outputShape);
2418 MemRefType srcType = getSrcType();
2419 MemRefType resultType = getResultType();
2421 if (srcType.getRank() > resultType.getRank()) {
2422 auto r0 = srcType.getRank();
2423 auto r1 = resultType.getRank();
2424 return emitOpError(
"has source rank ")
2425 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2426 << r0 <<
" > " << r1 <<
").";
2431 resultType.getShape(),
2432 getReassociationIndices(),
2437 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2438 srcType, resultType.getShape(), getReassociationIndices());
2439 if (
failed(expectedResultType))
2440 return emitOpError(
"invalid source layout map");
2443 if (*expectedResultType != resultType)
2444 return emitOpError(
"expected expanded type to be ")
2445 << *expectedResultType <<
" but found " << resultType;
2447 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2448 return emitOpError(
"expected number of static shape bounds to be equal to "
2449 "the output rank (")
2450 << resultType.getRank() <<
") but found "
2451 << getStaticOutputShape().size() <<
" inputs instead";
2453 if ((int64_t)getOutputShape().size() !=
2454 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2455 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2456 "static_output_shape: static_output_shape has ")
2457 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2458 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2465 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2466 return emitOpError(
"invalid output shape provided at pos ") << pos;
2480 FailureOr<std::optional<SmallVector<Value>>>
2481 ExpandShapeOp::bubbleDownCasts(
OpBuilder &builder) {
2492 static FailureOr<StridedLayoutAttr>
2495 bool strict =
false) {
2498 auto srcShape = srcType.getShape();
2499 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2508 resultStrides.reserve(reassociation.size());
2511 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2512 ref = ref.drop_back();
2513 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2514 resultStrides.push_back(srcStrides[ref.back()]);
2520 resultStrides.push_back(ShapedType::kDynamic);
2525 unsigned resultStrideIndex = resultStrides.size() - 1;
2529 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2541 if (strict && (stride.saturated || srcStride.saturated))
2546 if (srcShape[idx - 1] == 1)
2549 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2556 bool CollapseShapeOp::isGuaranteedCollapsible(
2559 if (srcType.getLayout().isIdentity())
2566 MemRefType CollapseShapeOp::computeCollapsedType(
2569 resultShape.reserve(reassociation.size());
2572 for (int64_t srcDim : group)
2575 resultShape.push_back(groupSize.asInteger());
2578 if (srcType.getLayout().isIdentity()) {
2581 MemRefLayoutAttrInterface layout;
2583 srcType.getMemorySpace());
2589 FailureOr<StridedLayoutAttr> computedLayout =
2591 assert(succeeded(computedLayout) &&
2592 "invalid source layout map or collapsing non-contiguous dims");
2593 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2594 srcType.getMemorySpace());
2600 auto srcType = llvm::cast<MemRefType>(src.
getType());
2601 MemRefType resultType =
2602 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2605 build(b, result, resultType, src, attrs);
2609 MemRefType srcType = getSrcType();
2610 MemRefType resultType = getResultType();
2612 if (srcType.getRank() < resultType.getRank()) {
2613 auto r0 = srcType.getRank();
2614 auto r1 = resultType.getRank();
2615 return emitOpError(
"has source rank ")
2616 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2617 << r0 <<
" < " << r1 <<
").";
2622 srcType.getShape(), getReassociationIndices(),
2627 MemRefType expectedResultType;
2628 if (srcType.getLayout().isIdentity()) {
2631 MemRefLayoutAttrInterface layout;
2632 expectedResultType =
2633 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2634 srcType.getMemorySpace());
2639 FailureOr<StridedLayoutAttr> computedLayout =
2641 if (
failed(computedLayout))
2643 "invalid source layout map or collapsing non-contiguous dims");
2644 expectedResultType =
2646 *computedLayout, srcType.getMemorySpace());
2649 if (expectedResultType != resultType)
2650 return emitOpError(
"expected collapsed type to be ")
2651 << expectedResultType <<
" but found " << resultType;
2663 auto cast = op.getOperand().getDefiningOp<CastOp>();
2670 Type newResultType = CollapseShapeOp::computeCollapsedType(
2671 llvm::cast<MemRefType>(cast.getOperand().getType()),
2672 op.getReassociationIndices());
2674 if (newResultType == op.getResultType()) {
2676 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2679 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2680 op.getReassociationIndices());
2692 memref::DimOp, MemRefType>,
2696 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2697 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2698 adaptor.getOperands());
2701 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2702 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2703 adaptor.getOperands());
2706 FailureOr<std::optional<SmallVector<Value>>>
2707 CollapseShapeOp::bubbleDownCasts(
OpBuilder &builder) {
2715 void ReshapeOp::getAsmResultNames(
2717 setNameFn(getResult(),
"reshape");
2721 Type operandType = getSource().getType();
2722 Type resultType = getResult().getType();
2724 Type operandElementType =
2725 llvm::cast<ShapedType>(operandType).getElementType();
2726 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2727 if (operandElementType != resultElementType)
2728 return emitOpError(
"element types of source and destination memref "
2729 "types should be the same");
2731 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2732 if (!operandMemRefType.getLayout().isIdentity())
2733 return emitOpError(
"source memref type should have identity affine map");
2737 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2738 if (resultMemRefType) {
2739 if (!resultMemRefType.getLayout().isIdentity())
2740 return emitOpError(
"result memref type should have identity affine map");
2741 if (shapeSize == ShapedType::kDynamic)
2742 return emitOpError(
"cannot use shape operand with dynamic length to "
2743 "reshape to statically-ranked memref type");
2744 if (shapeSize != resultMemRefType.getRank())
2746 "length of shape operand differs from the result's memref rank");
2751 FailureOr<std::optional<SmallVector<Value>>>
2752 ReshapeOp::bubbleDownCasts(
OpBuilder &builder) {
2762 return emitOpError(
"store index operand count not equal to memref rank");
2767 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2773 FailureOr<std::optional<SmallVector<Value>>>
2774 StoreOp::bubbleDownCasts(
OpBuilder &builder) {
2783 void SubViewOp::getAsmResultNames(
2785 setNameFn(getResult(),
"subview");
2791 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2795 unsigned rank = sourceMemRefType.getRank();
2797 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2798 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2799 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2802 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2806 int64_t targetOffset = sourceOffset;
2807 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2808 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2818 targetStrides.reserve(staticOffsets.size());
2819 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2820 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2827 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2829 targetOffset, targetStrides),
2830 sourceMemRefType.getMemorySpace());
2833 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2848 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2849 staticSizes, staticStrides);
2852 MemRefType SubViewOp::inferRankReducedResultType(
2856 MemRefType inferredType =
2857 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2858 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2860 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2861 return inferredType;
2864 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2866 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2869 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2871 rankReducedStrides.reserve(resultShape.size());
2872 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2873 if (!dimsToProject->contains(idx))
2874 rankReducedStrides.push_back(value);
2878 inferredLayout.getOffset(),
2879 rankReducedStrides),
2880 inferredType.getMemorySpace());
2883 MemRefType SubViewOp::inferRankReducedResultType(
2892 return SubViewOp::inferRankReducedResultType(
2893 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2900 MemRefType resultType,
Value source,
2910 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2913 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2914 staticSizes, staticStrides);
2917 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2930 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2939 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2943 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2947 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2950 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2956 MemRefType resultType,
Value source,
2961 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2965 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2969 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2972 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2988 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2995 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2999 Value SubViewOp::getViewSource() {
return getSource(); }
3004 int64_t t1Offset, t2Offset;
3006 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3007 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3008 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3015 const llvm::SmallBitVector &droppedDims) {
3016 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3017 "incorrect number of bits");
3018 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3019 "incorrect number of dropped dims");
3020 int64_t t1Offset, t2Offset;
3022 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3023 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3026 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3029 if (t1Strides[i] != t2Strides[
j])
3037 SubViewOp op,
Type expectedType) {
3038 auto memrefType = llvm::cast<ShapedType>(expectedType);
3043 return op->emitError(
"expected result rank to be smaller or equal to ")
3044 <<
"the source rank, but got " << op.getType();
3046 return op->emitError(
"expected result type to be ")
3048 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3051 return op->emitError(
"expected result element type to be ")
3052 << memrefType.getElementType() <<
", but got " << op.getType();
3054 return op->emitError(
3055 "expected result and source memory spaces to match, but got ")
3058 return op->emitError(
"expected result type to be ")
3060 <<
" or a rank-reduced version. (mismatch of result layout), but "
3064 llvm_unreachable(
"unexpected subview verification result");
3069 MemRefType baseType = getSourceType();
3070 MemRefType subViewType =
getType();
3076 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3077 return emitError(
"different memory spaces specified for base memref "
3079 << baseType <<
" and subview memref type " << subViewType;
3082 if (!baseType.isStrided())
3083 return emitError(
"base type ") << baseType <<
" is not strided";
3087 MemRefType expectedType = SubViewOp::inferResultType(
3088 baseType, staticOffsets, staticSizes, staticStrides);
3093 expectedType, subViewType);
3098 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3100 *
this, expectedType);
3105 *
this, expectedType);
3115 *
this, expectedType);
3120 *
this, expectedType);
3126 staticStrides,
true);
3128 return getOperation()->emitError(boundsResult.
errorMessage);
3134 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3143 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3144 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3145 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3147 unsigned rank = ranks[0];
3149 for (
unsigned idx = 0; idx < rank; ++idx) {
3151 op.isDynamicOffset(idx)
3152 ? op.getDynamicOffset(idx)
3155 op.isDynamicSize(idx)
3156 ? op.getDynamicSize(idx)
3159 op.isDynamicStride(idx)
3160 ? op.getDynamicStride(idx)
3162 res.emplace_back(
Range{offset, size, stride});
3175 MemRefType currentResultType, MemRefType currentSourceType,
3178 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3179 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3181 currentSourceType, currentResultType, mixedSizes);
3185 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3187 unsigned numDimsAfterReduction =
3188 nonRankReducedType.getRank() - unusedDims->count();
3189 shape.reserve(numDimsAfterReduction);
3190 strides.reserve(numDimsAfterReduction);
3191 for (
const auto &[idx, size, stride] :
3192 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3193 nonRankReducedType.getShape(), layout.getStrides())) {
3194 if (unusedDims->test(idx))
3196 shape.push_back(size);
3197 strides.push_back(stride);
3202 layout.getOffset(), strides),
3203 nonRankReducedType.getMemorySpace());
3208 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3209 unsigned rank = memrefType.getRank();
3213 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3214 targetShape, memrefType, offsets, sizes, strides);
3215 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3222 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3223 assert(sourceMemrefType &&
"not a ranked memref type");
3224 auto sourceShape = sourceMemrefType.getShape();
3225 if (sourceShape.equals(desiredShape))
3227 auto maybeRankReductionMask =
3229 if (!maybeRankReductionMask)
3239 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3242 auto mixedOffsets = subViewOp.getMixedOffsets();
3243 auto mixedSizes = subViewOp.getMixedSizes();
3244 auto mixedStrides = subViewOp.getMixedStrides();
3249 return !intValue || intValue.value() != 0;
3256 return !intValue || intValue.value() != 1;
3264 if (!intValue || *intValue != sourceShape[size.index()])
3288 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3292 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3296 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3297 return matchPattern(operand, matchConstantIndex());
3301 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3313 subViewOp.getType(), subViewOp.getSourceType(),
3314 llvm::cast<MemRefType>(castOp.getSource().getType()),
3315 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3316 subViewOp.getMixedStrides());
3320 Value newSubView = SubViewOp::create(
3321 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3322 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3323 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3324 subViewOp.getStaticStrides());
3337 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3341 if (subViewOp.getSourceType() == subViewOp.getType()) {
3342 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3346 subViewOp.getSource());
3358 MemRefType resTy = SubViewOp::inferResultType(
3359 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3362 MemRefType nonReducedType = resTy;
3365 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3366 if (droppedDims.none())
3367 return nonReducedType;
3370 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3375 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3376 if (droppedDims.test(i))
3378 targetStrides.push_back(nonReducedStrides[i]);
3379 targetShape.push_back(nonReducedType.getDimSize(i));
3384 offset, targetStrides),
3385 nonReducedType.getMemorySpace());
3401 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3405 MemRefType sourceMemrefType = getSource().getType();
3406 MemRefType resultMemrefType = getResult().getType();
3408 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3410 if (resultMemrefType == sourceMemrefType &&
3411 resultMemrefType.hasStaticShape() &&
3412 (!resultLayout || resultLayout.hasStaticLayout())) {
3413 return getViewSource();
3419 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3420 auto srcSizes = srcSubview.getMixedSizes();
3422 auto offsets = getMixedOffsets();
3424 auto strides = getMixedStrides();
3425 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3426 bool allSizesSame = llvm::equal(sizes, srcSizes);
3427 if (allOffsetsZero && allStridesOne && allSizesSame &&
3428 resultMemrefType == sourceMemrefType)
3429 return getViewSource();
3435 FailureOr<std::optional<SmallVector<Value>>>
3436 SubViewOp::bubbleDownCasts(
OpBuilder &builder) {
3444 void TransposeOp::getAsmResultNames(
3446 setNameFn(getResult(),
"transpose");
3452 auto originalSizes = memRefType.getShape();
3453 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3454 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3457 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3458 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3467 AffineMapAttr permutation,
3469 auto permutationMap = permutation.getValue();
3470 assert(permutationMap);
3472 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3476 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3477 build(b, result, resultType, in, attrs);
3482 p <<
" " << getIn() <<
" " << getPermutation();
3484 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3490 MemRefType srcType, dstType;
3499 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3506 return emitOpError(
"expected a permutation map");
3507 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3508 return emitOpError(
"expected a permutation map of same rank as the input");
3510 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3511 auto resultType = llvm::cast<MemRefType>(
getType());
3513 .canonicalizeStridedLayout();
3515 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3516 return emitOpError(
"result type ")
3518 <<
" is not equivalent to the canonical transposed input type "
3519 << canonicalResultType;
3526 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3530 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3532 getPermutation().
compose(otherTransposeOp.getPermutation());
3533 getInMutable().assign(otherTransposeOp.getIn());
3534 setPermutation(composedPermutation);
3540 FailureOr<std::optional<SmallVector<Value>>>
3541 TransposeOp::bubbleDownCasts(
OpBuilder &builder) {
3549 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3550 setNameFn(getResult(),
"view");
3554 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3558 if (!baseType.getLayout().isIdentity())
3559 return emitError(
"unsupported map for base memref type ") << baseType;
3562 if (!viewType.getLayout().isIdentity())
3563 return emitError(
"unsupported map for result memref type ") << viewType;
3566 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3567 return emitError(
"different memory spaces specified for base memref "
3569 << baseType <<
" and view memref type " << viewType;
3572 unsigned numDynamicDims = viewType.getNumDynamicDims();
3573 if (getSizes().size() != numDynamicDims)
3574 return emitError(
"incorrect number of size operands for type ") << viewType;
3579 Value ViewOp::getViewSource() {
return getSource(); }
3582 MemRefType sourceMemrefType = getSource().getType();
3583 MemRefType resultMemrefType = getResult().getType();
3585 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3586 return getViewSource();
3596 LogicalResult matchAndRewrite(ViewOp viewOp,
3599 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3600 return matchPattern(operand, matchConstantIndex());
3605 auto memrefType = viewOp.getType();
3610 if (
failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3612 assert(oldOffset == 0 &&
"Expected 0 offset");
3620 newShapeConstants.reserve(memrefType.getRank());
3622 unsigned dynamicDimPos = 0;
3623 unsigned rank = memrefType.getRank();
3624 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3625 int64_t dimSize = memrefType.getDimSize(dim);
3627 if (ShapedType::isStatic(dimSize)) {
3628 newShapeConstants.push_back(dimSize);
3631 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3632 if (
auto constantIndexOp =
3633 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3635 newShapeConstants.push_back(constantIndexOp.value());
3638 newShapeConstants.push_back(dimSize);
3639 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3645 MemRefType newMemRefType =
3648 if (newMemRefType == memrefType)
3652 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3653 viewOp.getOperand(0), viewOp.getByteShift(),
3664 LogicalResult matchAndRewrite(ViewOp viewOp,
3666 Value memrefOperand = viewOp.getOperand(0);
3667 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3670 Value allocOperand = memrefCastOp.getOperand();
3675 viewOp.getByteShift(),
3685 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3688 FailureOr<std::optional<SmallVector<Value>>>
3689 ViewOp::bubbleDownCasts(
OpBuilder &builder) {
3700 "expects the number of subscripts to be equal to memref rank");
3701 switch (getKind()) {
3702 case arith::AtomicRMWKind::addf:
3703 case arith::AtomicRMWKind::maximumf:
3704 case arith::AtomicRMWKind::minimumf:
3705 case arith::AtomicRMWKind::mulf:
3706 if (!llvm::isa<FloatType>(getValue().
getType()))
3707 return emitOpError() <<
"with kind '"
3708 << arith::stringifyAtomicRMWKind(getKind())
3709 <<
"' expects a floating-point type";
3711 case arith::AtomicRMWKind::addi:
3712 case arith::AtomicRMWKind::maxs:
3713 case arith::AtomicRMWKind::maxu:
3714 case arith::AtomicRMWKind::mins:
3715 case arith::AtomicRMWKind::minu:
3716 case arith::AtomicRMWKind::muli:
3717 case arith::AtomicRMWKind::ori:
3718 case arith::AtomicRMWKind::xori:
3719 case arith::AtomicRMWKind::andi:
3720 if (!llvm::isa<IntegerType>(getValue().
getType()))
3721 return emitOpError() <<
"with kind '"
3722 << arith::stringifyAtomicRMWKind(getKind())
3723 <<
"' expects an integer type";
3738 FailureOr<std::optional<SmallVector<Value>>>
3739 AtomicRMWOp::bubbleDownCasts(
OpBuilder &builder) {
3748 #define GET_OP_CLASSES
3749 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.