23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
34 return arith::ConstantOp::materialize(builder, value, type, loc);
47 auto cast = operand.get().getDefiningOp<CastOp>();
48 if (cast && operand.get() != inner &&
49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50 operand.set(cast.getOperand());
54 return success(folded);
60 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
62 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
69 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
123 int64_t constValue = it.value();
124 if (!isDynamic(constValue))
143 llvm::cast<IntegerAttr>(ofr.get<
Attribute>()).getInt());
146 std::optional<int64_t> maybeConstant =
166 LogicalResult hasStaticInformation =
168 if (failed(hasStaticInformation))
179 LogicalResult hasStaticInformation =
181 if (failed(hasStaticInformation))
190 void AllocOp::getAsmResultNames(
192 setNameFn(getResult(),
"alloc");
195 void AllocaOp::getAsmResultNames(
197 setNameFn(getResult(),
"alloca");
200 template <
typename AllocLikeOp>
202 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203 "applies to only alloc or alloca");
208 if (
static_cast<int64_t
>(op.getDynamicSizes().size()) !=
209 memRefType.getNumDynamicDims())
210 return op.
emitOpError(
"dimension operand count does not equal memref "
211 "dynamic dimension count");
213 unsigned numSymbols = 0;
214 if (!memRefType.getLayout().isIdentity())
215 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216 if (op.getSymbolOperands().size() != numSymbols)
217 return op.
emitOpError(
"symbol operand count does not equal memref symbol "
219 << numSymbols <<
", got " << op.getSymbolOperands().size();
230 "requires an ancestor op with AutomaticAllocationScope trait");
237 template <
typename AllocLikeOp>
241 LogicalResult matchAndRewrite(AllocLikeOp alloc,
245 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
247 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
249 return constSizeArg.isNonNegative();
253 auto memrefType = alloc.getType();
258 newShapeConstants.reserve(memrefType.getRank());
261 unsigned dynamicDimPos = 0;
262 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263 int64_t dimSize = memrefType.getDimSize(dim);
265 if (!ShapedType::isDynamic(dimSize)) {
266 newShapeConstants.push_back(dimSize);
269 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
272 constSizeArg.isNonNegative()) {
274 newShapeConstants.push_back(constSizeArg.getZExtValue());
277 newShapeConstants.push_back(ShapedType::kDynamic);
278 dynamicSizes.push_back(dynamicSize);
284 MemRefType newMemRefType =
286 assert(
static_cast<int64_t
>(dynamicSizes.size()) ==
287 newMemRefType.getNumDynamicDims());
290 auto newAlloc = rewriter.
create<AllocLikeOp>(
291 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292 alloc.getAlignmentAttr());
300 template <
typename T>
304 LogicalResult matchAndRewrite(T alloc,
306 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
307 if (auto storeOp = dyn_cast<StoreOp>(op))
308 return storeOp.getValue() == alloc;
309 return !isa<DeallocOp>(op);
313 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
324 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
329 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
338 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
339 MemRefType resultType =
getType();
342 if (!sourceType.getLayout().isIdentity())
343 return emitError(
"unsupported layout for source memref type ")
347 if (!resultType.getLayout().isIdentity())
348 return emitError(
"unsupported layout for result memref type ")
352 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353 return emitError(
"different memory spaces specified for source memref "
355 << sourceType <<
" and result memref type " << resultType;
358 if (sourceType.getElementType() != resultType.getElementType())
359 return emitError(
"different element types specified for source memref "
361 << sourceType <<
" and result memref type " << resultType;
364 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365 return emitError(
"missing dimension operand for result type ")
367 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368 return emitError(
"unnecessary dimension operand for result type ")
376 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
384 bool printBlockTerminators =
false;
387 if (!getResults().empty()) {
388 p <<
" -> (" << getResultTypes() <<
")";
389 printBlockTerminators =
true;
394 printBlockTerminators);
410 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
420 void AllocaScopeOp::getSuccessorRegions(
433 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
439 if (isa<SideEffects::AutomaticAllocationScopeResource>(
440 effect->getResource()))
456 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
462 if (isa<SideEffects::AutomaticAllocationScopeResource>(
463 effect->getResource()))
486 bool hasPotentialAlloca =
499 if (hasPotentialAlloca) {
532 if (!lastParentWithoutScope ||
545 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
546 if (!lastParentWithoutScope ||
553 Region *containingRegion =
nullptr;
554 for (
auto &r : lastParentWithoutScope->
getRegions()) {
556 assert(containingRegion ==
nullptr &&
557 "only one region can contain the op");
558 containingRegion = &r;
561 assert(containingRegion &&
"op must be contained in a region");
571 return containingRegion->isAncestor(v.getParentRegion());
574 toHoist.push_back(alloc);
581 for (
auto *op : toHoist) {
582 auto *cloned = rewriter.
clone(*op);
599 if (!llvm::isPowerOf2_32(getAlignment()))
600 return emitOpError(
"alignment must be power of 2");
609 setNameFn(getResult(),
"cast");
650 MemRefType sourceType =
651 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
655 if (!sourceType || !resultType)
659 if (sourceType.getElementType() != resultType.getElementType())
663 if (sourceType.getRank() != resultType.getRank())
667 int64_t sourceOffset, resultOffset;
674 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675 auto ss = std::get<0>(it), st = std::get<1>(it);
677 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
682 if (sourceOffset != resultOffset)
683 if (ShapedType::isDynamic(sourceOffset) &&
684 !ShapedType::isDynamic(resultOffset))
688 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
689 auto ss = std::get<0>(it), st = std::get<1>(it);
691 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
699 if (inputs.size() != 1 || outputs.size() != 1)
701 Type a = inputs.front(), b = outputs.front();
702 auto aT = llvm::dyn_cast<MemRefType>(a);
703 auto bT = llvm::dyn_cast<MemRefType>(b);
705 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
709 if (aT.getElementType() != bT.getElementType())
711 if (aT.getLayout() != bT.getLayout()) {
712 int64_t aOffset, bOffset;
716 aStrides.size() != bStrides.size())
723 auto checkCompatible = [](int64_t a, int64_t b) {
724 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
726 if (!checkCompatible(aOffset, bOffset))
728 for (
const auto &aStride :
enumerate(aStrides))
729 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
732 if (aT.getMemorySpace() != bT.getMemorySpace())
736 if (aT.getRank() != bT.getRank())
739 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
755 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757 if (aEltType != bEltType)
760 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762 return aMemSpace == bMemSpace;
783 LogicalResult matchAndRewrite(CopyOp copyOp,
785 bool modified =
false;
788 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
792 if (fromType && toType) {
793 if (fromType.getShape() == toType.getShape() &&
794 fromType.getElementType() == toType.getElementType()) {
796 copyOp.getSourceMutable().assign(castOp.getSource());
804 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
808 if (fromType && toType) {
809 if (fromType.getShape() == toType.getShape() &&
810 fromType.getElementType() == toType.getElementType()) {
812 copyOp.getTargetMutable().assign(castOp.getSource());
819 return success(modified);
827 LogicalResult matchAndRewrite(CopyOp copyOp,
829 if (copyOp.getSource() != copyOp.getTarget())
842 llvm::any_of(type.
getShape(), [](int64_t x) { return x == 0; });
845 LogicalResult matchAndRewrite(CopyOp copyOp,
847 if (isEmptyMemRef(copyOp.getSource().getType()) ||
848 isEmptyMemRef(copyOp.getTarget().getType())) {
860 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
863 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
871 operand.set(castOp.getOperand());
875 return success(folded);
882 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
893 setNameFn(getResult(),
"dim");
899 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
900 build(builder, result, source, indexValue);
903 std::optional<int64_t> DimOp::getConstantIndex() {
912 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
913 if (!rankedSourceType)
927 std::map<int64_t, unsigned> numOccurences;
928 for (
auto val : vals)
929 numOccurences[val]++;
930 return numOccurences;
940 static FailureOr<llvm::SmallBitVector>
943 llvm::SmallBitVector unusedDims(originalType.getRank());
944 if (originalType.getRank() == reducedType.getRank())
948 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
949 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
950 unusedDims.set(dim.index());
954 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
955 originalType.getRank())
959 int64_t originalOffset, candidateOffset;
975 std::map<int64_t, unsigned> currUnaccountedStrides =
977 std::map<int64_t, unsigned> candidateStridesNumOccurences =
979 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
980 if (!unusedDims.test(dim))
982 int64_t originalStride = originalStrides[dim];
983 if (currUnaccountedStrides[originalStride] >
984 candidateStridesNumOccurences[originalStride]) {
986 currUnaccountedStrides[originalStride]--;
989 if (currUnaccountedStrides[originalStride] ==
990 candidateStridesNumOccurences[originalStride]) {
992 unusedDims.reset(dim);
995 if (currUnaccountedStrides[originalStride] <
996 candidateStridesNumOccurences[originalStride]) {
1003 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1004 originalType.getRank())
1010 MemRefType sourceType = getSourceType();
1011 MemRefType resultType =
getType();
1012 FailureOr<llvm::SmallBitVector> unusedDims =
1014 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1020 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1025 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1031 int64_t indexVal = index.getInt();
1032 if (indexVal < 0 || indexVal >= memrefType.getRank())
1036 if (!memrefType.isDynamicDim(index.getInt())) {
1038 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1042 unsigned unsignedIndex = index.getValue().getZExtValue();
1045 Operation *definingOp = getSource().getDefiningOp();
1047 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1048 return *(alloc.getDynamicSizes().begin() +
1049 memrefType.getDynamicDimIndex(unsignedIndex));
1051 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1052 return *(alloca.getDynamicSizes().begin() +
1053 memrefType.getDynamicDimIndex(unsignedIndex));
1055 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1056 return *(view.getDynamicSizes().begin() +
1057 memrefType.getDynamicDimIndex(unsignedIndex));
1059 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1060 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1061 unsigned resultIndex = 0;
1062 unsigned sourceRank = subview.getSourceType().getRank();
1063 unsigned sourceIndex = 0;
1064 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1065 if (unusedDims.test(i))
1067 if (resultIndex == unsignedIndex) {
1073 assert(subview.isDynamicSize(sourceIndex) &&
1074 "expected dynamic subview size");
1075 return subview.getDynamicSize(sourceIndex);
1078 if (
auto sizeInterface =
1079 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1080 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1081 "Expected dynamic subview size");
1082 return sizeInterface.getDynamicSize(unsignedIndex);
1098 LogicalResult matchAndRewrite(DimOp dim,
1100 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1104 dim,
"Dim op is not defined by a reshape op.");
1115 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1116 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1117 if (reshape->isBeforeInBlock(definingOp)) {
1120 "dim.getIndex is not defined before reshape in the same block.");
1125 else if (dim->getBlock() != reshape->getBlock() &&
1126 !dim.getIndex().getParentRegion()->isProperAncestor(
1127 reshape->getParentRegion())) {
1132 dim,
"dim.getIndex does not dominate reshape.");
1140 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1141 if (load.
getType() != dim.getType())
1142 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1152 results.
add<DimOfMemRefReshape>(context);
1163 Value elementsPerStride) {
1175 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1176 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1177 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1179 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1182 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1183 <<
", " << getTagMemRef().getType();
1224 bool isStrided = strideInfo.size() == 2;
1225 if (!strideInfo.empty() && !
isStrided) {
1227 "expected two stride related operands");
1232 if (types.size() != 3)
1255 unsigned numOperands = getNumOperands();
1259 if (numOperands < 4)
1260 return emitOpError(
"expected at least 4 operands");
1265 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1266 return emitOpError(
"expected source to be of memref type");
1267 if (numOperands < getSrcMemRefRank() + 4)
1268 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1270 if (!getSrcIndices().empty() &&
1271 !llvm::all_of(getSrcIndices().getTypes(),
1273 return emitOpError(
"expected source indices to be of index type");
1276 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1277 return emitOpError(
"expected destination to be of memref type");
1278 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1279 if (numOperands < numExpectedOperands)
1280 return emitOpError() <<
"expected at least " << numExpectedOperands
1282 if (!getDstIndices().empty() &&
1283 !llvm::all_of(getDstIndices().getTypes(),
1285 return emitOpError(
"expected destination indices to be of index type");
1289 return emitOpError(
"expected num elements to be of index type");
1292 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1293 return emitOpError(
"expected tag to be of memref type");
1294 numExpectedOperands += getTagMemRefRank();
1295 if (numOperands < numExpectedOperands)
1296 return emitOpError() <<
"expected at least " << numExpectedOperands
1298 if (!getTagIndices().empty() &&
1299 !llvm::all_of(getTagIndices().getTypes(),
1301 return emitOpError(
"expected tag indices to be of index type");
1305 if (numOperands != numExpectedOperands &&
1306 numOperands != numExpectedOperands + 2)
1307 return emitOpError(
"incorrect number of operands");
1311 if (!getStride().
getType().isIndex() ||
1312 !getNumElementsPerStride().
getType().isIndex())
1314 "expected stride and num elements per stride to be of type index");
1320 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1330 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1338 unsigned numTagIndices = getTagIndices().size();
1339 unsigned tagMemRefRank = getTagMemRefRank();
1340 if (numTagIndices != tagMemRefRank)
1341 return emitOpError() <<
"expected tagIndices to have the same number of "
1342 "elements as the tagMemRef rank, expected "
1343 << tagMemRefRank <<
", but got " << numTagIndices;
1351 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1353 setNameFn(getResult(),
"intptr");
1362 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1363 MLIRContext *context, std::optional<Location> location,
1364 ExtractStridedMetadataOp::Adaptor adaptor,
1366 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1370 unsigned sourceRank = sourceType.getRank();
1374 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1376 inferredReturnTypes.push_back(memrefType);
1378 inferredReturnTypes.push_back(indexType);
1380 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1381 inferredReturnTypes.push_back(indexType);
1385 void ExtractStridedMetadataOp::getAsmResultNames(
1387 setNameFn(getBaseBuffer(),
"base_buffer");
1388 setNameFn(getOffset(),
"offset");
1391 if (!getSizes().empty()) {
1392 setNameFn(getSizes().front(),
"sizes");
1393 setNameFn(getStrides().front(),
"strides");
1400 template <
typename Container>
1404 assert(values.size() == maybeConstants.size() &&
1405 " expected values and maybeConstants of the same size");
1406 bool atLeastOneReplacement =
false;
1407 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1412 assert(maybeConstant.template is<Attribute>() &&
1413 "The constified value should be either unchanged (i.e., == result) "
1415 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1416 loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1418 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1422 atLeastOneReplacement =
true;
1425 return atLeastOneReplacement;
1429 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1435 getConstifiedMixedOffset());
1437 getConstifiedMixedSizes());
1439 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1441 return success(atLeastOneReplacement);
1452 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1459 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1477 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1478 Type elementType = memrefType.getElementType();
1488 auto &body = getRegion();
1489 if (body.getNumArguments() != 1)
1490 return emitOpError(
"expected single number of entry block arguments");
1492 if (getResult().
getType() != body.getArgument(0).getType())
1493 return emitOpError(
"expected block argument of the same type result type");
1500 "body of 'memref.generic_atomic_rmw' should contain "
1501 "only operations with no side effects");
1531 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1532 <<
"] : " << getMemref().
getType() <<
' ';
1542 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1543 Type resultType = getResult().getType();
1544 if (parentType != resultType)
1545 return emitOpError() <<
"types mismatch between yield op: " << resultType
1546 <<
" and its parent: " << parentType;
1558 if (!op.isExternal()) {
1560 if (op.isUninitialized())
1561 p <<
"uninitialized";
1574 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1575 if (!memrefType || !memrefType.hasStaticShape())
1577 <<
"type should be static shaped memref, but got " << type;
1591 if (!llvm::isa<ElementsAttr>(initialValue))
1593 <<
"initial value should be a unit or elements attribute";
1598 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1599 if (!memrefType || !memrefType.hasStaticShape())
1600 return emitOpError(
"type should be static shaped memref, but got ")
1605 if (getInitialValue().has_value()) {
1606 Attribute initValue = getInitialValue().value();
1607 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1608 return emitOpError(
"initial value should be a unit or elements "
1609 "attribute, but got ")
1614 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1615 Type initType = elementsAttr.getType();
1617 if (initType != tensorType)
1618 return emitOpError(
"initial value expected to be of type ")
1619 << tensorType <<
", but was of type " << initType;
1623 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1624 uint64_t alignment = *alignAttr;
1626 if (!llvm::isPowerOf2_64(alignment))
1627 return emitError() <<
"alignment attribute value " << alignment
1628 <<
" is not a power of 2";
1635 ElementsAttr GlobalOp::getConstantInitValue() {
1636 auto initVal = getInitialValue();
1637 if (getConstant() && initVal.has_value())
1638 return llvm::cast<ElementsAttr>(initVal.value());
1653 return emitOpError(
"'")
1654 << getName() <<
"' does not reference a valid global memref";
1656 Type resultType = getResult().getType();
1657 if (global.getType() != resultType)
1658 return emitOpError(
"result type ")
1659 << resultType <<
" does not match type " << global.getType()
1660 <<
" of the global memref @" << getName();
1670 return emitOpError(
"incorrect number of indices for load, expected ")
1687 void MemorySpaceCastOp::getAsmResultNames(
1689 setNameFn(getResult(),
"memspacecast");
1693 if (inputs.size() != 1 || outputs.size() != 1)
1695 Type a = inputs.front(), b = outputs.front();
1696 auto aT = llvm::dyn_cast<MemRefType>(a);
1697 auto bT = llvm::dyn_cast<MemRefType>(b);
1699 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1700 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1703 if (aT.getElementType() != bT.getElementType())
1705 if (aT.getLayout() != bT.getLayout())
1707 if (aT.getShape() != bT.getShape())
1712 return uaT.getElementType() == ubT.getElementType();
1717 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1720 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1721 getSourceMutable().assign(parentCast.getSource());
1732 p <<
" " << getMemref() <<
'[';
1734 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1735 p <<
", locality<" << getLocalityHint();
1736 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1738 (*this)->getAttrs(),
1739 {
"localityHint",
"isWrite",
"isDataCache"});
1746 IntegerAttr localityHint;
1748 StringRef readOrWrite, cacheType;
1765 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1767 "rw specifier has to be 'read' or 'write'");
1768 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1771 if (cacheType !=
"data" && cacheType !=
"instr")
1773 "cache type has to be 'data' or 'instr'");
1775 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1783 return emitOpError(
"too few indices");
1788 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1800 auto type = getOperand().getType();
1801 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802 if (shapedType && shapedType.hasRank())
1804 return IntegerAttr();
1811 void ReinterpretCastOp::getAsmResultNames(
1813 setNameFn(getResult(),
"reinterpret_cast");
1820 MemRefType resultType,
Value source,
1830 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1837 MemRefType resultType,
Value source,
1842 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1846 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1850 strideValues, attrs);
1854 MemRefType resultType,
Value source,
Value offset,
1861 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1868 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1869 auto resultType = llvm::cast<MemRefType>(
getType());
1870 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1871 return emitError(
"different memory spaces specified for source type ")
1872 << srcType <<
" and result memref type " << resultType;
1873 if (srcType.getElementType() != resultType.getElementType())
1874 return emitError(
"different element types specified for source type ")
1875 << srcType <<
" and result memref type " << resultType;
1878 for (
auto [idx, resultSize, expectedSize] :
1880 if (!ShapedType::isDynamic(resultSize) &&
1881 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1882 return emitError(
"expected result type with size = ")
1883 << expectedSize <<
" instead of " << resultSize
1884 <<
" in dim = " << idx;
1890 int64_t resultOffset;
1893 return emitError(
"expected result type to have strided layout but found ")
1897 int64_t expectedOffset = getStaticOffsets().front();
1898 if (!ShapedType::isDynamic(resultOffset) &&
1899 !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1900 return emitError(
"expected result type with offset = ")
1901 << expectedOffset <<
" instead of " << resultOffset;
1904 for (
auto [idx, resultStride, expectedStride] :
1906 if (!ShapedType::isDynamic(resultStride) &&
1907 !ShapedType::isDynamic(expectedStride) &&
1908 resultStride != expectedStride)
1909 return emitError(
"expected result type with stride = ")
1910 << expectedStride <<
" instead of " << resultStride
1911 <<
" in dim = " << idx;
1918 Value src = getSource();
1919 auto getPrevSrc = [&]() ->
Value {
1922 return prev.getSource();
1926 return prev.getSource();
1931 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1932 return isConstantIntValue(val, 0);
1934 return prev.getSource();
1939 if (
auto prevSrc = getPrevSrc()) {
1940 getSourceMutable().assign(prevSrc);
1956 ShapedType::isDynamic);
1963 ShapedType::isDynamic);
1967 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1969 assert(values.size() == 1 &&
1970 "reinterpret_cast must have one and only one offset");
1972 ShapedType::isDynamic);
2014 struct ReinterpretCastOpExtractStridedMetadataFolder
2019 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2021 auto extractStridedMetadata =
2022 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2023 if (!extractStridedMetadata)
2030 extractStridedMetadata.getConstifiedMixedStrides();
2032 op.getConstifiedMixedStrides();
2033 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2036 unsigned rank = op.getType().getRank();
2037 for (
unsigned i = 0; i < rank; ++i) {
2038 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2043 assert(extractStridedMetadata.getSizes().size() ==
2044 op.getMixedSizes().size() &&
2045 "Strides and sizes rank must match");
2047 extractStridedMetadata.getConstifiedMixedSizes();
2049 op.getConstifiedMixedSizes();
2050 for (
unsigned i = 0; i < rank; ++i) {
2051 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2055 assert(op.getMixedOffsets().size() == 1 &&
2056 "reinterpret_cast with more than one offset should have been "
2057 "rejected by the verifier");
2059 extractStridedMetadata.getConstifiedMixedOffset();
2060 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2061 if (extractOffsetOfr != reinterpretOffsetOfr)
2069 Type srcTy = extractStridedMetadata.getSource().getType();
2071 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2074 extractStridedMetadata.getSource());
2083 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2090 void CollapseShapeOp::getAsmResultNames(
2092 setNameFn(getResult(),
"collapse_shape");
2095 void ExpandShapeOp::getAsmResultNames(
2097 setNameFn(getResult(),
"expand_shape");
2102 reifiedResultShapes = {
2103 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2112 static LogicalResult
2116 bool allowMultipleDynamicDimsPerGroup) {
2118 if (collapsedShape.size() != reassociation.size())
2119 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2120 << reassociation.size() <<
", expected " << collapsedShape.size();
2124 int64_t nextDim = 0;
2127 int64_t collapsedDim = it.index();
2129 bool foundDynamic =
false;
2130 for (int64_t expandedDim : group) {
2131 if (expandedDim != nextDim++)
2132 return op->
emitOpError(
"reassociation indices must be contiguous");
2134 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2136 << expandedDim <<
" is out of bounds";
2139 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2140 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2142 "at most one dimension in a reassociation group may be dynamic");
2143 foundDynamic =
true;
2148 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2151 <<
") must be dynamic if and only if reassociation group is "
2156 if (!foundDynamic) {
2157 int64_t groupSize = 1;
2158 for (int64_t expandedDim : group)
2159 groupSize *= expandedShape[expandedDim];
2160 if (groupSize != collapsedShape[collapsedDim])
2162 << collapsedShape[collapsedDim]
2163 <<
") must equal reassociation group size (" << groupSize <<
")";
2167 if (collapsedShape.empty()) {
2169 for (int64_t d : expandedShape)
2172 "rank 0 memrefs can only be extended/collapsed with/from ones");
2173 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2177 << expandedShape.size()
2178 <<
") inconsistent with number of reassociation indices (" << nextDim
2191 getReassociationIndices());
2200 getReassociationIndices());
2205 static FailureOr<StridedLayoutAttr>
2212 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2227 reverseResultStrides.reserve(resultShape.size());
2228 unsigned shapeIndex = resultShape.size() - 1;
2229 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2231 int64_t currentStrideToExpand = std::get<1>(it);
2232 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2233 reverseResultStrides.push_back(currentStrideToExpand);
2234 currentStrideToExpand =
2240 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2241 resultStrides.resize(resultShape.size(), 1);
2245 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2248 if (srcType.getLayout().isIdentity()) {
2251 MemRefLayoutAttrInterface layout;
2253 srcType.getMemorySpace());
2257 FailureOr<StridedLayoutAttr> computedLayout =
2259 if (failed(computedLayout))
2261 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2262 srcType.getMemorySpace());
2265 FailureOr<SmallVector<OpFoldResult>>
2267 MemRefType expandedType,
2270 std::optional<SmallVector<OpFoldResult>> outputShape =
2275 return *outputShape;
2282 auto [staticOutputShape, dynamicOutputShape] =
2284 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2286 dynamicOutputShape, staticOutputShape);
2294 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2295 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2296 builder, result.
location, memrefResultTy, reassociation, inputShape);
2299 assert(succeeded(outputShape) &&
"unable to infer output shape");
2300 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2307 auto srcType = llvm::cast<MemRefType>(src.
getType());
2308 FailureOr<MemRefType> resultType =
2309 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2312 assert(succeeded(resultType) &&
"could not compute layout");
2313 build(builder, result, *resultType, src, reassociation);
2321 auto srcType = llvm::cast<MemRefType>(src.
getType());
2322 FailureOr<MemRefType> resultType =
2323 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2326 assert(succeeded(resultType) &&
"could not compute layout");
2327 build(builder, result, *resultType, src, reassociation, outputShape);
2331 MemRefType srcType = getSrcType();
2332 MemRefType resultType = getResultType();
2334 if (srcType.getRank() > resultType.getRank()) {
2335 auto r0 = srcType.getRank();
2336 auto r1 = resultType.getRank();
2337 return emitOpError(
"has source rank ")
2338 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2339 << r0 <<
" > " << r1 <<
").";
2344 resultType.getShape(),
2345 getReassociationIndices(),
2350 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2351 srcType, resultType.getShape(), getReassociationIndices());
2352 if (failed(expectedResultType))
2353 return emitOpError(
"invalid source layout map");
2356 if (*expectedResultType != resultType)
2357 return emitOpError(
"expected expanded type to be ")
2358 << *expectedResultType <<
" but found " << resultType;
2360 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2361 return emitOpError(
"expected number of static shape bounds to be equal to "
2362 "the output rank (")
2363 << resultType.getRank() <<
") but found "
2364 << getStaticOutputShape().size() <<
" inputs instead";
2366 if ((int64_t)getOutputShape().size() !=
2367 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2368 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2369 "static_output_shape: static_output_shape has ")
2370 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2371 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2378 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2379 return emitOpError(
"invalid output shape provided at pos ") << pos;
2400 static FailureOr<StridedLayoutAttr>
2403 bool strict =
false) {
2406 auto srcShape = srcType.getShape();
2416 resultStrides.reserve(reassociation.size());
2419 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2420 ref = ref.drop_back();
2421 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2422 resultStrides.push_back(srcStrides[ref.back()]);
2428 resultStrides.push_back(ShapedType::kDynamic);
2433 unsigned resultStrideIndex = resultStrides.size() - 1;
2437 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2449 if (strict && (stride.saturated || srcStride.saturated))
2452 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2459 bool CollapseShapeOp::isGuaranteedCollapsible(
2462 if (srcType.getLayout().isIdentity())
2469 MemRefType CollapseShapeOp::computeCollapsedType(
2472 resultShape.reserve(reassociation.size());
2475 for (int64_t srcDim : group)
2478 resultShape.push_back(groupSize.asInteger());
2481 if (srcType.getLayout().isIdentity()) {
2484 MemRefLayoutAttrInterface layout;
2486 srcType.getMemorySpace());
2492 FailureOr<StridedLayoutAttr> computedLayout =
2494 assert(succeeded(computedLayout) &&
2495 "invalid source layout map or collapsing non-contiguous dims");
2496 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2497 srcType.getMemorySpace());
2503 auto srcType = llvm::cast<MemRefType>(src.
getType());
2504 MemRefType resultType =
2505 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2508 build(b, result, resultType, src, attrs);
2512 MemRefType srcType = getSrcType();
2513 MemRefType resultType = getResultType();
2515 if (srcType.getRank() < resultType.getRank()) {
2516 auto r0 = srcType.getRank();
2517 auto r1 = resultType.getRank();
2518 return emitOpError(
"has source rank ")
2519 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2520 << r0 <<
" < " << r1 <<
").";
2525 srcType.getShape(), getReassociationIndices(),
2530 MemRefType expectedResultType;
2531 if (srcType.getLayout().isIdentity()) {
2534 MemRefLayoutAttrInterface layout;
2535 expectedResultType =
2536 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2537 srcType.getMemorySpace());
2542 FailureOr<StridedLayoutAttr> computedLayout =
2544 if (failed(computedLayout))
2546 "invalid source layout map or collapsing non-contiguous dims");
2547 expectedResultType =
2549 *computedLayout, srcType.getMemorySpace());
2552 if (expectedResultType != resultType)
2553 return emitOpError(
"expected collapsed type to be ")
2554 << expectedResultType <<
" but found " << resultType;
2573 Type newResultType = CollapseShapeOp::computeCollapsedType(
2574 llvm::cast<MemRefType>(cast.getOperand().getType()),
2575 op.getReassociationIndices());
2577 if (newResultType == op.getResultType()) {
2579 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2582 op->
getLoc(), cast.getSource(), op.getReassociationIndices());
2594 memref::DimOp, MemRefType>,
2598 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2599 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2600 adaptor.getOperands());
2603 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2604 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2605 adaptor.getOperands());
2612 void ReshapeOp::getAsmResultNames(
2614 setNameFn(getResult(),
"reshape");
2618 Type operandType = getSource().getType();
2619 Type resultType = getResult().getType();
2621 Type operandElementType =
2622 llvm::cast<ShapedType>(operandType).getElementType();
2623 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2624 if (operandElementType != resultElementType)
2625 return emitOpError(
"element types of source and destination memref "
2626 "types should be the same");
2628 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2629 if (!operandMemRefType.getLayout().isIdentity())
2630 return emitOpError(
"source memref type should have identity affine map");
2634 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2635 if (resultMemRefType) {
2636 if (!resultMemRefType.getLayout().isIdentity())
2637 return emitOpError(
"result memref type should have identity affine map");
2638 if (shapeSize == ShapedType::kDynamic)
2639 return emitOpError(
"cannot use shape operand with dynamic length to "
2640 "reshape to statically-ranked memref type");
2641 if (shapeSize != resultMemRefType.getRank())
2643 "length of shape operand differs from the result's memref rank");
2654 return emitOpError(
"store index operand count not equal to memref rank");
2659 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2669 void SubViewOp::getAsmResultNames(
2671 setNameFn(getResult(),
"subview");
2677 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2681 unsigned rank = sourceMemRefType.getRank();
2683 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2684 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2685 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2692 int64_t targetOffset = sourceOffset;
2693 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2694 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2704 targetStrides.reserve(staticOffsets.size());
2705 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2706 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2713 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2715 targetOffset, targetStrides),
2716 sourceMemRefType.getMemorySpace());
2719 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2734 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2735 staticSizes, staticStrides);
2739 MemRefType sourceRankedTensorType,
2743 auto inferredType = llvm::cast<MemRefType>(
2744 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2745 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2747 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2748 return inferredType;
2751 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2753 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2756 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2758 rankReducedStrides.reserve(resultShape.size());
2759 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2760 if (!dimsToProject->contains(idx))
2761 rankReducedStrides.push_back(value);
2765 inferredLayout.getOffset(),
2766 rankReducedStrides),
2767 inferredType.getMemorySpace());
2771 MemRefType sourceRankedTensorType,
2780 return SubViewOp::inferRankReducedResultType(
2781 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2788 MemRefType resultType,
Value source,
2798 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2801 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2802 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2805 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2818 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2827 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2831 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2835 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2838 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2844 MemRefType resultType,
Value source,
2849 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2853 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2857 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2860 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2876 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2883 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2887 Value SubViewOp::getViewSource() {
return getSource(); }
2892 int64_t t1Offset, t2Offset;
2896 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2903 const llvm::SmallBitVector &droppedDims) {
2904 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2905 "incorrect number of bits");
2906 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2907 "incorrect number of dropped dims");
2908 int64_t t1Offset, t2Offset;
2912 if (failed(res1) || failed(res2))
2914 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2917 if (t1Strides[i] != t2Strides[
j])
2926 auto memrefType = llvm::cast<ShapedType>(expectedType);
2931 return op->
emitError(
"expected result rank to be smaller or equal to ")
2932 <<
"the source rank. ";
2934 return op->
emitError(
"expected result type to be ")
2936 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2938 return op->
emitError(
"expected result element type to be ")
2939 << memrefType.getElementType();
2941 return op->
emitError(
"expected result and source memory spaces to match.");
2943 return op->
emitError(
"expected result type to be ")
2945 <<
" or a rank-reduced version. (mismatch of result layout) ";
2947 llvm_unreachable(
"unexpected subview verification result");
2952 MemRefType baseType = getSourceType();
2953 MemRefType subViewType =
getType();
2956 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2957 return emitError(
"different memory spaces specified for base memref "
2959 << baseType <<
" and subview memref type " << subViewType;
2963 return emitError(
"base type ") << baseType <<
" is not strided";
2967 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2968 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2973 expectedType, subViewType);
2978 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2980 *
this, expectedType);
2985 *
this, expectedType);
2993 if (failed(unusedDims))
2995 *
this, expectedType);
3000 *
this, expectedType);
3006 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3015 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3016 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3017 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3019 unsigned rank = ranks[0];
3021 for (
unsigned idx = 0; idx < rank; ++idx) {
3023 op.isDynamicOffset(idx)
3024 ? op.getDynamicOffset(idx)
3027 op.isDynamicSize(idx)
3028 ? op.getDynamicSize(idx)
3031 op.isDynamicStride(idx)
3032 ? op.getDynamicStride(idx)
3034 res.emplace_back(
Range{offset, size, stride});
3047 MemRefType currentResultType, MemRefType currentSourceType,
3050 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3051 sourceType, mixedOffsets, mixedSizes, mixedStrides));
3053 currentSourceType, currentResultType, mixedSizes);
3054 if (failed(unusedDims))
3057 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3059 unsigned numDimsAfterReduction =
3060 nonRankReducedType.getRank() - unusedDims->count();
3061 shape.reserve(numDimsAfterReduction);
3062 strides.reserve(numDimsAfterReduction);
3063 for (
const auto &[idx, size, stride] :
3064 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3065 nonRankReducedType.getShape(), layout.getStrides())) {
3066 if (unusedDims->test(idx))
3068 shape.push_back(size);
3069 strides.push_back(stride);
3074 layout.getOffset(), strides),
3075 nonRankReducedType.getMemorySpace());
3080 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3081 unsigned rank = memrefType.getRank();
3086 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3087 targetShape, memrefType, offsets, sizes, strides));
3088 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3095 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3096 assert(sourceMemrefType &&
"not a ranked memref type");
3097 auto sourceShape = sourceMemrefType.getShape();
3098 if (sourceShape.equals(desiredShape))
3100 auto maybeRankReductionMask =
3102 if (!maybeRankReductionMask)
3112 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3115 auto mixedOffsets = subViewOp.getMixedOffsets();
3116 auto mixedSizes = subViewOp.getMixedSizes();
3117 auto mixedStrides = subViewOp.getMixedStrides();
3122 return !intValue || intValue.value() != 0;
3129 return !intValue || intValue.value() != 1;
3137 if (!intValue || *intValue != sourceShape[size.index()])
3161 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3165 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3169 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3170 return matchPattern(operand, matchConstantIndex());
3174 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3186 subViewOp.getType(), subViewOp.getSourceType(),
3187 llvm::cast<MemRefType>(castOp.getSource().getType()),
3188 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3189 subViewOp.getMixedStrides());
3194 subViewOp.getLoc(), resultType, castOp.getSource(),
3195 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3196 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3197 subViewOp.getStaticStrides());
3210 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3214 if (subViewOp.getSourceType() == subViewOp.getType()) {
3215 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3219 subViewOp.getSource());
3231 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3232 mixedSizes, mixedStrides);
3235 MemRefType nonReducedType = cast<MemRefType>(resTy);
3238 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3239 if (droppedDims.none())
3240 return nonReducedType;
3248 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3249 if (droppedDims.test(i))
3251 targetStrides.push_back(nonReducedStrides[i]);
3252 targetShape.push_back(nonReducedType.getDimSize(i));
3257 offset, targetStrides),
3258 nonReducedType.getMemorySpace());
3274 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3278 auto resultShapedType = llvm::cast<ShapedType>(getResult().
getType());
3279 auto sourceShapedType = llvm::cast<ShapedType>(getSource().
getType());
3281 if (resultShapedType.hasStaticShape() &&
3282 resultShapedType == sourceShapedType) {
3283 return getViewSource();
3289 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3290 auto srcSizes = srcSubview.getMixedSizes();
3292 auto offsets = getMixedOffsets();
3293 bool allOffsetsZero = llvm::all_of(
3295 auto strides = getMixedStrides();
3296 bool allStridesOne = llvm::all_of(
3298 bool allSizesSame = llvm::equal(sizes, srcSizes);
3299 if (allOffsetsZero && allStridesOne && allSizesSame &&
3300 resultShapedType == sourceShapedType)
3301 return getViewSource();
3311 void TransposeOp::getAsmResultNames(
3313 setNameFn(getResult(),
"transpose");
3319 auto originalSizes = memRefType.getShape();
3321 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3324 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3325 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3334 AffineMapAttr permutation,
3336 auto permutationMap = permutation.getValue();
3337 assert(permutationMap);
3339 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3343 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3344 build(b, result, resultType, in, attrs);
3349 p <<
" " << getIn() <<
" " << getPermutation();
3351 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3357 MemRefType srcType, dstType;
3366 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3373 return emitOpError(
"expected a permutation map");
3374 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3375 return emitOpError(
"expected a permutation map of same rank as the input");
3377 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3378 auto resultType = llvm::cast<MemRefType>(
getType());
3383 return emitOpError(
"result type ")
3385 <<
" is not equivalent to the canonical transposed input type "
3386 << canonicalResultType;
3393 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3397 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3399 getPermutation().
compose(otherTransposeOp.getPermutation());
3400 getInMutable().assign(otherTransposeOp.getIn());
3401 setPermutation(composedPermutation);
3411 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3412 setNameFn(getResult(),
"view");
3416 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3420 if (!baseType.getLayout().isIdentity())
3421 return emitError(
"unsupported map for base memref type ") << baseType;
3424 if (!viewType.getLayout().isIdentity())
3425 return emitError(
"unsupported map for result memref type ") << viewType;
3428 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3429 return emitError(
"different memory spaces specified for base memref "
3431 << baseType <<
" and view memref type " << viewType;
3434 unsigned numDynamicDims = viewType.getNumDynamicDims();
3435 if (getSizes().size() != numDynamicDims)
3436 return emitError(
"incorrect number of size operands for type ") << viewType;
3441 Value ViewOp::getViewSource() {
return getSource(); }
3448 LogicalResult matchAndRewrite(ViewOp viewOp,
3451 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3452 return matchPattern(operand, matchConstantIndex());
3457 auto memrefType = viewOp.getType();
3464 assert(oldOffset == 0 &&
"Expected 0 offset");
3472 newShapeConstants.reserve(memrefType.getRank());
3474 unsigned dynamicDimPos = 0;
3475 unsigned rank = memrefType.getRank();
3476 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3477 int64_t dimSize = memrefType.getDimSize(dim);
3479 if (!ShapedType::isDynamic(dimSize)) {
3480 newShapeConstants.push_back(dimSize);
3483 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3484 if (
auto constantIndexOp =
3485 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3487 newShapeConstants.push_back(constantIndexOp.value());
3490 newShapeConstants.push_back(dimSize);
3491 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3497 MemRefType newMemRefType =
3500 if (newMemRefType == memrefType)
3504 auto newViewOp = rewriter.
create<ViewOp>(
3505 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3506 viewOp.getByteShift(), newOperands);
3516 LogicalResult matchAndRewrite(ViewOp viewOp,
3518 Value memrefOperand = viewOp.getOperand(0);
3519 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3522 Value allocOperand = memrefCastOp.getOperand();
3527 viewOp.getByteShift(),
3537 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3547 "expects the number of subscripts to be equal to memref rank");
3548 switch (getKind()) {
3549 case arith::AtomicRMWKind::addf:
3550 case arith::AtomicRMWKind::maximumf:
3551 case arith::AtomicRMWKind::minimumf:
3552 case arith::AtomicRMWKind::mulf:
3553 if (!llvm::isa<FloatType>(getValue().
getType()))
3554 return emitOpError() <<
"with kind '"
3555 << arith::stringifyAtomicRMWKind(getKind())
3556 <<
"' expects a floating-point type";
3558 case arith::AtomicRMWKind::addi:
3559 case arith::AtomicRMWKind::maxs:
3560 case arith::AtomicRMWKind::maxu:
3561 case arith::AtomicRMWKind::mins:
3562 case arith::AtomicRMWKind::minu:
3563 case arith::AtomicRMWKind::muli:
3564 case arith::AtomicRMWKind::ori:
3565 case arith::AtomicRMWKind::andi:
3566 if (!llvm::isa<IntegerType>(getValue().
getType()))
3567 return emitOpError() <<
"with kind '"
3568 << arith::stringifyAtomicRMWKind(getKind())
3569 <<
"' expects an integer type";
3588 #define GET_OP_CLASSES
3589 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.