23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
34 return arith::ConstantOp::materialize(builder, value, type, loc);
47 auto cast = operand.get().getDefiningOp<CastOp>();
48 if (cast && operand.get() != inner &&
49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50 operand.set(cast.getOperand());
54 return success(folded);
60 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
62 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
69 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
123 int64_t constValue = it.value();
124 if (!isDynamic(constValue))
143 llvm::cast<IntegerAttr>(ofr.get<
Attribute>()).getInt());
146 std::optional<int64_t> maybeConstant =
166 LogicalResult hasStaticInformation =
168 if (failed(hasStaticInformation))
179 LogicalResult hasStaticInformation =
181 if (failed(hasStaticInformation))
190 void AllocOp::getAsmResultNames(
192 setNameFn(getResult(),
"alloc");
195 void AllocaOp::getAsmResultNames(
197 setNameFn(getResult(),
"alloca");
200 template <
typename AllocLikeOp>
202 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203 "applies to only alloc or alloca");
204 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
206 return op.emitOpError(
"result must be a memref");
208 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
209 return op.emitOpError(
"dimension operand count does not equal memref "
210 "dynamic dimension count");
212 unsigned numSymbols = 0;
213 if (!memRefType.getLayout().isIdentity())
214 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
215 if (op.getSymbolOperands().size() != numSymbols)
216 return op.emitOpError(
"symbol operand count does not equal memref symbol "
218 << numSymbols <<
", got " << op.getSymbolOperands().size();
229 "requires an ancestor op with AutomaticAllocationScope trait");
236 template <
typename AllocLikeOp>
240 LogicalResult matchAndRewrite(AllocLikeOp alloc,
244 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
246 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248 return constSizeArg.isNonNegative();
252 auto memrefType = alloc.getType();
257 newShapeConstants.reserve(memrefType.getRank());
260 unsigned dynamicDimPos = 0;
261 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
262 int64_t dimSize = memrefType.getDimSize(dim);
264 if (!ShapedType::isDynamic(dimSize)) {
265 newShapeConstants.push_back(dimSize);
268 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
271 constSizeArg.isNonNegative()) {
273 newShapeConstants.push_back(constSizeArg.getZExtValue());
276 newShapeConstants.push_back(ShapedType::kDynamic);
277 dynamicSizes.push_back(dynamicSize);
283 MemRefType newMemRefType =
285 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
288 auto newAlloc = rewriter.
create<AllocLikeOp>(
289 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
290 alloc.getAlignmentAttr());
298 template <
typename T>
302 LogicalResult matchAndRewrite(T alloc,
304 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
305 if (auto storeOp = dyn_cast<StoreOp>(op))
306 return storeOp.getValue() == alloc;
307 return !isa<DeallocOp>(op);
311 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
322 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
327 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
336 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
337 MemRefType resultType =
getType();
340 if (!sourceType.getLayout().isIdentity())
341 return emitError(
"unsupported layout for source memref type ")
345 if (!resultType.getLayout().isIdentity())
346 return emitError(
"unsupported layout for result memref type ")
350 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
351 return emitError(
"different memory spaces specified for source memref "
353 << sourceType <<
" and result memref type " << resultType;
356 if (sourceType.getElementType() != resultType.getElementType())
357 return emitError(
"different element types specified for source memref "
359 << sourceType <<
" and result memref type " << resultType;
362 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
363 return emitError(
"missing dimension operand for result type ")
365 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
366 return emitError(
"unnecessary dimension operand for result type ")
374 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
382 bool printBlockTerminators =
false;
385 if (!getResults().empty()) {
386 p <<
" -> (" << getResultTypes() <<
")";
387 printBlockTerminators =
true;
392 printBlockTerminators);
408 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
418 void AllocaScopeOp::getSuccessorRegions(
431 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
437 if (isa<SideEffects::AutomaticAllocationScopeResource>(
438 effect->getResource()))
454 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
460 if (isa<SideEffects::AutomaticAllocationScopeResource>(
461 effect->getResource()))
484 bool hasPotentialAlloca =
497 if (hasPotentialAlloca) {
530 if (!lastParentWithoutScope ||
543 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
544 if (!lastParentWithoutScope ||
551 Region *containingRegion =
nullptr;
552 for (
auto &r : lastParentWithoutScope->
getRegions()) {
553 if (r.isAncestor(op->getParentRegion())) {
554 assert(containingRegion ==
nullptr &&
555 "only one region can contain the op");
556 containingRegion = &r;
559 assert(containingRegion &&
"op must be contained in a region");
569 return containingRegion->isAncestor(v.getParentRegion());
572 toHoist.push_back(alloc);
579 for (
auto *op : toHoist) {
580 auto *cloned = rewriter.
clone(*op);
581 rewriter.
replaceOp(op, cloned->getResults());
597 if (!llvm::isPowerOf2_32(getAlignment()))
598 return emitOpError(
"alignment must be power of 2");
607 setNameFn(getResult(),
"cast");
648 MemRefType sourceType =
649 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
650 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653 if (!sourceType || !resultType)
657 if (sourceType.getElementType() != resultType.getElementType())
661 if (sourceType.getRank() != resultType.getRank())
665 int64_t sourceOffset, resultOffset;
672 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
673 auto ss = std::get<0>(it), st = std::get<1>(it);
675 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
680 if (sourceOffset != resultOffset)
681 if (ShapedType::isDynamic(sourceOffset) &&
682 !ShapedType::isDynamic(resultOffset))
686 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
687 auto ss = std::get<0>(it), st = std::get<1>(it);
689 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
697 if (inputs.size() != 1 || outputs.size() != 1)
699 Type a = inputs.front(), b = outputs.front();
700 auto aT = llvm::dyn_cast<MemRefType>(a);
701 auto bT = llvm::dyn_cast<MemRefType>(b);
703 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
704 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707 if (aT.getElementType() != bT.getElementType())
709 if (aT.getLayout() != bT.getLayout()) {
710 int64_t aOffset, bOffset;
714 aStrides.size() != bStrides.size())
721 auto checkCompatible = [](int64_t a, int64_t b) {
722 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
724 if (!checkCompatible(aOffset, bOffset))
726 for (
const auto &aStride :
enumerate(aStrides))
727 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730 if (aT.getMemorySpace() != bT.getMemorySpace())
734 if (aT.getRank() != bT.getRank())
737 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
738 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
739 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
753 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
754 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
755 if (aEltType != bEltType)
758 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
759 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
760 return aMemSpace == bMemSpace;
781 LogicalResult matchAndRewrite(CopyOp copyOp,
783 bool modified =
false;
786 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
787 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790 if (fromType && toType) {
791 if (fromType.getShape() == toType.getShape() &&
792 fromType.getElementType() == toType.getElementType()) {
794 copyOp.getSourceMutable().assign(castOp.getSource());
802 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
803 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806 if (fromType && toType) {
807 if (fromType.getShape() == toType.getShape() &&
808 fromType.getElementType() == toType.getElementType()) {
810 copyOp.getTargetMutable().assign(castOp.getSource());
817 return success(modified);
825 LogicalResult matchAndRewrite(CopyOp copyOp,
827 if (copyOp.getSource() != copyOp.getTarget())
842 LogicalResult matchAndRewrite(CopyOp copyOp,
844 if (isEmptyMemRef(copyOp.getSource().getType()) ||
845 isEmptyMemRef(copyOp.getTarget().getType())) {
857 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
860 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
868 operand.set(castOp.getOperand());
872 return success(folded);
879 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
890 setNameFn(getResult(),
"dim");
896 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
897 build(builder, result, source, indexValue);
900 std::optional<int64_t> DimOp::getConstantIndex() {
909 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
910 if (!rankedSourceType)
924 std::map<int64_t, unsigned> numOccurences;
925 for (
auto val : vals)
926 numOccurences[val]++;
927 return numOccurences;
937 static FailureOr<llvm::SmallBitVector>
940 llvm::SmallBitVector unusedDims(originalType.getRank());
941 if (originalType.getRank() == reducedType.getRank())
945 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
946 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
947 unusedDims.set(dim.index());
951 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
952 originalType.getRank())
956 int64_t originalOffset, candidateOffset;
972 std::map<int64_t, unsigned> currUnaccountedStrides =
974 std::map<int64_t, unsigned> candidateStridesNumOccurences =
976 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
977 if (!unusedDims.test(dim))
979 int64_t originalStride = originalStrides[dim];
980 if (currUnaccountedStrides[originalStride] >
981 candidateStridesNumOccurences[originalStride]) {
983 currUnaccountedStrides[originalStride]--;
986 if (currUnaccountedStrides[originalStride] ==
987 candidateStridesNumOccurences[originalStride]) {
989 unusedDims.reset(dim);
992 if (currUnaccountedStrides[originalStride] <
993 candidateStridesNumOccurences[originalStride]) {
1000 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1001 originalType.getRank())
1007 MemRefType sourceType = getSourceType();
1008 MemRefType resultType =
getType();
1009 FailureOr<llvm::SmallBitVector> unusedDims =
1011 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1017 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1022 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1028 int64_t indexVal = index.getInt();
1029 if (indexVal < 0 || indexVal >= memrefType.getRank())
1033 if (!memrefType.isDynamicDim(index.getInt())) {
1035 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1039 unsigned unsignedIndex = index.getValue().getZExtValue();
1042 Operation *definingOp = getSource().getDefiningOp();
1044 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1045 return *(alloc.getDynamicSizes().begin() +
1046 memrefType.getDynamicDimIndex(unsignedIndex));
1048 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1049 return *(alloca.getDynamicSizes().begin() +
1050 memrefType.getDynamicDimIndex(unsignedIndex));
1052 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1053 return *(view.getDynamicSizes().begin() +
1054 memrefType.getDynamicDimIndex(unsignedIndex));
1056 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1057 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1058 unsigned resultIndex = 0;
1059 unsigned sourceRank = subview.getSourceType().getRank();
1060 unsigned sourceIndex = 0;
1061 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1062 if (unusedDims.test(i))
1064 if (resultIndex == unsignedIndex) {
1070 assert(subview.isDynamicSize(sourceIndex) &&
1071 "expected dynamic subview size");
1072 return subview.getDynamicSize(sourceIndex);
1075 if (
auto sizeInterface =
1076 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1077 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1078 "Expected dynamic subview size");
1079 return sizeInterface.getDynamicSize(unsignedIndex);
1095 LogicalResult matchAndRewrite(DimOp dim,
1097 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1101 dim,
"Dim op is not defined by a reshape op.");
1112 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1113 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1114 if (reshape->isBeforeInBlock(definingOp)) {
1117 "dim.getIndex is not defined before reshape in the same block.");
1122 else if (dim->getBlock() != reshape->getBlock() &&
1123 !dim.getIndex().getParentRegion()->isProperAncestor(
1124 reshape->getParentRegion())) {
1129 dim,
"dim.getIndex does not dominate reshape.");
1137 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1138 if (load.
getType() != dim.getType())
1139 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1149 results.
add<DimOfMemRefReshape>(context);
1160 Value elementsPerStride) {
1172 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1173 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1174 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1176 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1179 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1180 <<
", " << getTagMemRef().getType();
1221 bool isStrided = strideInfo.size() == 2;
1222 if (!strideInfo.empty() && !
isStrided) {
1224 "expected two stride related operands");
1229 if (types.size() != 3)
1252 unsigned numOperands = getNumOperands();
1256 if (numOperands < 4)
1257 return emitOpError(
"expected at least 4 operands");
1262 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1263 return emitOpError(
"expected source to be of memref type");
1264 if (numOperands < getSrcMemRefRank() + 4)
1265 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1267 if (!getSrcIndices().empty() &&
1268 !llvm::all_of(getSrcIndices().getTypes(),
1270 return emitOpError(
"expected source indices to be of index type");
1273 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1274 return emitOpError(
"expected destination to be of memref type");
1275 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1276 if (numOperands < numExpectedOperands)
1277 return emitOpError() <<
"expected at least " << numExpectedOperands
1279 if (!getDstIndices().empty() &&
1280 !llvm::all_of(getDstIndices().getTypes(),
1282 return emitOpError(
"expected destination indices to be of index type");
1286 return emitOpError(
"expected num elements to be of index type");
1289 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1290 return emitOpError(
"expected tag to be of memref type");
1291 numExpectedOperands += getTagMemRefRank();
1292 if (numOperands < numExpectedOperands)
1293 return emitOpError() <<
"expected at least " << numExpectedOperands
1295 if (!getTagIndices().empty() &&
1296 !llvm::all_of(getTagIndices().getTypes(),
1298 return emitOpError(
"expected tag indices to be of index type");
1302 if (numOperands != numExpectedOperands &&
1303 numOperands != numExpectedOperands + 2)
1304 return emitOpError(
"incorrect number of operands");
1308 if (!getStride().
getType().isIndex() ||
1309 !getNumElementsPerStride().
getType().isIndex())
1311 "expected stride and num elements per stride to be of type index");
1317 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1327 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1335 unsigned numTagIndices = getTagIndices().size();
1336 unsigned tagMemRefRank = getTagMemRefRank();
1337 if (numTagIndices != tagMemRefRank)
1338 return emitOpError() <<
"expected tagIndices to have the same number of "
1339 "elements as the tagMemRef rank, expected "
1340 << tagMemRefRank <<
", but got " << numTagIndices;
1348 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1350 setNameFn(getResult(),
"intptr");
1359 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1360 MLIRContext *context, std::optional<Location> location,
1361 ExtractStridedMetadataOp::Adaptor adaptor,
1363 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1367 unsigned sourceRank = sourceType.getRank();
1371 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1373 inferredReturnTypes.push_back(memrefType);
1375 inferredReturnTypes.push_back(indexType);
1377 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1378 inferredReturnTypes.push_back(indexType);
1382 void ExtractStridedMetadataOp::getAsmResultNames(
1384 setNameFn(getBaseBuffer(),
"base_buffer");
1385 setNameFn(getOffset(),
"offset");
1388 if (!getSizes().empty()) {
1389 setNameFn(getSizes().front(),
"sizes");
1390 setNameFn(getStrides().front(),
"strides");
1397 template <
typename Container>
1401 assert(values.size() == maybeConstants.size() &&
1402 " expected values and maybeConstants of the same size");
1403 bool atLeastOneReplacement =
false;
1404 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1409 assert(maybeConstant.template is<Attribute>() &&
1410 "The constified value should be either unchanged (i.e., == result) "
1412 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1413 loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1415 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1419 atLeastOneReplacement =
true;
1422 return atLeastOneReplacement;
1426 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1432 getConstifiedMixedOffset());
1434 getConstifiedMixedSizes());
1436 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1438 return success(atLeastOneReplacement);
1449 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1456 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1474 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1475 Type elementType = memrefType.getElementType();
1485 auto &body = getRegion();
1486 if (body.getNumArguments() != 1)
1487 return emitOpError(
"expected single number of entry block arguments");
1489 if (getResult().
getType() != body.getArgument(0).getType())
1490 return emitOpError(
"expected block argument of the same type result type");
1497 "body of 'memref.generic_atomic_rmw' should contain "
1498 "only operations with no side effects");
1528 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1529 <<
"] : " << getMemref().
getType() <<
' ';
1539 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1540 Type resultType = getResult().getType();
1541 if (parentType != resultType)
1542 return emitOpError() <<
"types mismatch between yield op: " << resultType
1543 <<
" and its parent: " << parentType;
1555 if (!op.isExternal()) {
1557 if (op.isUninitialized())
1558 p <<
"uninitialized";
1571 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1572 if (!memrefType || !memrefType.hasStaticShape())
1574 <<
"type should be static shaped memref, but got " << type;
1588 if (!llvm::isa<ElementsAttr>(initialValue))
1590 <<
"initial value should be a unit or elements attribute";
1595 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1596 if (!memrefType || !memrefType.hasStaticShape())
1597 return emitOpError(
"type should be static shaped memref, but got ")
1602 if (getInitialValue().has_value()) {
1603 Attribute initValue = getInitialValue().value();
1604 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1605 return emitOpError(
"initial value should be a unit or elements "
1606 "attribute, but got ")
1611 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1612 Type initType = elementsAttr.getType();
1614 if (initType != tensorType)
1615 return emitOpError(
"initial value expected to be of type ")
1616 << tensorType <<
", but was of type " << initType;
1620 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1621 uint64_t alignment = *alignAttr;
1623 if (!llvm::isPowerOf2_64(alignment))
1624 return emitError() <<
"alignment attribute value " << alignment
1625 <<
" is not a power of 2";
1632 ElementsAttr GlobalOp::getConstantInitValue() {
1633 auto initVal = getInitialValue();
1634 if (getConstant() && initVal.has_value())
1635 return llvm::cast<ElementsAttr>(initVal.value());
1650 return emitOpError(
"'")
1651 << getName() <<
"' does not reference a valid global memref";
1653 Type resultType = getResult().getType();
1654 if (global.getType() != resultType)
1655 return emitOpError(
"result type ")
1656 << resultType <<
" does not match type " << global.getType()
1657 <<
" of the global memref @" << getName();
1667 return emitOpError(
"incorrect number of indices for load, expected ")
1684 void MemorySpaceCastOp::getAsmResultNames(
1686 setNameFn(getResult(),
"memspacecast");
1690 if (inputs.size() != 1 || outputs.size() != 1)
1692 Type a = inputs.front(), b = outputs.front();
1693 auto aT = llvm::dyn_cast<MemRefType>(a);
1694 auto bT = llvm::dyn_cast<MemRefType>(b);
1696 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1697 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1700 if (aT.getElementType() != bT.getElementType())
1702 if (aT.getLayout() != bT.getLayout())
1704 if (aT.getShape() != bT.getShape())
1709 return uaT.getElementType() == ubT.getElementType();
1714 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1717 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1718 getSourceMutable().assign(parentCast.getSource());
1729 p <<
" " << getMemref() <<
'[';
1731 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1732 p <<
", locality<" << getLocalityHint();
1733 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1735 (*this)->getAttrs(),
1736 {
"localityHint",
"isWrite",
"isDataCache"});
1743 IntegerAttr localityHint;
1745 StringRef readOrWrite, cacheType;
1762 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1764 "rw specifier has to be 'read' or 'write'");
1765 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1768 if (cacheType !=
"data" && cacheType !=
"instr")
1770 "cache type has to be 'data' or 'instr'");
1772 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1780 return emitOpError(
"too few indices");
1785 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1797 auto type = getOperand().getType();
1798 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1799 if (shapedType && shapedType.hasRank())
1801 return IntegerAttr();
1808 void ReinterpretCastOp::getAsmResultNames(
1810 setNameFn(getResult(),
"reinterpret_cast");
1817 MemRefType resultType,
Value source,
1827 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1838 auto sourceType = cast<BaseMemRefType>(source.
getType());
1845 b.
getContext(), staticOffsets.front(), staticStrides);
1846 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1847 stridedLayout, sourceType.getMemorySpace());
1848 build(b, result, resultType, source, offset, sizes, strides, attrs);
1852 MemRefType resultType,
Value source,
1857 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1861 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1865 strideValues, attrs);
1869 MemRefType resultType,
Value source,
Value offset,
1876 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1883 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1884 auto resultType = llvm::cast<MemRefType>(
getType());
1885 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1886 return emitError(
"different memory spaces specified for source type ")
1887 << srcType <<
" and result memref type " << resultType;
1888 if (srcType.getElementType() != resultType.getElementType())
1889 return emitError(
"different element types specified for source type ")
1890 << srcType <<
" and result memref type " << resultType;
1893 for (
auto [idx, resultSize, expectedSize] :
1895 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1896 return emitError(
"expected result type with size = ")
1897 << (ShapedType::isDynamic(expectedSize)
1898 ? std::string(
"dynamic")
1899 : std::to_string(expectedSize))
1900 <<
" instead of " << resultSize <<
" in dim = " << idx;
1906 int64_t resultOffset;
1909 return emitError(
"expected result type to have strided layout but found ")
1913 int64_t expectedOffset = getStaticOffsets().front();
1914 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1915 return emitError(
"expected result type with offset = ")
1916 << (ShapedType::isDynamic(expectedOffset)
1917 ? std::string(
"dynamic")
1918 : std::to_string(expectedOffset))
1919 <<
" instead of " << resultOffset;
1922 for (
auto [idx, resultStride, expectedStride] :
1924 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1925 return emitError(
"expected result type with stride = ")
1926 << (ShapedType::isDynamic(expectedStride)
1927 ? std::string(
"dynamic")
1928 : std::to_string(expectedStride))
1929 <<
" instead of " << resultStride <<
" in dim = " << idx;
1936 Value src = getSource();
1937 auto getPrevSrc = [&]() ->
Value {
1940 return prev.getSource();
1944 return prev.getSource();
1949 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1950 return isConstantIntValue(val, 0);
1952 return prev.getSource();
1957 if (
auto prevSrc = getPrevSrc()) {
1958 getSourceMutable().assign(prevSrc);
1974 ShapedType::isDynamic);
1981 ShapedType::isDynamic);
1985 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1987 assert(values.size() == 1 &&
1988 "reinterpret_cast must have one and only one offset");
1990 ShapedType::isDynamic);
2032 struct ReinterpretCastOpExtractStridedMetadataFolder
2037 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2039 auto extractStridedMetadata =
2040 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2041 if (!extractStridedMetadata)
2048 extractStridedMetadata.getConstifiedMixedStrides();
2050 op.getConstifiedMixedStrides();
2051 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2054 unsigned rank = op.getType().getRank();
2055 for (
unsigned i = 0; i < rank; ++i) {
2056 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2061 assert(extractStridedMetadata.getSizes().size() ==
2062 op.getMixedSizes().size() &&
2063 "Strides and sizes rank must match");
2065 extractStridedMetadata.getConstifiedMixedSizes();
2067 op.getConstifiedMixedSizes();
2068 for (
unsigned i = 0; i < rank; ++i) {
2069 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2073 assert(op.getMixedOffsets().size() == 1 &&
2074 "reinterpret_cast with more than one offset should have been "
2075 "rejected by the verifier");
2077 extractStridedMetadata.getConstifiedMixedOffset();
2078 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2079 if (extractOffsetOfr != reinterpretOffsetOfr)
2087 Type srcTy = extractStridedMetadata.getSource().getType();
2088 if (srcTy == op.getResult().getType())
2089 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2092 extractStridedMetadata.getSource());
2101 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2108 void CollapseShapeOp::getAsmResultNames(
2110 setNameFn(getResult(),
"collapse_shape");
2113 void ExpandShapeOp::getAsmResultNames(
2115 setNameFn(getResult(),
"expand_shape");
2120 reifiedResultShapes = {
2121 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2130 static LogicalResult
2134 bool allowMultipleDynamicDimsPerGroup) {
2136 if (collapsedShape.size() != reassociation.size())
2137 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2138 << reassociation.size() <<
", expected " << collapsedShape.size();
2142 int64_t nextDim = 0;
2145 int64_t collapsedDim = it.index();
2147 bool foundDynamic =
false;
2148 for (int64_t expandedDim : group) {
2149 if (expandedDim != nextDim++)
2150 return op->
emitOpError(
"reassociation indices must be contiguous");
2152 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2154 << expandedDim <<
" is out of bounds";
2157 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2158 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2160 "at most one dimension in a reassociation group may be dynamic");
2161 foundDynamic =
true;
2166 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2169 <<
") must be dynamic if and only if reassociation group is "
2174 if (!foundDynamic) {
2175 int64_t groupSize = 1;
2176 for (int64_t expandedDim : group)
2177 groupSize *= expandedShape[expandedDim];
2178 if (groupSize != collapsedShape[collapsedDim])
2180 << collapsedShape[collapsedDim]
2181 <<
") must equal reassociation group size (" << groupSize <<
")";
2185 if (collapsedShape.empty()) {
2187 for (int64_t d : expandedShape)
2190 "rank 0 memrefs can only be extended/collapsed with/from ones");
2191 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2195 << expandedShape.size()
2196 <<
") inconsistent with number of reassociation indices (" << nextDim
2209 getReassociationIndices());
2218 getReassociationIndices());
2223 static FailureOr<StridedLayoutAttr>
2230 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2245 reverseResultStrides.reserve(resultShape.size());
2246 unsigned shapeIndex = resultShape.size() - 1;
2247 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2249 int64_t currentStrideToExpand = std::get<1>(it);
2250 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2251 reverseResultStrides.push_back(currentStrideToExpand);
2252 currentStrideToExpand =
2258 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2259 resultStrides.resize(resultShape.size(), 1);
2263 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2266 if (srcType.getLayout().isIdentity()) {
2269 MemRefLayoutAttrInterface layout;
2271 srcType.getMemorySpace());
2275 FailureOr<StridedLayoutAttr> computedLayout =
2277 if (failed(computedLayout))
2279 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2280 srcType.getMemorySpace());
2283 FailureOr<SmallVector<OpFoldResult>>
2285 MemRefType expandedType,
2288 std::optional<SmallVector<OpFoldResult>> outputShape =
2293 return *outputShape;
2300 auto [staticOutputShape, dynamicOutputShape] =
2302 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2304 dynamicOutputShape, staticOutputShape);
2312 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2313 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2314 builder, result.
location, memrefResultTy, reassociation, inputShape);
2317 assert(succeeded(outputShape) &&
"unable to infer output shape");
2318 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2325 auto srcType = llvm::cast<MemRefType>(src.
getType());
2326 FailureOr<MemRefType> resultType =
2327 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2330 assert(succeeded(resultType) &&
"could not compute layout");
2331 build(builder, result, *resultType, src, reassociation);
2339 auto srcType = llvm::cast<MemRefType>(src.
getType());
2340 FailureOr<MemRefType> resultType =
2341 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2344 assert(succeeded(resultType) &&
"could not compute layout");
2345 build(builder, result, *resultType, src, reassociation, outputShape);
2349 MemRefType srcType = getSrcType();
2350 MemRefType resultType = getResultType();
2352 if (srcType.getRank() > resultType.getRank()) {
2353 auto r0 = srcType.getRank();
2354 auto r1 = resultType.getRank();
2355 return emitOpError(
"has source rank ")
2356 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2357 << r0 <<
" > " << r1 <<
").";
2362 resultType.getShape(),
2363 getReassociationIndices(),
2368 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2369 srcType, resultType.getShape(), getReassociationIndices());
2370 if (failed(expectedResultType))
2371 return emitOpError(
"invalid source layout map");
2374 if (*expectedResultType != resultType)
2375 return emitOpError(
"expected expanded type to be ")
2376 << *expectedResultType <<
" but found " << resultType;
2378 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2379 return emitOpError(
"expected number of static shape bounds to be equal to "
2380 "the output rank (")
2381 << resultType.getRank() <<
") but found "
2382 << getStaticOutputShape().size() <<
" inputs instead";
2384 if ((int64_t)getOutputShape().size() !=
2385 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2386 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2387 "static_output_shape: static_output_shape has ")
2388 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2389 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2396 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2397 return emitOpError(
"invalid output shape provided at pos ") << pos;
2418 static FailureOr<StridedLayoutAttr>
2421 bool strict =
false) {
2424 auto srcShape = srcType.getShape();
2434 resultStrides.reserve(reassociation.size());
2437 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2438 ref = ref.drop_back();
2439 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2440 resultStrides.push_back(srcStrides[ref.back()]);
2446 resultStrides.push_back(ShapedType::kDynamic);
2451 unsigned resultStrideIndex = resultStrides.size() - 1;
2455 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2467 if (strict && (stride.saturated || srcStride.saturated))
2472 if (srcShape[idx - 1] == 1)
2475 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2482 bool CollapseShapeOp::isGuaranteedCollapsible(
2485 if (srcType.getLayout().isIdentity())
2492 MemRefType CollapseShapeOp::computeCollapsedType(
2495 resultShape.reserve(reassociation.size());
2498 for (int64_t srcDim : group)
2501 resultShape.push_back(groupSize.asInteger());
2504 if (srcType.getLayout().isIdentity()) {
2507 MemRefLayoutAttrInterface layout;
2509 srcType.getMemorySpace());
2515 FailureOr<StridedLayoutAttr> computedLayout =
2517 assert(succeeded(computedLayout) &&
2518 "invalid source layout map or collapsing non-contiguous dims");
2519 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2520 srcType.getMemorySpace());
2526 auto srcType = llvm::cast<MemRefType>(src.
getType());
2527 MemRefType resultType =
2528 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2531 build(b, result, resultType, src, attrs);
2535 MemRefType srcType = getSrcType();
2536 MemRefType resultType = getResultType();
2538 if (srcType.getRank() < resultType.getRank()) {
2539 auto r0 = srcType.getRank();
2540 auto r1 = resultType.getRank();
2541 return emitOpError(
"has source rank ")
2542 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2543 << r0 <<
" < " << r1 <<
").";
2548 srcType.getShape(), getReassociationIndices(),
2553 MemRefType expectedResultType;
2554 if (srcType.getLayout().isIdentity()) {
2557 MemRefLayoutAttrInterface layout;
2558 expectedResultType =
2559 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2560 srcType.getMemorySpace());
2565 FailureOr<StridedLayoutAttr> computedLayout =
2567 if (failed(computedLayout))
2569 "invalid source layout map or collapsing non-contiguous dims");
2570 expectedResultType =
2572 *computedLayout, srcType.getMemorySpace());
2575 if (expectedResultType != resultType)
2576 return emitOpError(
"expected collapsed type to be ")
2577 << expectedResultType <<
" but found " << resultType;
2589 auto cast = op.getOperand().getDefiningOp<CastOp>();
2596 Type newResultType = CollapseShapeOp::computeCollapsedType(
2597 llvm::cast<MemRefType>(cast.getOperand().getType()),
2598 op.getReassociationIndices());
2600 if (newResultType == op.getResultType()) {
2602 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2605 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2617 memref::DimOp, MemRefType>,
2621 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2622 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2623 adaptor.getOperands());
2626 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2627 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2628 adaptor.getOperands());
2635 void ReshapeOp::getAsmResultNames(
2637 setNameFn(getResult(),
"reshape");
2641 Type operandType = getSource().getType();
2642 Type resultType = getResult().getType();
2644 Type operandElementType =
2645 llvm::cast<ShapedType>(operandType).getElementType();
2646 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2647 if (operandElementType != resultElementType)
2648 return emitOpError(
"element types of source and destination memref "
2649 "types should be the same");
2651 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2652 if (!operandMemRefType.getLayout().isIdentity())
2653 return emitOpError(
"source memref type should have identity affine map");
2657 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2658 if (resultMemRefType) {
2659 if (!resultMemRefType.getLayout().isIdentity())
2660 return emitOpError(
"result memref type should have identity affine map");
2661 if (shapeSize == ShapedType::kDynamic)
2662 return emitOpError(
"cannot use shape operand with dynamic length to "
2663 "reshape to statically-ranked memref type");
2664 if (shapeSize != resultMemRefType.getRank())
2666 "length of shape operand differs from the result's memref rank");
2677 return emitOpError(
"store index operand count not equal to memref rank");
2682 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2692 void SubViewOp::getAsmResultNames(
2694 setNameFn(getResult(),
"subview");
2700 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2704 unsigned rank = sourceMemRefType.getRank();
2706 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2707 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2708 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2715 int64_t targetOffset = sourceOffset;
2716 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2717 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2727 targetStrides.reserve(staticOffsets.size());
2728 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2729 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2736 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2738 targetOffset, targetStrides),
2739 sourceMemRefType.getMemorySpace());
2742 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2757 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2758 staticSizes, staticStrides);
2762 MemRefType sourceRankedTensorType,
2766 auto inferredType = llvm::cast<MemRefType>(
2767 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2768 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2770 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2771 return inferredType;
2774 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2776 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2779 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2781 rankReducedStrides.reserve(resultShape.size());
2782 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2783 if (!dimsToProject->contains(idx))
2784 rankReducedStrides.push_back(value);
2788 inferredLayout.getOffset(),
2789 rankReducedStrides),
2790 inferredType.getMemorySpace());
2794 MemRefType sourceRankedTensorType,
2803 return SubViewOp::inferRankReducedResultType(
2804 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2811 MemRefType resultType,
Value source,
2821 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2824 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2825 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2828 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2841 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2850 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2854 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2858 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2861 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2867 MemRefType resultType,
Value source,
2872 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2876 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2880 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2883 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2899 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2906 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2910 Value SubViewOp::getViewSource() {
return getSource(); }
2915 int64_t t1Offset, t2Offset;
2919 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2926 const llvm::SmallBitVector &droppedDims) {
2927 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2928 "incorrect number of bits");
2929 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2930 "incorrect number of dropped dims");
2931 int64_t t1Offset, t2Offset;
2935 if (failed(res1) || failed(res2))
2937 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2940 if (t1Strides[i] != t2Strides[
j])
2949 auto memrefType = llvm::cast<ShapedType>(expectedType);
2954 return op->
emitError(
"expected result rank to be smaller or equal to ")
2955 <<
"the source rank. ";
2957 return op->
emitError(
"expected result type to be ")
2959 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2961 return op->
emitError(
"expected result element type to be ")
2962 << memrefType.getElementType();
2964 return op->
emitError(
"expected result and source memory spaces to match.");
2966 return op->
emitError(
"expected result type to be ")
2968 <<
" or a rank-reduced version. (mismatch of result layout) ";
2970 llvm_unreachable(
"unexpected subview verification result");
2975 MemRefType baseType = getSourceType();
2976 MemRefType subViewType =
getType();
2979 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2980 return emitError(
"different memory spaces specified for base memref "
2982 << baseType <<
" and subview memref type " << subViewType;
2986 return emitError(
"base type ") << baseType <<
" is not strided";
2990 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2991 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2996 expectedType, subViewType);
3001 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3003 *
this, expectedType);
3008 *
this, expectedType);
3016 if (failed(unusedDims))
3018 *
this, expectedType);
3023 *
this, expectedType);
3029 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3038 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3039 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3040 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3042 unsigned rank = ranks[0];
3044 for (
unsigned idx = 0; idx < rank; ++idx) {
3046 op.isDynamicOffset(idx)
3047 ? op.getDynamicOffset(idx)
3050 op.isDynamicSize(idx)
3051 ? op.getDynamicSize(idx)
3054 op.isDynamicStride(idx)
3055 ? op.getDynamicStride(idx)
3057 res.emplace_back(
Range{offset, size, stride});
3070 MemRefType currentResultType, MemRefType currentSourceType,
3073 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3074 sourceType, mixedOffsets, mixedSizes, mixedStrides));
3076 currentSourceType, currentResultType, mixedSizes);
3077 if (failed(unusedDims))
3080 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3082 unsigned numDimsAfterReduction =
3083 nonRankReducedType.getRank() - unusedDims->count();
3084 shape.reserve(numDimsAfterReduction);
3085 strides.reserve(numDimsAfterReduction);
3086 for (
const auto &[idx, size, stride] :
3087 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3088 nonRankReducedType.getShape(), layout.getStrides())) {
3089 if (unusedDims->test(idx))
3091 shape.push_back(size);
3092 strides.push_back(stride);
3097 layout.getOffset(), strides),
3098 nonRankReducedType.getMemorySpace());
3103 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3104 unsigned rank = memrefType.getRank();
3109 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3110 targetShape, memrefType, offsets, sizes, strides));
3111 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3118 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3119 assert(sourceMemrefType &&
"not a ranked memref type");
3120 auto sourceShape = sourceMemrefType.getShape();
3121 if (sourceShape.equals(desiredShape))
3123 auto maybeRankReductionMask =
3125 if (!maybeRankReductionMask)
3135 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3138 auto mixedOffsets = subViewOp.getMixedOffsets();
3139 auto mixedSizes = subViewOp.getMixedSizes();
3140 auto mixedStrides = subViewOp.getMixedStrides();
3145 return !intValue || intValue.value() != 0;
3152 return !intValue || intValue.value() != 1;
3160 if (!intValue || *intValue != sourceShape[size.index()])
3184 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3188 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3192 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3193 return matchPattern(operand, matchConstantIndex());
3197 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3209 subViewOp.getType(), subViewOp.getSourceType(),
3210 llvm::cast<MemRefType>(castOp.getSource().getType()),
3211 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3212 subViewOp.getMixedStrides());
3217 subViewOp.getLoc(), resultType, castOp.getSource(),
3218 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3219 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3220 subViewOp.getStaticStrides());
3233 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3237 if (subViewOp.getSourceType() == subViewOp.getType()) {
3238 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3242 subViewOp.getSource());
3254 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3255 mixedSizes, mixedStrides);
3258 MemRefType nonReducedType = cast<MemRefType>(resTy);
3261 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3262 if (droppedDims.none())
3263 return nonReducedType;
3271 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3272 if (droppedDims.test(i))
3274 targetStrides.push_back(nonReducedStrides[i]);
3275 targetShape.push_back(nonReducedType.getDimSize(i));
3280 offset, targetStrides),
3281 nonReducedType.getMemorySpace());
3297 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3301 MemRefType sourceMemrefType = getSource().getType();
3302 MemRefType resultMemrefType = getResult().getType();
3304 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3306 if (resultMemrefType == sourceMemrefType &&
3307 resultMemrefType.hasStaticShape() &&
3308 (!resultLayout || resultLayout.hasStaticLayout())) {
3309 return getViewSource();
3315 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3316 auto srcSizes = srcSubview.getMixedSizes();
3318 auto offsets = getMixedOffsets();
3319 bool allOffsetsZero = llvm::all_of(
3321 auto strides = getMixedStrides();
3322 bool allStridesOne = llvm::all_of(
3324 bool allSizesSame = llvm::equal(sizes, srcSizes);
3325 if (allOffsetsZero && allStridesOne && allSizesSame &&
3326 resultMemrefType == sourceMemrefType)
3327 return getViewSource();
3337 void TransposeOp::getAsmResultNames(
3339 setNameFn(getResult(),
"transpose");
3345 auto originalSizes = memRefType.getShape();
3347 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3350 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3351 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3360 AffineMapAttr permutation,
3362 auto permutationMap = permutation.getValue();
3363 assert(permutationMap);
3365 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3369 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3370 build(b, result, resultType, in, attrs);
3375 p <<
" " << getIn() <<
" " << getPermutation();
3377 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3383 MemRefType srcType, dstType;
3392 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3399 return emitOpError(
"expected a permutation map");
3400 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3401 return emitOpError(
"expected a permutation map of same rank as the input");
3403 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3404 auto resultType = llvm::cast<MemRefType>(
getType());
3409 return emitOpError(
"result type ")
3411 <<
" is not equivalent to the canonical transposed input type "
3412 << canonicalResultType;
3419 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3423 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3425 getPermutation().
compose(otherTransposeOp.getPermutation());
3426 getInMutable().assign(otherTransposeOp.getIn());
3427 setPermutation(composedPermutation);
3437 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3438 setNameFn(getResult(),
"view");
3442 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3446 if (!baseType.getLayout().isIdentity())
3447 return emitError(
"unsupported map for base memref type ") << baseType;
3450 if (!viewType.getLayout().isIdentity())
3451 return emitError(
"unsupported map for result memref type ") << viewType;
3454 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3455 return emitError(
"different memory spaces specified for base memref "
3457 << baseType <<
" and view memref type " << viewType;
3460 unsigned numDynamicDims = viewType.getNumDynamicDims();
3461 if (getSizes().size() != numDynamicDims)
3462 return emitError(
"incorrect number of size operands for type ") << viewType;
3467 Value ViewOp::getViewSource() {
return getSource(); }
3474 LogicalResult matchAndRewrite(ViewOp viewOp,
3477 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3478 return matchPattern(operand, matchConstantIndex());
3483 auto memrefType = viewOp.getType();
3490 assert(oldOffset == 0 &&
"Expected 0 offset");
3498 newShapeConstants.reserve(memrefType.getRank());
3500 unsigned dynamicDimPos = 0;
3501 unsigned rank = memrefType.getRank();
3502 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3503 int64_t dimSize = memrefType.getDimSize(dim);
3505 if (!ShapedType::isDynamic(dimSize)) {
3506 newShapeConstants.push_back(dimSize);
3509 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3510 if (
auto constantIndexOp =
3511 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3513 newShapeConstants.push_back(constantIndexOp.value());
3516 newShapeConstants.push_back(dimSize);
3517 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3523 MemRefType newMemRefType =
3526 if (newMemRefType == memrefType)
3530 auto newViewOp = rewriter.
create<ViewOp>(
3531 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3532 viewOp.getByteShift(), newOperands);
3542 LogicalResult matchAndRewrite(ViewOp viewOp,
3544 Value memrefOperand = viewOp.getOperand(0);
3545 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3548 Value allocOperand = memrefCastOp.getOperand();
3553 viewOp.getByteShift(),
3563 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3573 "expects the number of subscripts to be equal to memref rank");
3574 switch (getKind()) {
3575 case arith::AtomicRMWKind::addf:
3576 case arith::AtomicRMWKind::maximumf:
3577 case arith::AtomicRMWKind::minimumf:
3578 case arith::AtomicRMWKind::mulf:
3579 if (!llvm::isa<FloatType>(getValue().
getType()))
3580 return emitOpError() <<
"with kind '"
3581 << arith::stringifyAtomicRMWKind(getKind())
3582 <<
"' expects a floating-point type";
3584 case arith::AtomicRMWKind::addi:
3585 case arith::AtomicRMWKind::maxs:
3586 case arith::AtomicRMWKind::maxu:
3587 case arith::AtomicRMWKind::mins:
3588 case arith::AtomicRMWKind::minu:
3589 case arith::AtomicRMWKind::muli:
3590 case arith::AtomicRMWKind::ori:
3591 case arith::AtomicRMWKind::andi:
3592 if (!llvm::isa<IntegerType>(getValue().
getType()))
3593 return emitOpError() <<
"with kind '"
3594 << arith::stringifyAtomicRMWKind(getKind())
3595 <<
"' expects an integer type";
3614 #define GET_OP_CLASSES
3615 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.