23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
34 return arith::ConstantOp::materialize(builder, value, type, loc);
47 auto cast = operand.get().getDefiningOp<CastOp>();
48 if (cast && operand.get() != inner &&
49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50 operand.set(cast.getOperand());
54 return success(folded);
60 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
62 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
69 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
123 int64_t constValue = it.value();
124 if (!isDynamic(constValue))
128 if (
auto attr = dyn_cast<Attribute>(ofr)) {
142 ofr = builder.
getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
145 std::optional<int64_t> maybeConstant =
165 LogicalResult hasStaticInformation =
167 if (failed(hasStaticInformation))
178 LogicalResult hasStaticInformation =
180 if (failed(hasStaticInformation))
189 void AllocOp::getAsmResultNames(
191 setNameFn(getResult(),
"alloc");
194 void AllocaOp::getAsmResultNames(
196 setNameFn(getResult(),
"alloca");
199 template <
typename AllocLikeOp>
201 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
202 "applies to only alloc or alloca");
203 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205 return op.emitOpError(
"result must be a memref");
207 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
208 return op.emitOpError(
"dimension operand count does not equal memref "
209 "dynamic dimension count");
211 unsigned numSymbols = 0;
212 if (!memRefType.getLayout().isIdentity())
213 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
214 if (op.getSymbolOperands().size() != numSymbols)
215 return op.emitOpError(
"symbol operand count does not equal memref symbol "
217 << numSymbols <<
", got " << op.getSymbolOperands().size();
228 "requires an ancestor op with AutomaticAllocationScope trait");
235 template <
typename AllocLikeOp>
239 LogicalResult matchAndRewrite(AllocLikeOp alloc,
243 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
245 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
247 return constSizeArg.isNonNegative();
251 auto memrefType = alloc.getType();
256 newShapeConstants.reserve(memrefType.getRank());
259 unsigned dynamicDimPos = 0;
260 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
261 int64_t dimSize = memrefType.getDimSize(dim);
263 if (!ShapedType::isDynamic(dimSize)) {
264 newShapeConstants.push_back(dimSize);
267 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
270 constSizeArg.isNonNegative()) {
272 newShapeConstants.push_back(constSizeArg.getZExtValue());
275 newShapeConstants.push_back(ShapedType::kDynamic);
276 dynamicSizes.push_back(dynamicSize);
282 MemRefType newMemRefType =
284 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
287 auto newAlloc = rewriter.
create<AllocLikeOp>(
288 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
289 alloc.getAlignmentAttr());
297 template <
typename T>
301 LogicalResult matchAndRewrite(T alloc,
303 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
304 if (auto storeOp = dyn_cast<StoreOp>(op))
305 return storeOp.getValue() == alloc;
306 return !isa<DeallocOp>(op);
310 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
321 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
326 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
335 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
336 MemRefType resultType =
getType();
339 if (!sourceType.getLayout().isIdentity())
340 return emitError(
"unsupported layout for source memref type ")
344 if (!resultType.getLayout().isIdentity())
345 return emitError(
"unsupported layout for result memref type ")
349 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
350 return emitError(
"different memory spaces specified for source memref "
352 << sourceType <<
" and result memref type " << resultType;
355 if (sourceType.getElementType() != resultType.getElementType())
356 return emitError(
"different element types specified for source memref "
358 << sourceType <<
" and result memref type " << resultType;
361 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
362 return emitError(
"missing dimension operand for result type ")
364 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
365 return emitError(
"unnecessary dimension operand for result type ")
373 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
381 bool printBlockTerminators =
false;
384 if (!getResults().empty()) {
385 p <<
" -> (" << getResultTypes() <<
")";
386 printBlockTerminators =
true;
391 printBlockTerminators);
407 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
417 void AllocaScopeOp::getSuccessorRegions(
430 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
436 if (isa<SideEffects::AutomaticAllocationScopeResource>(
437 effect->getResource()))
453 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
459 if (isa<SideEffects::AutomaticAllocationScopeResource>(
460 effect->getResource()))
483 bool hasPotentialAlloca =
496 if (hasPotentialAlloca) {
529 if (!lastParentWithoutScope ||
542 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
543 if (!lastParentWithoutScope ||
550 Region *containingRegion =
nullptr;
551 for (
auto &r : lastParentWithoutScope->
getRegions()) {
552 if (r.isAncestor(op->getParentRegion())) {
553 assert(containingRegion ==
nullptr &&
554 "only one region can contain the op");
555 containingRegion = &r;
558 assert(containingRegion &&
"op must be contained in a region");
568 return containingRegion->isAncestor(v.getParentRegion());
571 toHoist.push_back(alloc);
578 for (
auto *op : toHoist) {
579 auto *cloned = rewriter.
clone(*op);
580 rewriter.
replaceOp(op, cloned->getResults());
596 if (!llvm::isPowerOf2_32(getAlignment()))
597 return emitOpError(
"alignment must be power of 2");
606 setNameFn(getResult(),
"cast");
647 MemRefType sourceType =
648 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
649 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
652 if (!sourceType || !resultType)
656 if (sourceType.getElementType() != resultType.getElementType())
660 if (sourceType.getRank() != resultType.getRank())
664 int64_t sourceOffset, resultOffset;
671 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
672 auto ss = std::get<0>(it), st = std::get<1>(it);
674 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
679 if (sourceOffset != resultOffset)
680 if (ShapedType::isDynamic(sourceOffset) &&
681 !ShapedType::isDynamic(resultOffset))
685 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
686 auto ss = std::get<0>(it), st = std::get<1>(it);
688 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
696 if (inputs.size() != 1 || outputs.size() != 1)
698 Type a = inputs.front(), b = outputs.front();
699 auto aT = llvm::dyn_cast<MemRefType>(a);
700 auto bT = llvm::dyn_cast<MemRefType>(b);
702 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
703 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
706 if (aT.getElementType() != bT.getElementType())
708 if (aT.getLayout() != bT.getLayout()) {
709 int64_t aOffset, bOffset;
713 aStrides.size() != bStrides.size())
720 auto checkCompatible = [](int64_t a, int64_t b) {
721 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
723 if (!checkCompatible(aOffset, bOffset))
725 for (
const auto &aStride :
enumerate(aStrides))
726 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
729 if (aT.getMemorySpace() != bT.getMemorySpace())
733 if (aT.getRank() != bT.getRank())
736 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
737 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
738 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
752 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
753 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
754 if (aEltType != bEltType)
757 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
758 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
759 return aMemSpace == bMemSpace;
780 LogicalResult matchAndRewrite(CopyOp copyOp,
782 bool modified =
false;
785 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
786 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
787 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
789 if (fromType && toType) {
790 if (fromType.getShape() == toType.getShape() &&
791 fromType.getElementType() == toType.getElementType()) {
793 copyOp.getSourceMutable().assign(castOp.getSource());
801 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
802 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
803 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
805 if (fromType && toType) {
806 if (fromType.getShape() == toType.getShape() &&
807 fromType.getElementType() == toType.getElementType()) {
809 copyOp.getTargetMutable().assign(castOp.getSource());
816 return success(modified);
824 LogicalResult matchAndRewrite(CopyOp copyOp,
826 if (copyOp.getSource() != copyOp.getTarget())
841 LogicalResult matchAndRewrite(CopyOp copyOp,
843 if (isEmptyMemRef(copyOp.getSource().getType()) ||
844 isEmptyMemRef(copyOp.getTarget().getType())) {
856 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
859 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
867 operand.set(castOp.getOperand());
871 return success(folded);
878 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
889 setNameFn(getResult(),
"dim");
895 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
896 build(builder, result, source, indexValue);
899 std::optional<int64_t> DimOp::getConstantIndex() {
908 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
909 if (!rankedSourceType)
923 std::map<int64_t, unsigned> numOccurences;
924 for (
auto val : vals)
925 numOccurences[val]++;
926 return numOccurences;
936 static FailureOr<llvm::SmallBitVector>
939 llvm::SmallBitVector unusedDims(originalType.getRank());
940 if (originalType.getRank() == reducedType.getRank())
944 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
945 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
946 unusedDims.set(dim.index());
950 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
951 originalType.getRank())
955 int64_t originalOffset, candidateOffset;
971 std::map<int64_t, unsigned> currUnaccountedStrides =
973 std::map<int64_t, unsigned> candidateStridesNumOccurences =
975 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
976 if (!unusedDims.test(dim))
978 int64_t originalStride = originalStrides[dim];
979 if (currUnaccountedStrides[originalStride] >
980 candidateStridesNumOccurences[originalStride]) {
982 currUnaccountedStrides[originalStride]--;
985 if (currUnaccountedStrides[originalStride] ==
986 candidateStridesNumOccurences[originalStride]) {
988 unusedDims.reset(dim);
991 if (currUnaccountedStrides[originalStride] <
992 candidateStridesNumOccurences[originalStride]) {
999 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1000 originalType.getRank())
1006 MemRefType sourceType = getSourceType();
1007 MemRefType resultType =
getType();
1008 FailureOr<llvm::SmallBitVector> unusedDims =
1010 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1016 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1021 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1027 int64_t indexVal = index.getInt();
1028 if (indexVal < 0 || indexVal >= memrefType.getRank())
1032 if (!memrefType.isDynamicDim(index.getInt())) {
1034 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1038 unsigned unsignedIndex = index.getValue().getZExtValue();
1041 Operation *definingOp = getSource().getDefiningOp();
1043 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1044 return *(alloc.getDynamicSizes().begin() +
1045 memrefType.getDynamicDimIndex(unsignedIndex));
1047 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1048 return *(alloca.getDynamicSizes().begin() +
1049 memrefType.getDynamicDimIndex(unsignedIndex));
1051 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1052 return *(view.getDynamicSizes().begin() +
1053 memrefType.getDynamicDimIndex(unsignedIndex));
1055 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1056 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1057 unsigned resultIndex = 0;
1058 unsigned sourceRank = subview.getSourceType().getRank();
1059 unsigned sourceIndex = 0;
1060 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1061 if (unusedDims.test(i))
1063 if (resultIndex == unsignedIndex) {
1069 assert(subview.isDynamicSize(sourceIndex) &&
1070 "expected dynamic subview size");
1071 return subview.getDynamicSize(sourceIndex);
1074 if (
auto sizeInterface =
1075 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1076 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1077 "Expected dynamic subview size");
1078 return sizeInterface.getDynamicSize(unsignedIndex);
1094 LogicalResult matchAndRewrite(DimOp dim,
1096 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1100 dim,
"Dim op is not defined by a reshape op.");
1111 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1112 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1113 if (reshape->isBeforeInBlock(definingOp)) {
1116 "dim.getIndex is not defined before reshape in the same block.");
1121 else if (dim->getBlock() != reshape->getBlock() &&
1122 !dim.getIndex().getParentRegion()->isProperAncestor(
1123 reshape->getParentRegion())) {
1128 dim,
"dim.getIndex does not dominate reshape.");
1136 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1137 if (load.
getType() != dim.getType())
1138 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1148 results.
add<DimOfMemRefReshape>(context);
1159 Value elementsPerStride) {
1171 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1172 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1173 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1175 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1178 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1179 <<
", " << getTagMemRef().getType();
1220 bool isStrided = strideInfo.size() == 2;
1221 if (!strideInfo.empty() && !
isStrided) {
1223 "expected two stride related operands");
1228 if (types.size() != 3)
1251 unsigned numOperands = getNumOperands();
1255 if (numOperands < 4)
1256 return emitOpError(
"expected at least 4 operands");
1261 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1262 return emitOpError(
"expected source to be of memref type");
1263 if (numOperands < getSrcMemRefRank() + 4)
1264 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1266 if (!getSrcIndices().empty() &&
1267 !llvm::all_of(getSrcIndices().getTypes(),
1269 return emitOpError(
"expected source indices to be of index type");
1272 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1273 return emitOpError(
"expected destination to be of memref type");
1274 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1275 if (numOperands < numExpectedOperands)
1276 return emitOpError() <<
"expected at least " << numExpectedOperands
1278 if (!getDstIndices().empty() &&
1279 !llvm::all_of(getDstIndices().getTypes(),
1281 return emitOpError(
"expected destination indices to be of index type");
1285 return emitOpError(
"expected num elements to be of index type");
1288 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1289 return emitOpError(
"expected tag to be of memref type");
1290 numExpectedOperands += getTagMemRefRank();
1291 if (numOperands < numExpectedOperands)
1292 return emitOpError() <<
"expected at least " << numExpectedOperands
1294 if (!getTagIndices().empty() &&
1295 !llvm::all_of(getTagIndices().getTypes(),
1297 return emitOpError(
"expected tag indices to be of index type");
1301 if (numOperands != numExpectedOperands &&
1302 numOperands != numExpectedOperands + 2)
1303 return emitOpError(
"incorrect number of operands");
1307 if (!getStride().
getType().isIndex() ||
1308 !getNumElementsPerStride().
getType().isIndex())
1310 "expected stride and num elements per stride to be of type index");
1316 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1326 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1334 unsigned numTagIndices = getTagIndices().size();
1335 unsigned tagMemRefRank = getTagMemRefRank();
1336 if (numTagIndices != tagMemRefRank)
1337 return emitOpError() <<
"expected tagIndices to have the same number of "
1338 "elements as the tagMemRef rank, expected "
1339 << tagMemRefRank <<
", but got " << numTagIndices;
1347 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1349 setNameFn(getResult(),
"intptr");
1358 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1359 MLIRContext *context, std::optional<Location> location,
1360 ExtractStridedMetadataOp::Adaptor adaptor,
1362 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1366 unsigned sourceRank = sourceType.getRank();
1370 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1372 inferredReturnTypes.push_back(memrefType);
1374 inferredReturnTypes.push_back(indexType);
1376 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1377 inferredReturnTypes.push_back(indexType);
1381 void ExtractStridedMetadataOp::getAsmResultNames(
1383 setNameFn(getBaseBuffer(),
"base_buffer");
1384 setNameFn(getOffset(),
"offset");
1387 if (!getSizes().empty()) {
1388 setNameFn(getSizes().front(),
"sizes");
1389 setNameFn(getStrides().front(),
"strides");
1396 template <
typename Container>
1400 assert(values.size() == maybeConstants.size() &&
1401 " expected values and maybeConstants of the same size");
1402 bool atLeastOneReplacement =
false;
1403 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1408 assert(isa<Attribute>(maybeConstant) &&
1409 "The constified value should be either unchanged (i.e., == result) "
1411 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1412 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1413 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1417 atLeastOneReplacement =
true;
1420 return atLeastOneReplacement;
1424 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1430 getConstifiedMixedOffset());
1432 getConstifiedMixedSizes());
1434 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1436 return success(atLeastOneReplacement);
1447 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1454 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1472 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1473 Type elementType = memrefType.getElementType();
1483 auto &body = getRegion();
1484 if (body.getNumArguments() != 1)
1485 return emitOpError(
"expected single number of entry block arguments");
1487 if (getResult().
getType() != body.getArgument(0).getType())
1488 return emitOpError(
"expected block argument of the same type result type");
1495 "body of 'memref.generic_atomic_rmw' should contain "
1496 "only operations with no side effects");
1526 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1527 <<
"] : " << getMemref().
getType() <<
' ';
1537 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1538 Type resultType = getResult().getType();
1539 if (parentType != resultType)
1540 return emitOpError() <<
"types mismatch between yield op: " << resultType
1541 <<
" and its parent: " << parentType;
1553 if (!op.isExternal()) {
1555 if (op.isUninitialized())
1556 p <<
"uninitialized";
1569 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1570 if (!memrefType || !memrefType.hasStaticShape())
1572 <<
"type should be static shaped memref, but got " << type;
1586 if (!llvm::isa<ElementsAttr>(initialValue))
1588 <<
"initial value should be a unit or elements attribute";
1593 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1594 if (!memrefType || !memrefType.hasStaticShape())
1595 return emitOpError(
"type should be static shaped memref, but got ")
1600 if (getInitialValue().has_value()) {
1601 Attribute initValue = getInitialValue().value();
1602 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1603 return emitOpError(
"initial value should be a unit or elements "
1604 "attribute, but got ")
1609 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1610 Type initType = elementsAttr.getType();
1612 if (initType != tensorType)
1613 return emitOpError(
"initial value expected to be of type ")
1614 << tensorType <<
", but was of type " << initType;
1618 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1619 uint64_t alignment = *alignAttr;
1621 if (!llvm::isPowerOf2_64(alignment))
1622 return emitError() <<
"alignment attribute value " << alignment
1623 <<
" is not a power of 2";
1630 ElementsAttr GlobalOp::getConstantInitValue() {
1631 auto initVal = getInitialValue();
1632 if (getConstant() && initVal.has_value())
1633 return llvm::cast<ElementsAttr>(initVal.value());
1648 return emitOpError(
"'")
1649 << getName() <<
"' does not reference a valid global memref";
1651 Type resultType = getResult().getType();
1652 if (global.getType() != resultType)
1653 return emitOpError(
"result type ")
1654 << resultType <<
" does not match type " << global.getType()
1655 <<
" of the global memref @" << getName();
1665 return emitOpError(
"incorrect number of indices for load, expected ")
1682 void MemorySpaceCastOp::getAsmResultNames(
1684 setNameFn(getResult(),
"memspacecast");
1688 if (inputs.size() != 1 || outputs.size() != 1)
1690 Type a = inputs.front(), b = outputs.front();
1691 auto aT = llvm::dyn_cast<MemRefType>(a);
1692 auto bT = llvm::dyn_cast<MemRefType>(b);
1694 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1695 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1698 if (aT.getElementType() != bT.getElementType())
1700 if (aT.getLayout() != bT.getLayout())
1702 if (aT.getShape() != bT.getShape())
1707 return uaT.getElementType() == ubT.getElementType();
1712 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1715 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1716 getSourceMutable().assign(parentCast.getSource());
1727 p <<
" " << getMemref() <<
'[';
1729 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1730 p <<
", locality<" << getLocalityHint();
1731 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1733 (*this)->getAttrs(),
1734 {
"localityHint",
"isWrite",
"isDataCache"});
1741 IntegerAttr localityHint;
1743 StringRef readOrWrite, cacheType;
1760 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1762 "rw specifier has to be 'read' or 'write'");
1763 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1766 if (cacheType !=
"data" && cacheType !=
"instr")
1768 "cache type has to be 'data' or 'instr'");
1770 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1778 return emitOpError(
"too few indices");
1783 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1795 auto type = getOperand().getType();
1796 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1797 if (shapedType && shapedType.hasRank())
1799 return IntegerAttr();
1806 void ReinterpretCastOp::getAsmResultNames(
1808 setNameFn(getResult(),
"reinterpret_cast");
1815 MemRefType resultType,
Value source,
1825 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1836 auto sourceType = cast<BaseMemRefType>(source.
getType());
1843 b.
getContext(), staticOffsets.front(), staticStrides);
1844 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1845 stridedLayout, sourceType.getMemorySpace());
1846 build(b, result, resultType, source, offset, sizes, strides, attrs);
1850 MemRefType resultType,
Value source,
1855 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1859 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1863 strideValues, attrs);
1867 MemRefType resultType,
Value source,
Value offset,
1874 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1881 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1882 auto resultType = llvm::cast<MemRefType>(
getType());
1883 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1884 return emitError(
"different memory spaces specified for source type ")
1885 << srcType <<
" and result memref type " << resultType;
1886 if (srcType.getElementType() != resultType.getElementType())
1887 return emitError(
"different element types specified for source type ")
1888 << srcType <<
" and result memref type " << resultType;
1891 for (
auto [idx, resultSize, expectedSize] :
1893 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1894 return emitError(
"expected result type with size = ")
1895 << (ShapedType::isDynamic(expectedSize)
1896 ? std::string(
"dynamic")
1897 : std::to_string(expectedSize))
1898 <<
" instead of " << resultSize <<
" in dim = " << idx;
1904 int64_t resultOffset;
1907 return emitError(
"expected result type to have strided layout but found ")
1911 int64_t expectedOffset = getStaticOffsets().front();
1912 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1913 return emitError(
"expected result type with offset = ")
1914 << (ShapedType::isDynamic(expectedOffset)
1915 ? std::string(
"dynamic")
1916 : std::to_string(expectedOffset))
1917 <<
" instead of " << resultOffset;
1920 for (
auto [idx, resultStride, expectedStride] :
1922 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1923 return emitError(
"expected result type with stride = ")
1924 << (ShapedType::isDynamic(expectedStride)
1925 ? std::string(
"dynamic")
1926 : std::to_string(expectedStride))
1927 <<
" instead of " << resultStride <<
" in dim = " << idx;
1934 Value src = getSource();
1935 auto getPrevSrc = [&]() ->
Value {
1938 return prev.getSource();
1942 return prev.getSource();
1947 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1948 return isConstantIntValue(val, 0);
1950 return prev.getSource();
1955 if (
auto prevSrc = getPrevSrc()) {
1956 getSourceMutable().assign(prevSrc);
1972 ShapedType::isDynamic);
1979 ShapedType::isDynamic);
1983 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1985 assert(values.size() == 1 &&
1986 "reinterpret_cast must have one and only one offset");
1988 ShapedType::isDynamic);
2030 struct ReinterpretCastOpExtractStridedMetadataFolder
2035 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2037 auto extractStridedMetadata =
2038 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2039 if (!extractStridedMetadata)
2046 extractStridedMetadata.getConstifiedMixedStrides();
2048 op.getConstifiedMixedStrides();
2049 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2052 unsigned rank = op.getType().getRank();
2053 for (
unsigned i = 0; i < rank; ++i) {
2054 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2059 assert(extractStridedMetadata.getSizes().size() ==
2060 op.getMixedSizes().size() &&
2061 "Strides and sizes rank must match");
2063 extractStridedMetadata.getConstifiedMixedSizes();
2065 op.getConstifiedMixedSizes();
2066 for (
unsigned i = 0; i < rank; ++i) {
2067 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2071 assert(op.getMixedOffsets().size() == 1 &&
2072 "reinterpret_cast with more than one offset should have been "
2073 "rejected by the verifier");
2075 extractStridedMetadata.getConstifiedMixedOffset();
2076 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2077 if (extractOffsetOfr != reinterpretOffsetOfr)
2085 Type srcTy = extractStridedMetadata.getSource().getType();
2086 if (srcTy == op.getResult().getType())
2087 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2090 extractStridedMetadata.getSource());
2099 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2106 void CollapseShapeOp::getAsmResultNames(
2108 setNameFn(getResult(),
"collapse_shape");
2111 void ExpandShapeOp::getAsmResultNames(
2113 setNameFn(getResult(),
"expand_shape");
2118 reifiedResultShapes = {
2119 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2128 static LogicalResult
2132 bool allowMultipleDynamicDimsPerGroup) {
2134 if (collapsedShape.size() != reassociation.size())
2135 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2136 << reassociation.size() <<
", expected " << collapsedShape.size();
2140 int64_t nextDim = 0;
2143 int64_t collapsedDim = it.index();
2145 bool foundDynamic =
false;
2146 for (int64_t expandedDim : group) {
2147 if (expandedDim != nextDim++)
2148 return op->
emitOpError(
"reassociation indices must be contiguous");
2150 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2152 << expandedDim <<
" is out of bounds";
2155 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2156 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2158 "at most one dimension in a reassociation group may be dynamic");
2159 foundDynamic =
true;
2164 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2167 <<
") must be dynamic if and only if reassociation group is "
2172 if (!foundDynamic) {
2173 int64_t groupSize = 1;
2174 for (int64_t expandedDim : group)
2175 groupSize *= expandedShape[expandedDim];
2176 if (groupSize != collapsedShape[collapsedDim])
2178 << collapsedShape[collapsedDim]
2179 <<
") must equal reassociation group size (" << groupSize <<
")";
2183 if (collapsedShape.empty()) {
2185 for (int64_t d : expandedShape)
2188 "rank 0 memrefs can only be extended/collapsed with/from ones");
2189 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2193 << expandedShape.size()
2194 <<
") inconsistent with number of reassociation indices (" << nextDim
2207 getReassociationIndices());
2216 getReassociationIndices());
2221 static FailureOr<StridedLayoutAttr>
2228 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2243 reverseResultStrides.reserve(resultShape.size());
2244 unsigned shapeIndex = resultShape.size() - 1;
2245 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2247 int64_t currentStrideToExpand = std::get<1>(it);
2248 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2249 reverseResultStrides.push_back(currentStrideToExpand);
2250 currentStrideToExpand =
2256 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2257 resultStrides.resize(resultShape.size(), 1);
2261 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2264 if (srcType.getLayout().isIdentity()) {
2267 MemRefLayoutAttrInterface layout;
2269 srcType.getMemorySpace());
2273 FailureOr<StridedLayoutAttr> computedLayout =
2275 if (failed(computedLayout))
2277 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2278 srcType.getMemorySpace());
2281 FailureOr<SmallVector<OpFoldResult>>
2283 MemRefType expandedType,
2286 std::optional<SmallVector<OpFoldResult>> outputShape =
2291 return *outputShape;
2298 auto [staticOutputShape, dynamicOutputShape] =
2300 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2302 dynamicOutputShape, staticOutputShape);
2310 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2311 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2312 builder, result.
location, memrefResultTy, reassociation, inputShape);
2315 assert(succeeded(outputShape) &&
"unable to infer output shape");
2316 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2323 auto srcType = llvm::cast<MemRefType>(src.
getType());
2324 FailureOr<MemRefType> resultType =
2325 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2328 assert(succeeded(resultType) &&
"could not compute layout");
2329 build(builder, result, *resultType, src, reassociation);
2337 auto srcType = llvm::cast<MemRefType>(src.
getType());
2338 FailureOr<MemRefType> resultType =
2339 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2342 assert(succeeded(resultType) &&
"could not compute layout");
2343 build(builder, result, *resultType, src, reassociation, outputShape);
2347 MemRefType srcType = getSrcType();
2348 MemRefType resultType = getResultType();
2350 if (srcType.getRank() > resultType.getRank()) {
2351 auto r0 = srcType.getRank();
2352 auto r1 = resultType.getRank();
2353 return emitOpError(
"has source rank ")
2354 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2355 << r0 <<
" > " << r1 <<
").";
2360 resultType.getShape(),
2361 getReassociationIndices(),
2366 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2367 srcType, resultType.getShape(), getReassociationIndices());
2368 if (failed(expectedResultType))
2369 return emitOpError(
"invalid source layout map");
2372 if (*expectedResultType != resultType)
2373 return emitOpError(
"expected expanded type to be ")
2374 << *expectedResultType <<
" but found " << resultType;
2376 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2377 return emitOpError(
"expected number of static shape bounds to be equal to "
2378 "the output rank (")
2379 << resultType.getRank() <<
") but found "
2380 << getStaticOutputShape().size() <<
" inputs instead";
2382 if ((int64_t)getOutputShape().size() !=
2383 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2384 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2385 "static_output_shape: static_output_shape has ")
2386 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2387 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2394 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2395 return emitOpError(
"invalid output shape provided at pos ") << pos;
2416 static FailureOr<StridedLayoutAttr>
2419 bool strict =
false) {
2422 auto srcShape = srcType.getShape();
2432 resultStrides.reserve(reassociation.size());
2435 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2436 ref = ref.drop_back();
2437 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2438 resultStrides.push_back(srcStrides[ref.back()]);
2444 resultStrides.push_back(ShapedType::kDynamic);
2449 unsigned resultStrideIndex = resultStrides.size() - 1;
2453 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2465 if (strict && (stride.saturated || srcStride.saturated))
2470 if (srcShape[idx - 1] == 1)
2473 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2480 bool CollapseShapeOp::isGuaranteedCollapsible(
2483 if (srcType.getLayout().isIdentity())
2490 MemRefType CollapseShapeOp::computeCollapsedType(
2493 resultShape.reserve(reassociation.size());
2496 for (int64_t srcDim : group)
2499 resultShape.push_back(groupSize.asInteger());
2502 if (srcType.getLayout().isIdentity()) {
2505 MemRefLayoutAttrInterface layout;
2507 srcType.getMemorySpace());
2513 FailureOr<StridedLayoutAttr> computedLayout =
2515 assert(succeeded(computedLayout) &&
2516 "invalid source layout map or collapsing non-contiguous dims");
2517 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2518 srcType.getMemorySpace());
2524 auto srcType = llvm::cast<MemRefType>(src.
getType());
2525 MemRefType resultType =
2526 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2529 build(b, result, resultType, src, attrs);
2533 MemRefType srcType = getSrcType();
2534 MemRefType resultType = getResultType();
2536 if (srcType.getRank() < resultType.getRank()) {
2537 auto r0 = srcType.getRank();
2538 auto r1 = resultType.getRank();
2539 return emitOpError(
"has source rank ")
2540 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2541 << r0 <<
" < " << r1 <<
").";
2546 srcType.getShape(), getReassociationIndices(),
2551 MemRefType expectedResultType;
2552 if (srcType.getLayout().isIdentity()) {
2555 MemRefLayoutAttrInterface layout;
2556 expectedResultType =
2557 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2558 srcType.getMemorySpace());
2563 FailureOr<StridedLayoutAttr> computedLayout =
2565 if (failed(computedLayout))
2567 "invalid source layout map or collapsing non-contiguous dims");
2568 expectedResultType =
2570 *computedLayout, srcType.getMemorySpace());
2573 if (expectedResultType != resultType)
2574 return emitOpError(
"expected collapsed type to be ")
2575 << expectedResultType <<
" but found " << resultType;
2587 auto cast = op.getOperand().getDefiningOp<CastOp>();
2594 Type newResultType = CollapseShapeOp::computeCollapsedType(
2595 llvm::cast<MemRefType>(cast.getOperand().getType()),
2596 op.getReassociationIndices());
2598 if (newResultType == op.getResultType()) {
2600 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2603 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2615 memref::DimOp, MemRefType>,
2619 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2620 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2621 adaptor.getOperands());
2624 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2625 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2626 adaptor.getOperands());
2633 void ReshapeOp::getAsmResultNames(
2635 setNameFn(getResult(),
"reshape");
2639 Type operandType = getSource().getType();
2640 Type resultType = getResult().getType();
2642 Type operandElementType =
2643 llvm::cast<ShapedType>(operandType).getElementType();
2644 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2645 if (operandElementType != resultElementType)
2646 return emitOpError(
"element types of source and destination memref "
2647 "types should be the same");
2649 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2650 if (!operandMemRefType.getLayout().isIdentity())
2651 return emitOpError(
"source memref type should have identity affine map");
2655 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2656 if (resultMemRefType) {
2657 if (!resultMemRefType.getLayout().isIdentity())
2658 return emitOpError(
"result memref type should have identity affine map");
2659 if (shapeSize == ShapedType::kDynamic)
2660 return emitOpError(
"cannot use shape operand with dynamic length to "
2661 "reshape to statically-ranked memref type");
2662 if (shapeSize != resultMemRefType.getRank())
2664 "length of shape operand differs from the result's memref rank");
2675 return emitOpError(
"store index operand count not equal to memref rank");
2680 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2690 void SubViewOp::getAsmResultNames(
2692 setNameFn(getResult(),
"subview");
2698 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2702 unsigned rank = sourceMemRefType.getRank();
2704 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2705 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2706 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2713 int64_t targetOffset = sourceOffset;
2714 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2715 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2725 targetStrides.reserve(staticOffsets.size());
2726 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2727 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2734 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2736 targetOffset, targetStrides),
2737 sourceMemRefType.getMemorySpace());
2740 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2755 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2756 staticSizes, staticStrides);
2760 MemRefType sourceRankedTensorType,
2764 auto inferredType = llvm::cast<MemRefType>(
2765 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2766 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2768 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2769 return inferredType;
2772 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2774 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2777 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2779 rankReducedStrides.reserve(resultShape.size());
2780 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2781 if (!dimsToProject->contains(idx))
2782 rankReducedStrides.push_back(value);
2786 inferredLayout.getOffset(),
2787 rankReducedStrides),
2788 inferredType.getMemorySpace());
2792 MemRefType sourceRankedTensorType,
2801 return SubViewOp::inferRankReducedResultType(
2802 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2809 MemRefType resultType,
Value source,
2819 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2822 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2823 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2826 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2839 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2848 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2852 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2856 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2859 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2865 MemRefType resultType,
Value source,
2870 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2874 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2878 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2881 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2897 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2904 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2908 Value SubViewOp::getViewSource() {
return getSource(); }
2913 int64_t t1Offset, t2Offset;
2917 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2924 const llvm::SmallBitVector &droppedDims) {
2925 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2926 "incorrect number of bits");
2927 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2928 "incorrect number of dropped dims");
2929 int64_t t1Offset, t2Offset;
2933 if (failed(res1) || failed(res2))
2935 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2938 if (t1Strides[i] != t2Strides[
j])
2947 auto memrefType = llvm::cast<ShapedType>(expectedType);
2952 return op->
emitError(
"expected result rank to be smaller or equal to ")
2953 <<
"the source rank. ";
2955 return op->
emitError(
"expected result type to be ")
2957 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2959 return op->
emitError(
"expected result element type to be ")
2960 << memrefType.getElementType();
2962 return op->
emitError(
"expected result and source memory spaces to match.");
2964 return op->
emitError(
"expected result type to be ")
2966 <<
" or a rank-reduced version. (mismatch of result layout) ";
2968 llvm_unreachable(
"unexpected subview verification result");
2973 MemRefType baseType = getSourceType();
2974 MemRefType subViewType =
getType();
2977 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2978 return emitError(
"different memory spaces specified for base memref "
2980 << baseType <<
" and subview memref type " << subViewType;
2984 return emitError(
"base type ") << baseType <<
" is not strided";
2988 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2989 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2994 expectedType, subViewType);
2999 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3001 *
this, expectedType);
3006 *
this, expectedType);
3014 if (failed(unusedDims))
3016 *
this, expectedType);
3021 *
this, expectedType);
3027 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3036 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3037 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3038 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3040 unsigned rank = ranks[0];
3042 for (
unsigned idx = 0; idx < rank; ++idx) {
3044 op.isDynamicOffset(idx)
3045 ? op.getDynamicOffset(idx)
3048 op.isDynamicSize(idx)
3049 ? op.getDynamicSize(idx)
3052 op.isDynamicStride(idx)
3053 ? op.getDynamicStride(idx)
3055 res.emplace_back(
Range{offset, size, stride});
3068 MemRefType currentResultType, MemRefType currentSourceType,
3071 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3072 sourceType, mixedOffsets, mixedSizes, mixedStrides));
3074 currentSourceType, currentResultType, mixedSizes);
3075 if (failed(unusedDims))
3078 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3080 unsigned numDimsAfterReduction =
3081 nonRankReducedType.getRank() - unusedDims->count();
3082 shape.reserve(numDimsAfterReduction);
3083 strides.reserve(numDimsAfterReduction);
3084 for (
const auto &[idx, size, stride] :
3085 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3086 nonRankReducedType.getShape(), layout.getStrides())) {
3087 if (unusedDims->test(idx))
3089 shape.push_back(size);
3090 strides.push_back(stride);
3095 layout.getOffset(), strides),
3096 nonRankReducedType.getMemorySpace());
3101 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3102 unsigned rank = memrefType.getRank();
3107 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3108 targetShape, memrefType, offsets, sizes, strides));
3109 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3116 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3117 assert(sourceMemrefType &&
"not a ranked memref type");
3118 auto sourceShape = sourceMemrefType.getShape();
3119 if (sourceShape.equals(desiredShape))
3121 auto maybeRankReductionMask =
3123 if (!maybeRankReductionMask)
3133 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3136 auto mixedOffsets = subViewOp.getMixedOffsets();
3137 auto mixedSizes = subViewOp.getMixedSizes();
3138 auto mixedStrides = subViewOp.getMixedStrides();
3143 return !intValue || intValue.value() != 0;
3150 return !intValue || intValue.value() != 1;
3158 if (!intValue || *intValue != sourceShape[size.index()])
3182 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3186 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3190 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3191 return matchPattern(operand, matchConstantIndex());
3195 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3207 subViewOp.getType(), subViewOp.getSourceType(),
3208 llvm::cast<MemRefType>(castOp.getSource().getType()),
3209 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3210 subViewOp.getMixedStrides());
3215 subViewOp.getLoc(), resultType, castOp.getSource(),
3216 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3217 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3218 subViewOp.getStaticStrides());
3231 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3235 if (subViewOp.getSourceType() == subViewOp.getType()) {
3236 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3240 subViewOp.getSource());
3252 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3253 mixedSizes, mixedStrides);
3256 MemRefType nonReducedType = cast<MemRefType>(resTy);
3259 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3260 if (droppedDims.none())
3261 return nonReducedType;
3269 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3270 if (droppedDims.test(i))
3272 targetStrides.push_back(nonReducedStrides[i]);
3273 targetShape.push_back(nonReducedType.getDimSize(i));
3278 offset, targetStrides),
3279 nonReducedType.getMemorySpace());
3295 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3299 MemRefType sourceMemrefType = getSource().getType();
3300 MemRefType resultMemrefType = getResult().getType();
3302 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3304 if (resultMemrefType == sourceMemrefType &&
3305 resultMemrefType.hasStaticShape() &&
3306 (!resultLayout || resultLayout.hasStaticLayout())) {
3307 return getViewSource();
3313 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3314 auto srcSizes = srcSubview.getMixedSizes();
3316 auto offsets = getMixedOffsets();
3317 bool allOffsetsZero = llvm::all_of(
3319 auto strides = getMixedStrides();
3320 bool allStridesOne = llvm::all_of(
3322 bool allSizesSame = llvm::equal(sizes, srcSizes);
3323 if (allOffsetsZero && allStridesOne && allSizesSame &&
3324 resultMemrefType == sourceMemrefType)
3325 return getViewSource();
3335 void TransposeOp::getAsmResultNames(
3337 setNameFn(getResult(),
"transpose");
3343 auto originalSizes = memRefType.getShape();
3345 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3348 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3349 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3358 AffineMapAttr permutation,
3360 auto permutationMap = permutation.getValue();
3361 assert(permutationMap);
3363 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3367 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3368 build(b, result, resultType, in, attrs);
3373 p <<
" " << getIn() <<
" " << getPermutation();
3375 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3381 MemRefType srcType, dstType;
3390 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3397 return emitOpError(
"expected a permutation map");
3398 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3399 return emitOpError(
"expected a permutation map of same rank as the input");
3401 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3402 auto resultType = llvm::cast<MemRefType>(
getType());
3407 return emitOpError(
"result type ")
3409 <<
" is not equivalent to the canonical transposed input type "
3410 << canonicalResultType;
3417 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3421 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3423 getPermutation().
compose(otherTransposeOp.getPermutation());
3424 getInMutable().assign(otherTransposeOp.getIn());
3425 setPermutation(composedPermutation);
3435 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3436 setNameFn(getResult(),
"view");
3440 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3444 if (!baseType.getLayout().isIdentity())
3445 return emitError(
"unsupported map for base memref type ") << baseType;
3448 if (!viewType.getLayout().isIdentity())
3449 return emitError(
"unsupported map for result memref type ") << viewType;
3452 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3453 return emitError(
"different memory spaces specified for base memref "
3455 << baseType <<
" and view memref type " << viewType;
3458 unsigned numDynamicDims = viewType.getNumDynamicDims();
3459 if (getSizes().size() != numDynamicDims)
3460 return emitError(
"incorrect number of size operands for type ") << viewType;
3465 Value ViewOp::getViewSource() {
return getSource(); }
3472 LogicalResult matchAndRewrite(ViewOp viewOp,
3475 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3476 return matchPattern(operand, matchConstantIndex());
3481 auto memrefType = viewOp.getType();
3488 assert(oldOffset == 0 &&
"Expected 0 offset");
3496 newShapeConstants.reserve(memrefType.getRank());
3498 unsigned dynamicDimPos = 0;
3499 unsigned rank = memrefType.getRank();
3500 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3501 int64_t dimSize = memrefType.getDimSize(dim);
3503 if (!ShapedType::isDynamic(dimSize)) {
3504 newShapeConstants.push_back(dimSize);
3507 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3508 if (
auto constantIndexOp =
3509 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3511 newShapeConstants.push_back(constantIndexOp.value());
3514 newShapeConstants.push_back(dimSize);
3515 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3521 MemRefType newMemRefType =
3524 if (newMemRefType == memrefType)
3528 auto newViewOp = rewriter.
create<ViewOp>(
3529 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3530 viewOp.getByteShift(), newOperands);
3540 LogicalResult matchAndRewrite(ViewOp viewOp,
3542 Value memrefOperand = viewOp.getOperand(0);
3543 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3546 Value allocOperand = memrefCastOp.getOperand();
3551 viewOp.getByteShift(),
3561 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3571 "expects the number of subscripts to be equal to memref rank");
3572 switch (getKind()) {
3573 case arith::AtomicRMWKind::addf:
3574 case arith::AtomicRMWKind::maximumf:
3575 case arith::AtomicRMWKind::minimumf:
3576 case arith::AtomicRMWKind::mulf:
3577 if (!llvm::isa<FloatType>(getValue().
getType()))
3578 return emitOpError() <<
"with kind '"
3579 << arith::stringifyAtomicRMWKind(getKind())
3580 <<
"' expects a floating-point type";
3582 case arith::AtomicRMWKind::addi:
3583 case arith::AtomicRMWKind::maxs:
3584 case arith::AtomicRMWKind::maxu:
3585 case arith::AtomicRMWKind::mins:
3586 case arith::AtomicRMWKind::minu:
3587 case arith::AtomicRMWKind::muli:
3588 case arith::AtomicRMWKind::ori:
3589 case arith::AtomicRMWKind::andi:
3590 if (!llvm::isa<IntegerType>(getValue().
getType()))
3591 return emitOpError() <<
"with kind '"
3592 << arith::stringifyAtomicRMWKind(getKind())
3593 <<
"' expects an integer type";
3612 #define GET_OP_CLASSES
3613 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.