24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (ShapedType::isStatic(cstVal)) {
116 static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118 MemorySpaceCastOpInterface castOp =
119 MemorySpaceCastOpInterface::getIfPromotableCast(src);
127 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
128 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
132 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
133 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
138 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return std::make_tuple(castOp, *tgtTy, *srcTy);
146 template <
typename ConcreteOpTy>
147 static FailureOr<std::optional<SmallVector<Value>>>
157 llvm::append_range(operands, op->getOperands());
161 auto newOp = ConcreteOpTy::create(
162 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
163 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 return std::optional<SmallVector<Value>>(
177 void AllocOp::getAsmResultNames(
179 setNameFn(getResult(),
"alloc");
182 void AllocaOp::getAsmResultNames(
184 setNameFn(getResult(),
"alloca");
187 template <
typename AllocLikeOp>
189 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190 "applies to only alloc or alloca");
191 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
193 return op.emitOpError(
"result must be a memref");
195 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196 return op.emitOpError(
"dimension operand count does not equal memref "
197 "dynamic dimension count");
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError(
"symbol operand count does not equal memref symbol "
205 << numSymbols <<
", got " << op.getSymbolOperands().size();
216 "requires an ancestor op with AutomaticAllocationScope trait");
223 template <
typename AllocLikeOp>
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
231 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return constSizeArg.isNonNegative();
239 auto memrefType = alloc.getType();
244 newShapeConstants.reserve(memrefType.getRank());
247 unsigned dynamicDimPos = 0;
248 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
258 constSizeArg.isNonNegative()) {
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
270 MemRefType newMemRefType =
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
285 template <
typename T>
289 LogicalResult matchAndRewrite(T alloc,
291 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
298 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
309 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
314 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
324 MemRefType resultType =
getType();
327 if (!sourceType.getLayout().isIdentity())
328 return emitError(
"unsupported layout for source memref type ")
332 if (!resultType.getLayout().isIdentity())
333 return emitError(
"unsupported layout for result memref type ")
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError(
"different memory spaces specified for source memref "
340 << sourceType <<
" and result memref type " << resultType;
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError(
"different element types specified for source memref "
346 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405 void AllocaScopeOp::getSuccessorRegions(
418 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
424 if (isa<SideEffects::AutomaticAllocationScopeResource>(
425 effect->getResource()))
441 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
447 if (isa<SideEffects::AutomaticAllocationScopeResource>(
448 effect->getResource()))
472 bool hasPotentialAlloca =
485 if (hasPotentialAlloca) {
518 if (!lastParentWithoutScope ||
531 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
532 if (!lastParentWithoutScope ||
539 Region *containingRegion =
nullptr;
540 for (
auto &r : lastParentWithoutScope->
getRegions()) {
541 if (r.isAncestor(op->getParentRegion())) {
542 assert(containingRegion ==
nullptr &&
543 "only one region can contain the op");
544 containingRegion = &r;
547 assert(containingRegion &&
"op must be contained in a region");
557 return containingRegion->isAncestor(v.getParentRegion());
560 toHoist.push_back(alloc);
567 for (
auto *op : toHoist) {
568 auto *cloned = rewriter.
clone(*op);
569 rewriter.
replaceOp(op, cloned->getResults());
585 if (!llvm::isPowerOf2_32(getAlignment()))
586 return emitOpError(
"alignment must be power of 2");
590 void AssumeAlignmentOp::getAsmResultNames(
592 setNameFn(getResult(),
"assume_align");
595 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
599 if (source.getAlignment() != getAlignment())
604 FailureOr<std::optional<SmallVector<Value>>>
605 AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
614 if (getOperandTypes() != getResultTypes())
615 return emitOpError(
"operand types and result types must match");
617 if (getOperandTypes().empty())
618 return emitOpError(
"expected at least one operand");
623 LogicalResult DistinctObjectsOp::inferReturnTypes(
637 setNameFn(getResult(),
"cast");
678 MemRefType sourceType =
679 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
683 if (!sourceType || !resultType)
687 if (sourceType.getElementType() != resultType.getElementType())
691 if (sourceType.getRank() != resultType.getRank())
695 int64_t sourceOffset, resultOffset;
697 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
702 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703 auto ss = std::get<0>(it), st = std::get<1>(it);
705 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
710 if (sourceOffset != resultOffset)
711 if (ShapedType::isDynamic(sourceOffset) &&
712 ShapedType::isStatic(resultOffset))
716 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
717 auto ss = std::get<0>(it), st = std::get<1>(it);
719 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
727 if (inputs.size() != 1 || outputs.size() != 1)
729 Type a = inputs.front(), b = outputs.front();
730 auto aT = llvm::dyn_cast<MemRefType>(a);
731 auto bT = llvm::dyn_cast<MemRefType>(b);
733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
737 if (aT.getElementType() != bT.getElementType())
739 if (aT.getLayout() != bT.getLayout()) {
740 int64_t aOffset, bOffset;
742 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744 aStrides.size() != bStrides.size())
751 auto checkCompatible = [](int64_t a, int64_t b) {
752 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
754 if (!checkCompatible(aOffset, bOffset))
756 for (
const auto &aStride :
enumerate(aStrides))
757 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
760 if (aT.getMemorySpace() != bT.getMemorySpace())
764 if (aT.getRank() != bT.getRank())
767 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
783 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785 if (aEltType != bEltType)
788 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790 return aMemSpace == bMemSpace;
800 FailureOr<std::optional<SmallVector<Value>>>
801 CastOp::bubbleDownCasts(
OpBuilder &builder) {
815 LogicalResult matchAndRewrite(CopyOp copyOp,
817 if (copyOp.getSource() != copyOp.getTarget())
832 LogicalResult matchAndRewrite(CopyOp copyOp,
834 if (isEmptyMemRef(copyOp.getSource().getType()) ||
835 isEmptyMemRef(copyOp.getTarget().getType())) {
847 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
854 for (
OpOperand &operand : op->getOpOperands()) {
857 operand.set(castOp.getOperand());
864 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
875 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
886 setNameFn(getResult(),
"dim");
893 build(builder, result, source, indexValue);
896 std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
906 if (!rankedSourceType)
917 setResultRange(getResult(),
926 std::map<int64_t, unsigned> numOccurences;
927 for (
auto val : vals)
928 numOccurences[val]++;
929 return numOccurences;
939 static FailureOr<llvm::SmallBitVector>
942 llvm::SmallBitVector unusedDims(originalType.getRank());
943 if (originalType.getRank() == reducedType.getRank())
947 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949 unusedDims.set(dim.index());
953 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
954 originalType.getRank())
958 int64_t originalOffset, candidateOffset;
960 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
962 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
974 std::map<int64_t, unsigned> currUnaccountedStrides =
976 std::map<int64_t, unsigned> candidateStridesNumOccurences =
978 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979 if (!unusedDims.test(dim))
981 int64_t originalStride = originalStrides[dim];
982 if (currUnaccountedStrides[originalStride] >
983 candidateStridesNumOccurences[originalStride]) {
985 currUnaccountedStrides[originalStride]--;
988 if (currUnaccountedStrides[originalStride] ==
989 candidateStridesNumOccurences[originalStride]) {
991 unusedDims.reset(dim);
994 if (currUnaccountedStrides[originalStride] <
995 candidateStridesNumOccurences[originalStride]) {
1002 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003 originalType.getRank())
1009 MemRefType sourceType = getSourceType();
1010 MemRefType resultType =
getType();
1011 FailureOr<llvm::SmallBitVector> unusedDims =
1013 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1019 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1024 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1030 int64_t indexVal = index.getInt();
1031 if (indexVal < 0 || indexVal >= memrefType.getRank())
1035 if (!memrefType.isDynamicDim(index.getInt())) {
1037 return builder.
getIndexAttr(memrefType.getShape()[index.getInt()]);
1041 unsigned unsignedIndex = index.getValue().getZExtValue();
1044 Operation *definingOp = getSource().getDefiningOp();
1046 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047 return *(alloc.getDynamicSizes().begin() +
1048 memrefType.getDynamicDimIndex(unsignedIndex));
1050 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051 return *(alloca.getDynamicSizes().begin() +
1052 memrefType.getDynamicDimIndex(unsignedIndex));
1054 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055 return *(view.getDynamicSizes().begin() +
1056 memrefType.getDynamicDimIndex(unsignedIndex));
1058 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060 unsigned resultIndex = 0;
1061 unsigned sourceRank = subview.getSourceType().getRank();
1062 unsigned sourceIndex = 0;
1063 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064 if (unusedDims.test(i))
1066 if (resultIndex == unsignedIndex) {
1072 assert(subview.isDynamicSize(sourceIndex) &&
1073 "expected dynamic subview size");
1074 return subview.getDynamicSize(sourceIndex);
1077 if (
auto sizeInterface =
1078 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1079 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1080 "Expected dynamic subview size");
1081 return sizeInterface.getDynamicSize(unsignedIndex);
1097 LogicalResult matchAndRewrite(DimOp dim,
1099 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1103 dim,
"Dim op is not defined by a reshape op.");
1114 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1115 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1116 if (reshape->isBeforeInBlock(definingOp)) {
1119 "dim.getIndex is not defined before reshape in the same block.");
1124 else if (dim->getBlock() != reshape->getBlock() &&
1125 !dim.getIndex().getParentRegion()->isProperAncestor(
1126 reshape->getParentRegion())) {
1131 dim,
"dim.getIndex does not dominate reshape.");
1139 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1140 if (load.
getType() != dim.getType())
1141 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1151 results.
add<DimOfMemRefReshape>(context);
1162 Value elementsPerStride) {
1174 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1175 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1176 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1178 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1181 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1182 <<
", " << getTagMemRef().getType();
1223 bool isStrided = strideInfo.size() == 2;
1224 if (!strideInfo.empty() && !isStrided) {
1226 "expected two stride related operands");
1231 if (types.size() != 3)
1254 unsigned numOperands = getNumOperands();
1258 if (numOperands < 4)
1259 return emitOpError(
"expected at least 4 operands");
1264 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1265 return emitOpError(
"expected source to be of memref type");
1266 if (numOperands < getSrcMemRefRank() + 4)
1267 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1269 if (!getSrcIndices().empty() &&
1270 !llvm::all_of(getSrcIndices().getTypes(),
1272 return emitOpError(
"expected source indices to be of index type");
1275 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1276 return emitOpError(
"expected destination to be of memref type");
1277 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1278 if (numOperands < numExpectedOperands)
1279 return emitOpError() <<
"expected at least " << numExpectedOperands
1281 if (!getDstIndices().empty() &&
1282 !llvm::all_of(getDstIndices().getTypes(),
1284 return emitOpError(
"expected destination indices to be of index type");
1288 return emitOpError(
"expected num elements to be of index type");
1291 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1292 return emitOpError(
"expected tag to be of memref type");
1293 numExpectedOperands += getTagMemRefRank();
1294 if (numOperands < numExpectedOperands)
1295 return emitOpError() <<
"expected at least " << numExpectedOperands
1297 if (!getTagIndices().empty() &&
1298 !llvm::all_of(getTagIndices().getTypes(),
1300 return emitOpError(
"expected tag indices to be of index type");
1304 if (numOperands != numExpectedOperands &&
1305 numOperands != numExpectedOperands + 2)
1306 return emitOpError(
"incorrect number of operands");
1310 if (!getStride().
getType().isIndex() ||
1311 !getNumElementsPerStride().
getType().isIndex())
1313 "expected stride and num elements per stride to be of type index");
1319 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1329 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1337 unsigned numTagIndices = getTagIndices().size();
1338 unsigned tagMemRefRank = getTagMemRefRank();
1339 if (numTagIndices != tagMemRefRank)
1340 return emitOpError() <<
"expected tagIndices to have the same number of "
1341 "elements as the tagMemRef rank, expected "
1342 << tagMemRefRank <<
", but got " << numTagIndices;
1350 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1352 setNameFn(getResult(),
"intptr");
1361 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1362 MLIRContext *context, std::optional<Location> location,
1363 ExtractStridedMetadataOp::Adaptor adaptor,
1365 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1369 unsigned sourceRank = sourceType.getRank();
1373 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1375 inferredReturnTypes.push_back(memrefType);
1377 inferredReturnTypes.push_back(indexType);
1379 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1380 inferredReturnTypes.push_back(indexType);
1384 void ExtractStridedMetadataOp::getAsmResultNames(
1386 setNameFn(getBaseBuffer(),
"base_buffer");
1387 setNameFn(getOffset(),
"offset");
1390 if (!getSizes().empty()) {
1391 setNameFn(getSizes().front(),
"sizes");
1392 setNameFn(getStrides().front(),
"strides");
1399 template <
typename Container>
1403 assert(values.size() == maybeConstants.size() &&
1404 " expected values and maybeConstants of the same size");
1405 bool atLeastOneReplacement =
false;
1406 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1411 assert(isa<Attribute>(maybeConstant) &&
1412 "The constified value should be either unchanged (i.e., == result) "
1416 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1417 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1421 atLeastOneReplacement =
true;
1424 return atLeastOneReplacement;
1428 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1434 getConstifiedMixedOffset());
1436 getConstifiedMixedSizes());
1438 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1441 if (
auto prev = getSource().getDefiningOp<CastOp>())
1442 if (isa<MemRefType>(prev.getSource().getType())) {
1443 getSourceMutable().assign(prev.getSource());
1444 atLeastOneReplacement =
true;
1447 return success(atLeastOneReplacement);
1457 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1461 LogicalResult status =
1462 getSource().getType().getStridesAndOffset(staticValues, unused);
1464 assert(succeeded(status) &&
"could not get strides from type");
1469 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1474 LogicalResult status =
1475 getSource().getType().getStridesAndOffset(unused, offset);
1477 assert(succeeded(status) &&
"could not get offset from type");
1478 staticValues.push_back(offset);
1493 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1494 Type elementType = memrefType.getElementType();
1504 auto &body = getRegion();
1505 if (body.getNumArguments() != 1)
1506 return emitOpError(
"expected single number of entry block arguments");
1508 if (getResult().
getType() != body.getArgument(0).getType())
1509 return emitOpError(
"expected block argument of the same type result type");
1516 "body of 'memref.generic_atomic_rmw' should contain "
1517 "only operations with no side effects");
1547 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1548 <<
"] : " << getMemref().
getType() <<
' ';
1558 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1559 Type resultType = getResult().getType();
1560 if (parentType != resultType)
1561 return emitOpError() <<
"types mismatch between yield op: " << resultType
1562 <<
" and its parent: " << parentType;
1574 if (!op.isExternal()) {
1576 if (op.isUninitialized())
1577 p <<
"uninitialized";
1590 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1591 if (!memrefType || !memrefType.hasStaticShape())
1593 <<
"type should be static shaped memref, but got " << type;
1607 if (!llvm::isa<ElementsAttr>(initialValue))
1609 <<
"initial value should be a unit or elements attribute";
1614 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1615 if (!memrefType || !memrefType.hasStaticShape())
1616 return emitOpError(
"type should be static shaped memref, but got ")
1621 if (getInitialValue().has_value()) {
1622 Attribute initValue = getInitialValue().value();
1623 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1624 return emitOpError(
"initial value should be a unit or elements "
1625 "attribute, but got ")
1630 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1632 auto initElementType =
1633 cast<TensorType>(elementsAttr.getType()).getElementType();
1634 auto memrefElementType = memrefType.getElementType();
1636 if (initElementType != memrefElementType)
1637 return emitOpError(
"initial value element expected to be of type ")
1638 << memrefElementType <<
", but was of type " << initElementType;
1643 auto initShape = elementsAttr.getShapedType().getShape();
1644 auto memrefShape = memrefType.getShape();
1645 if (initShape != memrefShape)
1646 return emitOpError(
"initial value shape expected to be ")
1647 << memrefShape <<
" but was " << initShape;
1655 ElementsAttr GlobalOp::getConstantInitValue() {
1656 auto initVal = getInitialValue();
1657 if (getConstant() && initVal.has_value())
1658 return llvm::cast<ElementsAttr>(initVal.value());
1673 return emitOpError(
"'")
1674 << getName() <<
"' does not reference a valid global memref";
1676 Type resultType = getResult().getType();
1677 if (global.getType() != resultType)
1678 return emitOpError(
"result type ")
1679 << resultType <<
" does not match type " << global.getType()
1680 <<
" of the global memref @" << getName();
1690 return emitOpError(
"incorrect number of indices for load, expected ")
1703 FailureOr<std::optional<SmallVector<Value>>>
1704 LoadOp::bubbleDownCasts(
OpBuilder &builder) {
1713 void MemorySpaceCastOp::getAsmResultNames(
1715 setNameFn(getResult(),
"memspacecast");
1719 if (inputs.size() != 1 || outputs.size() != 1)
1721 Type a = inputs.front(), b = outputs.front();
1722 auto aT = llvm::dyn_cast<MemRefType>(a);
1723 auto bT = llvm::dyn_cast<MemRefType>(b);
1725 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1726 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1729 if (aT.getElementType() != bT.getElementType())
1731 if (aT.getLayout() != bT.getLayout())
1733 if (aT.getShape() != bT.getShape())
1738 return uaT.getElementType() == ubT.getElementType();
1743 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1746 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1747 getSourceMutable().assign(parentCast.getSource());
1761 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1762 PtrLikeTypeInterface src) {
1763 return isa<BaseMemRefType>(tgt) &&
1764 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1767 MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1770 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1771 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1775 bool MemorySpaceCastOp::isSourcePromotable() {
1776 return getDest().getType().getMemorySpace() ==
nullptr;
1784 p <<
" " << getMemref() <<
'[';
1786 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1787 p <<
", locality<" << getLocalityHint();
1788 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1790 (*this)->getAttrs(),
1791 {
"localityHint",
"isWrite",
"isDataCache"});
1798 IntegerAttr localityHint;
1800 StringRef readOrWrite, cacheType;
1817 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1819 "rw specifier has to be 'read' or 'write'");
1820 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1823 if (cacheType !=
"data" && cacheType !=
"instr")
1825 "cache type has to be 'data' or 'instr'");
1827 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1835 return emitOpError(
"too few indices");
1840 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1852 auto type = getOperand().getType();
1853 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1854 if (shapedType && shapedType.hasRank())
1856 return IntegerAttr();
1863 void ReinterpretCastOp::getAsmResultNames(
1865 setNameFn(getResult(),
"reinterpret_cast");
1872 MemRefType resultType,
Value source,
1882 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1893 auto sourceType = cast<BaseMemRefType>(source.
getType());
1900 b.
getContext(), staticOffsets.front(), staticStrides);
1901 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1902 stridedLayout, sourceType.getMemorySpace());
1903 build(b, result, resultType, source, offset, sizes, strides, attrs);
1907 MemRefType resultType,
Value source,
1912 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1916 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1920 strideValues, attrs);
1924 MemRefType resultType,
Value source,
Value offset,
1931 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1938 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1939 auto resultType = llvm::cast<MemRefType>(
getType());
1940 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1941 return emitError(
"different memory spaces specified for source type ")
1942 << srcType <<
" and result memref type " << resultType;
1943 if (srcType.getElementType() != resultType.getElementType())
1944 return emitError(
"different element types specified for source type ")
1945 << srcType <<
" and result memref type " << resultType;
1948 for (
auto [idx, resultSize, expectedSize] :
1950 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1951 return emitError(
"expected result type with size = ")
1952 << (ShapedType::isDynamic(expectedSize)
1953 ? std::string(
"dynamic")
1954 : std::to_string(expectedSize))
1955 <<
" instead of " << resultSize <<
" in dim = " << idx;
1961 int64_t resultOffset;
1963 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1964 return emitError(
"expected result type to have strided layout but found ")
1968 int64_t expectedOffset = getStaticOffsets().front();
1969 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1970 return emitError(
"expected result type with offset = ")
1971 << (ShapedType::isDynamic(expectedOffset)
1972 ? std::string(
"dynamic")
1973 : std::to_string(expectedOffset))
1974 <<
" instead of " << resultOffset;
1977 for (
auto [idx, resultStride, expectedStride] :
1979 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1980 return emitError(
"expected result type with stride = ")
1981 << (ShapedType::isDynamic(expectedStride)
1982 ? std::string(
"dynamic")
1983 : std::to_string(expectedStride))
1984 <<
" instead of " << resultStride <<
" in dim = " << idx;
1991 Value src = getSource();
1992 auto getPrevSrc = [&]() ->
Value {
1995 return prev.getSource();
1999 return prev.getSource();
2005 return prev.getSource();
2010 if (
auto prevSrc = getPrevSrc()) {
2011 getSourceMutable().assign(prevSrc);
2034 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2036 assert(succeeded(status) &&
"could not get strides from type");
2041 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2043 assert(values.size() == 1 &&
2044 "reinterpret_cast must have one and only one offset");
2047 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2049 assert(succeeded(status) &&
"could not get offset from type");
2050 staticValues.push_back(offset);
2098 struct ReinterpretCastOpExtractStridedMetadataFolder
2103 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2105 auto extractStridedMetadata =
2106 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2107 if (!extractStridedMetadata)
2112 auto isReinterpretCastNoop = [&]() ->
bool {
2114 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2115 op.getConstifiedMixedStrides()))
2119 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2120 op.getConstifiedMixedSizes()))
2124 assert(op.getMixedOffsets().size() == 1 &&
2125 "reinterpret_cast with more than one offset should have been "
2126 "rejected by the verifier");
2127 return extractStridedMetadata.getConstifiedMixedOffset() ==
2128 op.getConstifiedMixedOffset();
2131 if (!isReinterpretCastNoop()) {
2148 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2158 Type srcTy = extractStridedMetadata.getSource().getType();
2159 if (srcTy == op.getResult().getType())
2160 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2163 extractStridedMetadata.getSource());
2169 struct ReinterpretCastOpConstantFolder
2174 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2176 unsigned srcStaticCount = llvm::count_if(
2177 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2178 op.getMixedStrides()),
2189 if (srcStaticCount ==
2190 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2194 auto newReinterpretCast = ReinterpretCastOp::create(
2195 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2205 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2206 ReinterpretCastOpConstantFolder>(context);
2209 FailureOr<std::optional<SmallVector<Value>>>
2210 ReinterpretCastOp::bubbleDownCasts(
OpBuilder &builder) {
2218 void CollapseShapeOp::getAsmResultNames(
2220 setNameFn(getResult(),
"collapse_shape");
2223 void ExpandShapeOp::getAsmResultNames(
2225 setNameFn(getResult(),
"expand_shape");
2230 reifiedResultShapes = {
2231 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2240 static LogicalResult
2244 bool allowMultipleDynamicDimsPerGroup) {
2246 if (collapsedShape.size() != reassociation.size())
2247 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2248 << reassociation.size() <<
", expected " << collapsedShape.size();
2252 int64_t nextDim = 0;
2255 int64_t collapsedDim = it.index();
2257 bool foundDynamic =
false;
2258 for (int64_t expandedDim : group) {
2259 if (expandedDim != nextDim++)
2260 return op->
emitOpError(
"reassociation indices must be contiguous");
2262 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2264 << expandedDim <<
" is out of bounds";
2267 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2268 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2270 "at most one dimension in a reassociation group may be dynamic");
2271 foundDynamic =
true;
2276 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2279 <<
") must be dynamic if and only if reassociation group is "
2284 if (!foundDynamic) {
2285 int64_t groupSize = 1;
2286 for (int64_t expandedDim : group)
2287 groupSize *= expandedShape[expandedDim];
2288 if (groupSize != collapsedShape[collapsedDim])
2290 << collapsedShape[collapsedDim]
2291 <<
") must equal reassociation group size (" << groupSize <<
")";
2295 if (collapsedShape.empty()) {
2297 for (int64_t d : expandedShape)
2300 "rank 0 memrefs can only be extended/collapsed with/from ones");
2301 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2305 << expandedShape.size()
2306 <<
") inconsistent with number of reassociation indices (" << nextDim
2319 getReassociationIndices());
2328 getReassociationIndices());
2333 static FailureOr<StridedLayoutAttr>
2338 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2340 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2355 reverseResultStrides.reserve(resultShape.size());
2356 unsigned shapeIndex = resultShape.size() - 1;
2357 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2359 int64_t currentStrideToExpand = std::get<1>(it);
2360 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2361 reverseResultStrides.push_back(currentStrideToExpand);
2362 currentStrideToExpand =
2368 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2369 resultStrides.resize(resultShape.size(), 1);
2373 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2376 if (srcType.getLayout().isIdentity()) {
2379 MemRefLayoutAttrInterface layout;
2381 srcType.getMemorySpace());
2385 FailureOr<StridedLayoutAttr> computedLayout =
2387 if (
failed(computedLayout))
2389 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2390 srcType.getMemorySpace());
2393 FailureOr<SmallVector<OpFoldResult>>
2395 MemRefType expandedType,
2398 std::optional<SmallVector<OpFoldResult>> outputShape =
2403 return *outputShape;
2410 auto [staticOutputShape, dynamicOutputShape] =
2412 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2414 dynamicOutputShape, staticOutputShape);
2422 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2423 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2424 builder, result.
location, memrefResultTy, reassociation, inputShape);
2427 assert(succeeded(outputShape) &&
"unable to infer output shape");
2428 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2435 auto srcType = llvm::cast<MemRefType>(src.
getType());
2436 FailureOr<MemRefType> resultType =
2437 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2440 assert(succeeded(resultType) &&
"could not compute layout");
2441 build(builder, result, *resultType, src, reassociation);
2449 auto srcType = llvm::cast<MemRefType>(src.
getType());
2450 FailureOr<MemRefType> resultType =
2451 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2454 assert(succeeded(resultType) &&
"could not compute layout");
2455 build(builder, result, *resultType, src, reassociation, outputShape);
2459 MemRefType srcType = getSrcType();
2460 MemRefType resultType = getResultType();
2462 if (srcType.getRank() > resultType.getRank()) {
2463 auto r0 = srcType.getRank();
2464 auto r1 = resultType.getRank();
2465 return emitOpError(
"has source rank ")
2466 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2467 << r0 <<
" > " << r1 <<
").";
2472 resultType.getShape(),
2473 getReassociationIndices(),
2478 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2479 srcType, resultType.getShape(), getReassociationIndices());
2480 if (
failed(expectedResultType))
2481 return emitOpError(
"invalid source layout map");
2484 if (*expectedResultType != resultType)
2485 return emitOpError(
"expected expanded type to be ")
2486 << *expectedResultType <<
" but found " << resultType;
2488 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2489 return emitOpError(
"expected number of static shape bounds to be equal to "
2490 "the output rank (")
2491 << resultType.getRank() <<
") but found "
2492 << getStaticOutputShape().size() <<
" inputs instead";
2494 if ((int64_t)getOutputShape().size() !=
2495 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2496 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2497 "static_output_shape: static_output_shape has ")
2498 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2499 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2506 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2507 return emitOpError(
"invalid output shape provided at pos ") << pos;
2521 FailureOr<std::optional<SmallVector<Value>>>
2522 ExpandShapeOp::bubbleDownCasts(
OpBuilder &builder) {
2533 static FailureOr<StridedLayoutAttr>
2536 bool strict =
false) {
2539 auto srcShape = srcType.getShape();
2540 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2549 resultStrides.reserve(reassociation.size());
2552 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2553 ref = ref.drop_back();
2554 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2555 resultStrides.push_back(srcStrides[ref.back()]);
2561 resultStrides.push_back(ShapedType::kDynamic);
2566 unsigned resultStrideIndex = resultStrides.size() - 1;
2570 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2582 if (strict && (stride.saturated || srcStride.saturated))
2587 if (srcShape[idx - 1] == 1)
2590 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2597 bool CollapseShapeOp::isGuaranteedCollapsible(
2600 if (srcType.getLayout().isIdentity())
2607 MemRefType CollapseShapeOp::computeCollapsedType(
2610 resultShape.reserve(reassociation.size());
2613 for (int64_t srcDim : group)
2616 resultShape.push_back(groupSize.asInteger());
2619 if (srcType.getLayout().isIdentity()) {
2622 MemRefLayoutAttrInterface layout;
2624 srcType.getMemorySpace());
2630 FailureOr<StridedLayoutAttr> computedLayout =
2632 assert(succeeded(computedLayout) &&
2633 "invalid source layout map or collapsing non-contiguous dims");
2634 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2635 srcType.getMemorySpace());
2641 auto srcType = llvm::cast<MemRefType>(src.
getType());
2642 MemRefType resultType =
2643 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2646 build(b, result, resultType, src, attrs);
2650 MemRefType srcType = getSrcType();
2651 MemRefType resultType = getResultType();
2653 if (srcType.getRank() < resultType.getRank()) {
2654 auto r0 = srcType.getRank();
2655 auto r1 = resultType.getRank();
2656 return emitOpError(
"has source rank ")
2657 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2658 << r0 <<
" < " << r1 <<
").";
2663 srcType.getShape(), getReassociationIndices(),
2668 MemRefType expectedResultType;
2669 if (srcType.getLayout().isIdentity()) {
2672 MemRefLayoutAttrInterface layout;
2673 expectedResultType =
2674 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2675 srcType.getMemorySpace());
2680 FailureOr<StridedLayoutAttr> computedLayout =
2682 if (
failed(computedLayout))
2684 "invalid source layout map or collapsing non-contiguous dims");
2685 expectedResultType =
2687 *computedLayout, srcType.getMemorySpace());
2690 if (expectedResultType != resultType)
2691 return emitOpError(
"expected collapsed type to be ")
2692 << expectedResultType <<
" but found " << resultType;
2704 auto cast = op.getOperand().getDefiningOp<CastOp>();
2711 Type newResultType = CollapseShapeOp::computeCollapsedType(
2712 llvm::cast<MemRefType>(cast.getOperand().getType()),
2713 op.getReassociationIndices());
2715 if (newResultType == op.getResultType()) {
2717 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2720 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2721 op.getReassociationIndices());
2733 memref::DimOp, MemRefType>,
2737 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2738 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2739 adaptor.getOperands());
2742 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2743 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2744 adaptor.getOperands());
2747 FailureOr<std::optional<SmallVector<Value>>>
2748 CollapseShapeOp::bubbleDownCasts(
OpBuilder &builder) {
2756 void ReshapeOp::getAsmResultNames(
2758 setNameFn(getResult(),
"reshape");
2762 Type operandType = getSource().getType();
2763 Type resultType = getResult().getType();
2765 Type operandElementType =
2766 llvm::cast<ShapedType>(operandType).getElementType();
2767 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2768 if (operandElementType != resultElementType)
2769 return emitOpError(
"element types of source and destination memref "
2770 "types should be the same");
2772 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2773 if (!operandMemRefType.getLayout().isIdentity())
2774 return emitOpError(
"source memref type should have identity affine map");
2778 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2779 if (resultMemRefType) {
2780 if (!resultMemRefType.getLayout().isIdentity())
2781 return emitOpError(
"result memref type should have identity affine map");
2782 if (shapeSize == ShapedType::kDynamic)
2783 return emitOpError(
"cannot use shape operand with dynamic length to "
2784 "reshape to statically-ranked memref type");
2785 if (shapeSize != resultMemRefType.getRank())
2787 "length of shape operand differs from the result's memref rank");
2792 FailureOr<std::optional<SmallVector<Value>>>
2793 ReshapeOp::bubbleDownCasts(
OpBuilder &builder) {
2803 return emitOpError(
"store index operand count not equal to memref rank");
2808 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2814 FailureOr<std::optional<SmallVector<Value>>>
2815 StoreOp::bubbleDownCasts(
OpBuilder &builder) {
2824 void SubViewOp::getAsmResultNames(
2826 setNameFn(getResult(),
"subview");
2832 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2836 unsigned rank = sourceMemRefType.getRank();
2838 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2839 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2840 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2843 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2847 int64_t targetOffset = sourceOffset;
2848 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2849 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2859 targetStrides.reserve(staticOffsets.size());
2860 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2861 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2868 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2870 targetOffset, targetStrides),
2871 sourceMemRefType.getMemorySpace());
2874 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2889 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2890 staticSizes, staticStrides);
2893 MemRefType SubViewOp::inferRankReducedResultType(
2897 MemRefType inferredType =
2898 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2899 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2901 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2902 return inferredType;
2905 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2907 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2910 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2912 rankReducedStrides.reserve(resultShape.size());
2913 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2914 if (!dimsToProject->contains(idx))
2915 rankReducedStrides.push_back(value);
2919 inferredLayout.getOffset(),
2920 rankReducedStrides),
2921 inferredType.getMemorySpace());
2924 MemRefType SubViewOp::inferRankReducedResultType(
2933 return SubViewOp::inferRankReducedResultType(
2934 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2941 MemRefType resultType,
Value source,
2951 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2954 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2955 staticSizes, staticStrides);
2958 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2971 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2980 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2984 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2988 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2991 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2997 MemRefType resultType,
Value source,
3002 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
3006 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
3010 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
3013 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3029 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3036 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3040 Value SubViewOp::getViewSource() {
return getSource(); }
3045 int64_t t1Offset, t2Offset;
3047 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3048 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3049 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3056 const llvm::SmallBitVector &droppedDims) {
3057 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3058 "incorrect number of bits");
3059 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3060 "incorrect number of dropped dims");
3061 int64_t t1Offset, t2Offset;
3063 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3064 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3067 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3070 if (t1Strides[i] != t2Strides[
j])
3078 SubViewOp op,
Type expectedType) {
3079 auto memrefType = llvm::cast<ShapedType>(expectedType);
3084 return op->emitError(
"expected result rank to be smaller or equal to ")
3085 <<
"the source rank, but got " << op.getType();
3087 return op->emitError(
"expected result type to be ")
3089 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3092 return op->emitError(
"expected result element type to be ")
3093 << memrefType.getElementType() <<
", but got " << op.getType();
3095 return op->emitError(
3096 "expected result and source memory spaces to match, but got ")
3099 return op->emitError(
"expected result type to be ")
3101 <<
" or a rank-reduced version. (mismatch of result layout), but "
3105 llvm_unreachable(
"unexpected subview verification result");
3110 MemRefType baseType = getSourceType();
3111 MemRefType subViewType =
getType();
3117 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3118 return emitError(
"different memory spaces specified for base memref "
3120 << baseType <<
" and subview memref type " << subViewType;
3123 if (!baseType.isStrided())
3124 return emitError(
"base type ") << baseType <<
" is not strided";
3128 MemRefType expectedType = SubViewOp::inferResultType(
3129 baseType, staticOffsets, staticSizes, staticStrides);
3134 expectedType, subViewType);
3139 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3141 *
this, expectedType);
3146 *
this, expectedType);
3156 *
this, expectedType);
3161 *
this, expectedType);
3167 staticStrides,
true);
3169 return getOperation()->emitError(boundsResult.
errorMessage);
3175 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3184 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3185 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3186 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3188 unsigned rank = ranks[0];
3190 for (
unsigned idx = 0; idx < rank; ++idx) {
3192 op.isDynamicOffset(idx)
3193 ? op.getDynamicOffset(idx)
3196 op.isDynamicSize(idx)
3197 ? op.getDynamicSize(idx)
3200 op.isDynamicStride(idx)
3201 ? op.getDynamicStride(idx)
3203 res.emplace_back(
Range{offset, size, stride});
3216 MemRefType currentResultType, MemRefType currentSourceType,
3219 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3220 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3222 currentSourceType, currentResultType, mixedSizes);
3226 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3228 unsigned numDimsAfterReduction =
3229 nonRankReducedType.getRank() - unusedDims->count();
3230 shape.reserve(numDimsAfterReduction);
3231 strides.reserve(numDimsAfterReduction);
3232 for (
const auto &[idx, size, stride] :
3233 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3234 nonRankReducedType.getShape(), layout.getStrides())) {
3235 if (unusedDims->test(idx))
3237 shape.push_back(size);
3238 strides.push_back(stride);
3243 layout.getOffset(), strides),
3244 nonRankReducedType.getMemorySpace());
3249 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3250 unsigned rank = memrefType.getRank();
3254 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3255 targetShape, memrefType, offsets, sizes, strides);
3256 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3263 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3264 assert(sourceMemrefType &&
"not a ranked memref type");
3265 auto sourceShape = sourceMemrefType.getShape();
3266 if (sourceShape.equals(desiredShape))
3268 auto maybeRankReductionMask =
3270 if (!maybeRankReductionMask)
3280 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3283 auto mixedOffsets = subViewOp.getMixedOffsets();
3284 auto mixedSizes = subViewOp.getMixedSizes();
3285 auto mixedStrides = subViewOp.getMixedStrides();
3290 return !intValue || intValue.value() != 0;
3297 return !intValue || intValue.value() != 1;
3305 if (!intValue || *intValue != sourceShape[size.index()])
3329 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3333 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3337 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3338 return matchPattern(operand, matchConstantIndex());
3342 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3354 subViewOp.getType(), subViewOp.getSourceType(),
3355 llvm::cast<MemRefType>(castOp.getSource().getType()),
3356 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3357 subViewOp.getMixedStrides());
3361 Value newSubView = SubViewOp::create(
3362 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3363 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3364 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3365 subViewOp.getStaticStrides());
3378 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3382 if (subViewOp.getSourceType() == subViewOp.getType()) {
3383 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3387 subViewOp.getSource());
3399 MemRefType resTy = SubViewOp::inferResultType(
3400 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3403 MemRefType nonReducedType = resTy;
3406 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3407 if (droppedDims.none())
3408 return nonReducedType;
3411 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3416 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3417 if (droppedDims.test(i))
3419 targetStrides.push_back(nonReducedStrides[i]);
3420 targetShape.push_back(nonReducedType.getDimSize(i));
3425 offset, targetStrides),
3426 nonReducedType.getMemorySpace());
3442 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3446 MemRefType sourceMemrefType = getSource().getType();
3447 MemRefType resultMemrefType = getResult().getType();
3449 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3451 if (resultMemrefType == sourceMemrefType &&
3452 resultMemrefType.hasStaticShape() &&
3453 (!resultLayout || resultLayout.hasStaticLayout())) {
3454 return getViewSource();
3460 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3461 auto srcSizes = srcSubview.getMixedSizes();
3463 auto offsets = getMixedOffsets();
3465 auto strides = getMixedStrides();
3466 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3467 bool allSizesSame = llvm::equal(sizes, srcSizes);
3468 if (allOffsetsZero && allStridesOne && allSizesSame &&
3469 resultMemrefType == sourceMemrefType)
3470 return getViewSource();
3476 FailureOr<std::optional<SmallVector<Value>>>
3477 SubViewOp::bubbleDownCasts(
OpBuilder &builder) {
3481 void SubViewOp::inferStridedMetadataRanges(
3484 auto isUninitialized =
3490 if (llvm::any_of(offsetOperands, isUninitialized))
3495 if (llvm::any_of(sizeOperands, isUninitialized))
3500 if (llvm::any_of(stridesOperands, isUninitialized))
3504 ranges[getSourceMutable().getOperandNumber()];
3517 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3518 bool dropped = droppedDims.test(i);
3531 sizes.push_back(sizeOperands[i].getValue());
3534 setMetadata(getResult(),
3537 std::move(sizes), std::move(strides)));
3544 void TransposeOp::getAsmResultNames(
3546 setNameFn(getResult(),
"transpose");
3552 auto originalSizes = memRefType.getShape();
3553 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3554 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3557 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3558 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3567 AffineMapAttr permutation,
3569 auto permutationMap = permutation.getValue();
3570 assert(permutationMap);
3572 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3576 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3577 build(b, result, resultType, in, attrs);
3582 p <<
" " << getIn() <<
" " << getPermutation();
3584 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3590 MemRefType srcType, dstType;
3599 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3606 return emitOpError(
"expected a permutation map");
3607 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3608 return emitOpError(
"expected a permutation map of same rank as the input");
3610 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3611 auto resultType = llvm::cast<MemRefType>(
getType());
3613 .canonicalizeStridedLayout();
3615 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3616 return emitOpError(
"result type ")
3618 <<
" is not equivalent to the canonical transposed input type "
3619 << canonicalResultType;
3626 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3630 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3632 getPermutation().
compose(otherTransposeOp.getPermutation());
3633 getInMutable().assign(otherTransposeOp.getIn());
3634 setPermutation(composedPermutation);
3640 FailureOr<std::optional<SmallVector<Value>>>
3641 TransposeOp::bubbleDownCasts(
OpBuilder &builder) {
3649 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3650 setNameFn(getResult(),
"view");
3654 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3658 if (!baseType.getLayout().isIdentity())
3659 return emitError(
"unsupported map for base memref type ") << baseType;
3662 if (!viewType.getLayout().isIdentity())
3663 return emitError(
"unsupported map for result memref type ") << viewType;
3666 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3667 return emitError(
"different memory spaces specified for base memref "
3669 << baseType <<
" and view memref type " << viewType;
3672 unsigned numDynamicDims = viewType.getNumDynamicDims();
3673 if (getSizes().size() != numDynamicDims)
3674 return emitError(
"incorrect number of size operands for type ") << viewType;
3679 Value ViewOp::getViewSource() {
return getSource(); }
3682 MemRefType sourceMemrefType = getSource().getType();
3683 MemRefType resultMemrefType = getResult().getType();
3685 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3686 return getViewSource();
3696 LogicalResult matchAndRewrite(ViewOp viewOp,
3699 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3700 return matchPattern(operand, matchConstantIndex());
3705 auto memrefType = viewOp.getType();
3710 if (
failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3712 assert(oldOffset == 0 &&
"Expected 0 offset");
3720 newShapeConstants.reserve(memrefType.getRank());
3722 unsigned dynamicDimPos = 0;
3723 unsigned rank = memrefType.getRank();
3724 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3725 int64_t dimSize = memrefType.getDimSize(dim);
3727 if (ShapedType::isStatic(dimSize)) {
3728 newShapeConstants.push_back(dimSize);
3731 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3732 if (
auto constantIndexOp =
3733 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3735 newShapeConstants.push_back(constantIndexOp.value());
3738 newShapeConstants.push_back(dimSize);
3739 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3745 MemRefType newMemRefType =
3748 if (newMemRefType == memrefType)
3752 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3753 viewOp.getOperand(0), viewOp.getByteShift(),
3764 LogicalResult matchAndRewrite(ViewOp viewOp,
3766 Value memrefOperand = viewOp.getOperand(0);
3767 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3770 Value allocOperand = memrefCastOp.getOperand();
3775 viewOp.getByteShift(),
3785 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3788 FailureOr<std::optional<SmallVector<Value>>>
3789 ViewOp::bubbleDownCasts(
OpBuilder &builder) {
3800 "expects the number of subscripts to be equal to memref rank");
3801 switch (getKind()) {
3802 case arith::AtomicRMWKind::addf:
3803 case arith::AtomicRMWKind::maximumf:
3804 case arith::AtomicRMWKind::minimumf:
3805 case arith::AtomicRMWKind::mulf:
3806 if (!llvm::isa<FloatType>(getValue().
getType()))
3807 return emitOpError() <<
"with kind '"
3808 << arith::stringifyAtomicRMWKind(getKind())
3809 <<
"' expects a floating-point type";
3811 case arith::AtomicRMWKind::addi:
3812 case arith::AtomicRMWKind::maxs:
3813 case arith::AtomicRMWKind::maxu:
3814 case arith::AtomicRMWKind::mins:
3815 case arith::AtomicRMWKind::minu:
3816 case arith::AtomicRMWKind::muli:
3817 case arith::AtomicRMWKind::ori:
3818 case arith::AtomicRMWKind::xori:
3819 case arith::AtomicRMWKind::andi:
3820 if (!llvm::isa<IntegerType>(getValue().
getType()))
3821 return emitOpError() <<
"with kind '"
3822 << arith::stringifyAtomicRMWKind(getKind())
3823 <<
"' expects an integer type";
3838 FailureOr<std::optional<SmallVector<Value>>>
3839 AtomicRMWOp::bubbleDownCasts(
OpBuilder &builder) {
3848 #define GET_OP_CLASSES
3849 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
IRValueT get() const
Return the current value being used by this operand.
This lattice value represents the integer range of an SSA value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.