24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
72 if (memrefType.isDynamicDim(dim))
73 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
80 auto memrefType = llvm::cast<MemRefType>(value.
getType());
82 for (int64_t i = 0; i < memrefType.getRank(); ++i)
124 int64_t constValue = it.value();
125 if (!isDynamic(constValue))
129 if (
auto attr = dyn_cast<Attribute>(ofr)) {
143 ofr = builder.
getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
146 std::optional<int64_t> maybeConstant =
166 LogicalResult hasStaticInformation =
167 memrefType.getStridesAndOffset(strides, offset);
168 if (failed(hasStaticInformation))
179 LogicalResult hasStaticInformation =
180 memrefType.getStridesAndOffset(strides, offset);
181 if (failed(hasStaticInformation))
190 void AllocOp::getAsmResultNames(
192 setNameFn(getResult(),
"alloc");
195 void AllocaOp::getAsmResultNames(
197 setNameFn(getResult(),
"alloca");
200 template <
typename AllocLikeOp>
202 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203 "applies to only alloc or alloca");
204 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
206 return op.emitOpError(
"result must be a memref");
208 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
209 return op.emitOpError(
"dimension operand count does not equal memref "
210 "dynamic dimension count");
212 unsigned numSymbols = 0;
213 if (!memRefType.getLayout().isIdentity())
214 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
215 if (op.getSymbolOperands().size() != numSymbols)
216 return op.emitOpError(
"symbol operand count does not equal memref symbol "
218 << numSymbols <<
", got " << op.getSymbolOperands().size();
229 "requires an ancestor op with AutomaticAllocationScope trait");
236 template <
typename AllocLikeOp>
240 LogicalResult matchAndRewrite(AllocLikeOp alloc,
244 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
246 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248 return constSizeArg.isNonNegative();
252 auto memrefType = alloc.getType();
257 newShapeConstants.reserve(memrefType.getRank());
260 unsigned dynamicDimPos = 0;
261 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
262 int64_t dimSize = memrefType.getDimSize(dim);
264 if (!ShapedType::isDynamic(dimSize)) {
265 newShapeConstants.push_back(dimSize);
268 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
271 constSizeArg.isNonNegative()) {
273 newShapeConstants.push_back(constSizeArg.getZExtValue());
276 newShapeConstants.push_back(ShapedType::kDynamic);
277 dynamicSizes.push_back(dynamicSize);
283 MemRefType newMemRefType =
285 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
288 auto newAlloc = rewriter.
create<AllocLikeOp>(
289 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
290 alloc.getAlignmentAttr());
298 template <
typename T>
302 LogicalResult matchAndRewrite(T alloc,
304 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
305 if (auto storeOp = dyn_cast<StoreOp>(op))
306 return storeOp.getValue() == alloc;
307 return !isa<DeallocOp>(op);
311 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
322 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
327 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
336 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
337 MemRefType resultType =
getType();
340 if (!sourceType.getLayout().isIdentity())
341 return emitError(
"unsupported layout for source memref type ")
345 if (!resultType.getLayout().isIdentity())
346 return emitError(
"unsupported layout for result memref type ")
350 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
351 return emitError(
"different memory spaces specified for source memref "
353 << sourceType <<
" and result memref type " << resultType;
356 if (sourceType.getElementType() != resultType.getElementType())
357 return emitError(
"different element types specified for source memref "
359 << sourceType <<
" and result memref type " << resultType;
362 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
363 return emitError(
"missing dimension operand for result type ")
365 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
366 return emitError(
"unnecessary dimension operand for result type ")
374 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
382 bool printBlockTerminators =
false;
385 if (!getResults().empty()) {
386 p <<
" -> (" << getResultTypes() <<
")";
387 printBlockTerminators =
true;
392 printBlockTerminators);
408 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
418 void AllocaScopeOp::getSuccessorRegions(
431 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
437 if (isa<SideEffects::AutomaticAllocationScopeResource>(
438 effect->getResource()))
454 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
460 if (isa<SideEffects::AutomaticAllocationScopeResource>(
461 effect->getResource()))
484 bool hasPotentialAlloca =
497 if (hasPotentialAlloca) {
530 if (!lastParentWithoutScope ||
543 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
544 if (!lastParentWithoutScope ||
551 Region *containingRegion =
nullptr;
552 for (
auto &r : lastParentWithoutScope->
getRegions()) {
553 if (r.isAncestor(op->getParentRegion())) {
554 assert(containingRegion ==
nullptr &&
555 "only one region can contain the op");
556 containingRegion = &r;
559 assert(containingRegion &&
"op must be contained in a region");
569 return containingRegion->isAncestor(v.getParentRegion());
572 toHoist.push_back(alloc);
579 for (
auto *op : toHoist) {
580 auto *cloned = rewriter.
clone(*op);
581 rewriter.
replaceOp(op, cloned->getResults());
597 if (!llvm::isPowerOf2_32(getAlignment()))
598 return emitOpError(
"alignment must be power of 2");
607 setNameFn(getResult(),
"cast");
648 MemRefType sourceType =
649 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
650 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653 if (!sourceType || !resultType)
657 if (sourceType.getElementType() != resultType.getElementType())
661 if (sourceType.getRank() != resultType.getRank())
665 int64_t sourceOffset, resultOffset;
667 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
668 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
672 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
673 auto ss = std::get<0>(it), st = std::get<1>(it);
675 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
680 if (sourceOffset != resultOffset)
681 if (ShapedType::isDynamic(sourceOffset) &&
682 !ShapedType::isDynamic(resultOffset))
686 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
687 auto ss = std::get<0>(it), st = std::get<1>(it);
689 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
697 if (inputs.size() != 1 || outputs.size() != 1)
699 Type a = inputs.front(), b = outputs.front();
700 auto aT = llvm::dyn_cast<MemRefType>(a);
701 auto bT = llvm::dyn_cast<MemRefType>(b);
703 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
704 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707 if (aT.getElementType() != bT.getElementType())
709 if (aT.getLayout() != bT.getLayout()) {
710 int64_t aOffset, bOffset;
712 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
713 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
714 aStrides.size() != bStrides.size())
721 auto checkCompatible = [](int64_t a, int64_t b) {
722 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
724 if (!checkCompatible(aOffset, bOffset))
726 for (
const auto &aStride :
enumerate(aStrides))
727 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730 if (aT.getMemorySpace() != bT.getMemorySpace())
734 if (aT.getRank() != bT.getRank())
737 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
738 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
739 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
753 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
754 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
755 if (aEltType != bEltType)
758 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
759 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
760 return aMemSpace == bMemSpace;
781 LogicalResult matchAndRewrite(CopyOp copyOp,
783 bool modified =
false;
786 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
787 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790 if (fromType && toType) {
791 if (fromType.getShape() == toType.getShape() &&
792 fromType.getElementType() == toType.getElementType()) {
794 copyOp.getSourceMutable().assign(castOp.getSource());
802 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
803 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806 if (fromType && toType) {
807 if (fromType.getShape() == toType.getShape() &&
808 fromType.getElementType() == toType.getElementType()) {
810 copyOp.getTargetMutable().assign(castOp.getSource());
817 return success(modified);
825 LogicalResult matchAndRewrite(CopyOp copyOp,
827 if (copyOp.getSource() != copyOp.getTarget())
842 LogicalResult matchAndRewrite(CopyOp copyOp,
844 if (isEmptyMemRef(copyOp.getSource().getType()) ||
845 isEmptyMemRef(copyOp.getTarget().getType())) {
857 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
860 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
868 operand.set(castOp.getOperand());
872 return success(folded);
879 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
890 setNameFn(getResult(),
"dim");
896 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
897 build(builder, result, source, indexValue);
900 std::optional<int64_t> DimOp::getConstantIndex() {
909 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
910 if (!rankedSourceType)
921 setResultRange(getResult(),
930 std::map<int64_t, unsigned> numOccurences;
931 for (
auto val : vals)
932 numOccurences[val]++;
933 return numOccurences;
943 static FailureOr<llvm::SmallBitVector>
946 llvm::SmallBitVector unusedDims(originalType.getRank());
947 if (originalType.getRank() == reducedType.getRank())
951 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
952 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
953 unusedDims.set(dim.index());
957 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
958 originalType.getRank())
962 int64_t originalOffset, candidateOffset;
964 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
966 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
978 std::map<int64_t, unsigned> currUnaccountedStrides =
980 std::map<int64_t, unsigned> candidateStridesNumOccurences =
982 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
983 if (!unusedDims.test(dim))
985 int64_t originalStride = originalStrides[dim];
986 if (currUnaccountedStrides[originalStride] >
987 candidateStridesNumOccurences[originalStride]) {
989 currUnaccountedStrides[originalStride]--;
992 if (currUnaccountedStrides[originalStride] ==
993 candidateStridesNumOccurences[originalStride]) {
995 unusedDims.reset(dim);
998 if (currUnaccountedStrides[originalStride] <
999 candidateStridesNumOccurences[originalStride]) {
1006 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1007 originalType.getRank())
1013 MemRefType sourceType = getSourceType();
1014 MemRefType resultType =
getType();
1015 FailureOr<llvm::SmallBitVector> unusedDims =
1017 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1023 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1028 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1034 int64_t indexVal = index.getInt();
1035 if (indexVal < 0 || indexVal >= memrefType.getRank())
1039 if (!memrefType.isDynamicDim(index.getInt())) {
1041 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1045 unsigned unsignedIndex = index.getValue().getZExtValue();
1048 Operation *definingOp = getSource().getDefiningOp();
1050 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1051 return *(alloc.getDynamicSizes().begin() +
1052 memrefType.getDynamicDimIndex(unsignedIndex));
1054 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1055 return *(alloca.getDynamicSizes().begin() +
1056 memrefType.getDynamicDimIndex(unsignedIndex));
1058 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1059 return *(view.getDynamicSizes().begin() +
1060 memrefType.getDynamicDimIndex(unsignedIndex));
1062 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1063 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1064 unsigned resultIndex = 0;
1065 unsigned sourceRank = subview.getSourceType().getRank();
1066 unsigned sourceIndex = 0;
1067 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1068 if (unusedDims.test(i))
1070 if (resultIndex == unsignedIndex) {
1076 assert(subview.isDynamicSize(sourceIndex) &&
1077 "expected dynamic subview size");
1078 return subview.getDynamicSize(sourceIndex);
1081 if (
auto sizeInterface =
1082 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1083 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1084 "Expected dynamic subview size");
1085 return sizeInterface.getDynamicSize(unsignedIndex);
1101 LogicalResult matchAndRewrite(DimOp dim,
1103 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1107 dim,
"Dim op is not defined by a reshape op.");
1118 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1119 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1120 if (reshape->isBeforeInBlock(definingOp)) {
1123 "dim.getIndex is not defined before reshape in the same block.");
1128 else if (dim->getBlock() != reshape->getBlock() &&
1129 !dim.getIndex().getParentRegion()->isProperAncestor(
1130 reshape->getParentRegion())) {
1135 dim,
"dim.getIndex does not dominate reshape.");
1143 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1144 if (load.
getType() != dim.getType())
1145 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1155 results.
add<DimOfMemRefReshape>(context);
1166 Value elementsPerStride) {
1178 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1179 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1180 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1182 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1185 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1186 <<
", " << getTagMemRef().getType();
1227 bool isStrided = strideInfo.size() == 2;
1228 if (!strideInfo.empty() && !isStrided) {
1230 "expected two stride related operands");
1235 if (types.size() != 3)
1258 unsigned numOperands = getNumOperands();
1262 if (numOperands < 4)
1263 return emitOpError(
"expected at least 4 operands");
1268 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1269 return emitOpError(
"expected source to be of memref type");
1270 if (numOperands < getSrcMemRefRank() + 4)
1271 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1273 if (!getSrcIndices().empty() &&
1274 !llvm::all_of(getSrcIndices().getTypes(),
1276 return emitOpError(
"expected source indices to be of index type");
1279 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1280 return emitOpError(
"expected destination to be of memref type");
1281 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1282 if (numOperands < numExpectedOperands)
1283 return emitOpError() <<
"expected at least " << numExpectedOperands
1285 if (!getDstIndices().empty() &&
1286 !llvm::all_of(getDstIndices().getTypes(),
1288 return emitOpError(
"expected destination indices to be of index type");
1292 return emitOpError(
"expected num elements to be of index type");
1295 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1296 return emitOpError(
"expected tag to be of memref type");
1297 numExpectedOperands += getTagMemRefRank();
1298 if (numOperands < numExpectedOperands)
1299 return emitOpError() <<
"expected at least " << numExpectedOperands
1301 if (!getTagIndices().empty() &&
1302 !llvm::all_of(getTagIndices().getTypes(),
1304 return emitOpError(
"expected tag indices to be of index type");
1308 if (numOperands != numExpectedOperands &&
1309 numOperands != numExpectedOperands + 2)
1310 return emitOpError(
"incorrect number of operands");
1314 if (!getStride().
getType().isIndex() ||
1315 !getNumElementsPerStride().
getType().isIndex())
1317 "expected stride and num elements per stride to be of type index");
1323 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1333 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1341 unsigned numTagIndices = getTagIndices().size();
1342 unsigned tagMemRefRank = getTagMemRefRank();
1343 if (numTagIndices != tagMemRefRank)
1344 return emitOpError() <<
"expected tagIndices to have the same number of "
1345 "elements as the tagMemRef rank, expected "
1346 << tagMemRefRank <<
", but got " << numTagIndices;
1354 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1356 setNameFn(getResult(),
"intptr");
1365 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1366 MLIRContext *context, std::optional<Location> location,
1367 ExtractStridedMetadataOp::Adaptor adaptor,
1369 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1373 unsigned sourceRank = sourceType.getRank();
1377 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1379 inferredReturnTypes.push_back(memrefType);
1381 inferredReturnTypes.push_back(indexType);
1383 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1384 inferredReturnTypes.push_back(indexType);
1388 void ExtractStridedMetadataOp::getAsmResultNames(
1390 setNameFn(getBaseBuffer(),
"base_buffer");
1391 setNameFn(getOffset(),
"offset");
1394 if (!getSizes().empty()) {
1395 setNameFn(getSizes().front(),
"sizes");
1396 setNameFn(getStrides().front(),
"strides");
1403 template <
typename Container>
1407 assert(values.size() == maybeConstants.size() &&
1408 " expected values and maybeConstants of the same size");
1409 bool atLeastOneReplacement =
false;
1410 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1415 assert(isa<Attribute>(maybeConstant) &&
1416 "The constified value should be either unchanged (i.e., == result) "
1418 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1419 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1420 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1424 atLeastOneReplacement =
true;
1427 return atLeastOneReplacement;
1431 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1437 getConstifiedMixedOffset());
1439 getConstifiedMixedSizes());
1441 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1443 return success(atLeastOneReplacement);
1454 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1461 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1479 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1480 Type elementType = memrefType.getElementType();
1490 auto &body = getRegion();
1491 if (body.getNumArguments() != 1)
1492 return emitOpError(
"expected single number of entry block arguments");
1494 if (getResult().
getType() != body.getArgument(0).getType())
1495 return emitOpError(
"expected block argument of the same type result type");
1502 "body of 'memref.generic_atomic_rmw' should contain "
1503 "only operations with no side effects");
1533 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1534 <<
"] : " << getMemref().
getType() <<
' ';
1544 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1545 Type resultType = getResult().getType();
1546 if (parentType != resultType)
1547 return emitOpError() <<
"types mismatch between yield op: " << resultType
1548 <<
" and its parent: " << parentType;
1560 if (!op.isExternal()) {
1562 if (op.isUninitialized())
1563 p <<
"uninitialized";
1576 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1577 if (!memrefType || !memrefType.hasStaticShape())
1579 <<
"type should be static shaped memref, but got " << type;
1593 if (!llvm::isa<ElementsAttr>(initialValue))
1595 <<
"initial value should be a unit or elements attribute";
1600 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1601 if (!memrefType || !memrefType.hasStaticShape())
1602 return emitOpError(
"type should be static shaped memref, but got ")
1607 if (getInitialValue().has_value()) {
1608 Attribute initValue = getInitialValue().value();
1609 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1610 return emitOpError(
"initial value should be a unit or elements "
1611 "attribute, but got ")
1616 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1617 Type initType = elementsAttr.getType();
1619 if (initType != tensorType)
1620 return emitOpError(
"initial value expected to be of type ")
1621 << tensorType <<
", but was of type " << initType;
1625 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1626 uint64_t alignment = *alignAttr;
1628 if (!llvm::isPowerOf2_64(alignment))
1629 return emitError() <<
"alignment attribute value " << alignment
1630 <<
" is not a power of 2";
1637 ElementsAttr GlobalOp::getConstantInitValue() {
1638 auto initVal = getInitialValue();
1639 if (getConstant() && initVal.has_value())
1640 return llvm::cast<ElementsAttr>(initVal.value());
1655 return emitOpError(
"'")
1656 << getName() <<
"' does not reference a valid global memref";
1658 Type resultType = getResult().getType();
1659 if (global.getType() != resultType)
1660 return emitOpError(
"result type ")
1661 << resultType <<
" does not match type " << global.getType()
1662 <<
" of the global memref @" << getName();
1672 return emitOpError(
"incorrect number of indices for load, expected ")
1689 void MemorySpaceCastOp::getAsmResultNames(
1691 setNameFn(getResult(),
"memspacecast");
1695 if (inputs.size() != 1 || outputs.size() != 1)
1697 Type a = inputs.front(), b = outputs.front();
1698 auto aT = llvm::dyn_cast<MemRefType>(a);
1699 auto bT = llvm::dyn_cast<MemRefType>(b);
1701 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1702 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1705 if (aT.getElementType() != bT.getElementType())
1707 if (aT.getLayout() != bT.getLayout())
1709 if (aT.getShape() != bT.getShape())
1714 return uaT.getElementType() == ubT.getElementType();
1719 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1722 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1723 getSourceMutable().assign(parentCast.getSource());
1734 p <<
" " << getMemref() <<
'[';
1736 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1737 p <<
", locality<" << getLocalityHint();
1738 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1740 (*this)->getAttrs(),
1741 {
"localityHint",
"isWrite",
"isDataCache"});
1748 IntegerAttr localityHint;
1750 StringRef readOrWrite, cacheType;
1767 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1769 "rw specifier has to be 'read' or 'write'");
1770 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1773 if (cacheType !=
"data" && cacheType !=
"instr")
1775 "cache type has to be 'data' or 'instr'");
1777 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1785 return emitOpError(
"too few indices");
1790 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804 if (shapedType && shapedType.hasRank())
1806 return IntegerAttr();
1813 void ReinterpretCastOp::getAsmResultNames(
1815 setNameFn(getResult(),
"reinterpret_cast");
1822 MemRefType resultType,
Value source,
1832 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1843 auto sourceType = cast<BaseMemRefType>(source.
getType());
1850 b.
getContext(), staticOffsets.front(), staticStrides);
1851 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1852 stridedLayout, sourceType.getMemorySpace());
1853 build(b, result, resultType, source, offset, sizes, strides, attrs);
1857 MemRefType resultType,
Value source,
1862 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1866 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1870 strideValues, attrs);
1874 MemRefType resultType,
Value source,
Value offset,
1881 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1888 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1889 auto resultType = llvm::cast<MemRefType>(
getType());
1890 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1891 return emitError(
"different memory spaces specified for source type ")
1892 << srcType <<
" and result memref type " << resultType;
1893 if (srcType.getElementType() != resultType.getElementType())
1894 return emitError(
"different element types specified for source type ")
1895 << srcType <<
" and result memref type " << resultType;
1898 for (
auto [idx, resultSize, expectedSize] :
1900 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1901 return emitError(
"expected result type with size = ")
1902 << (ShapedType::isDynamic(expectedSize)
1903 ? std::string(
"dynamic")
1904 : std::to_string(expectedSize))
1905 <<
" instead of " << resultSize <<
" in dim = " << idx;
1911 int64_t resultOffset;
1913 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1914 return emitError(
"expected result type to have strided layout but found ")
1918 int64_t expectedOffset = getStaticOffsets().front();
1919 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1920 return emitError(
"expected result type with offset = ")
1921 << (ShapedType::isDynamic(expectedOffset)
1922 ? std::string(
"dynamic")
1923 : std::to_string(expectedOffset))
1924 <<
" instead of " << resultOffset;
1927 for (
auto [idx, resultStride, expectedStride] :
1929 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1930 return emitError(
"expected result type with stride = ")
1931 << (ShapedType::isDynamic(expectedStride)
1932 ? std::string(
"dynamic")
1933 : std::to_string(expectedStride))
1934 <<
" instead of " << resultStride <<
" in dim = " << idx;
1941 Value src = getSource();
1942 auto getPrevSrc = [&]() ->
Value {
1945 return prev.getSource();
1949 return prev.getSource();
1954 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1955 return isConstantIntValue(val, 0);
1957 return prev.getSource();
1962 if (
auto prevSrc = getPrevSrc()) {
1963 getSourceMutable().assign(prevSrc);
1979 ShapedType::isDynamic);
1986 ShapedType::isDynamic);
1990 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1992 assert(values.size() == 1 &&
1993 "reinterpret_cast must have one and only one offset");
1995 ShapedType::isDynamic);
2037 struct ReinterpretCastOpExtractStridedMetadataFolder
2042 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2044 auto extractStridedMetadata =
2045 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2046 if (!extractStridedMetadata)
2053 extractStridedMetadata.getConstifiedMixedStrides();
2055 op.getConstifiedMixedStrides();
2056 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2059 unsigned rank = op.getType().getRank();
2060 for (
unsigned i = 0; i < rank; ++i) {
2061 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2066 assert(extractStridedMetadata.getSizes().size() ==
2067 op.getMixedSizes().size() &&
2068 "Strides and sizes rank must match");
2070 extractStridedMetadata.getConstifiedMixedSizes();
2072 op.getConstifiedMixedSizes();
2073 for (
unsigned i = 0; i < rank; ++i) {
2074 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2078 assert(op.getMixedOffsets().size() == 1 &&
2079 "reinterpret_cast with more than one offset should have been "
2080 "rejected by the verifier");
2082 extractStridedMetadata.getConstifiedMixedOffset();
2083 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2084 if (extractOffsetOfr != reinterpretOffsetOfr)
2092 Type srcTy = extractStridedMetadata.getSource().getType();
2093 if (srcTy == op.getResult().getType())
2094 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2097 extractStridedMetadata.getSource());
2106 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2113 void CollapseShapeOp::getAsmResultNames(
2115 setNameFn(getResult(),
"collapse_shape");
2118 void ExpandShapeOp::getAsmResultNames(
2120 setNameFn(getResult(),
"expand_shape");
2125 reifiedResultShapes = {
2126 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2135 static LogicalResult
2139 bool allowMultipleDynamicDimsPerGroup) {
2141 if (collapsedShape.size() != reassociation.size())
2142 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2143 << reassociation.size() <<
", expected " << collapsedShape.size();
2147 int64_t nextDim = 0;
2150 int64_t collapsedDim = it.index();
2152 bool foundDynamic =
false;
2153 for (int64_t expandedDim : group) {
2154 if (expandedDim != nextDim++)
2155 return op->
emitOpError(
"reassociation indices must be contiguous");
2157 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2159 << expandedDim <<
" is out of bounds";
2162 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2163 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2165 "at most one dimension in a reassociation group may be dynamic");
2166 foundDynamic =
true;
2171 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2174 <<
") must be dynamic if and only if reassociation group is "
2179 if (!foundDynamic) {
2180 int64_t groupSize = 1;
2181 for (int64_t expandedDim : group)
2182 groupSize *= expandedShape[expandedDim];
2183 if (groupSize != collapsedShape[collapsedDim])
2185 << collapsedShape[collapsedDim]
2186 <<
") must equal reassociation group size (" << groupSize <<
")";
2190 if (collapsedShape.empty()) {
2192 for (int64_t d : expandedShape)
2195 "rank 0 memrefs can only be extended/collapsed with/from ones");
2196 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2200 << expandedShape.size()
2201 <<
") inconsistent with number of reassociation indices (" << nextDim
2214 getReassociationIndices());
2223 getReassociationIndices());
2228 static FailureOr<StridedLayoutAttr>
2233 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2235 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2250 reverseResultStrides.reserve(resultShape.size());
2251 unsigned shapeIndex = resultShape.size() - 1;
2252 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2254 int64_t currentStrideToExpand = std::get<1>(it);
2255 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2256 reverseResultStrides.push_back(currentStrideToExpand);
2257 currentStrideToExpand =
2263 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2264 resultStrides.resize(resultShape.size(), 1);
2268 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2271 if (srcType.getLayout().isIdentity()) {
2274 MemRefLayoutAttrInterface layout;
2276 srcType.getMemorySpace());
2280 FailureOr<StridedLayoutAttr> computedLayout =
2282 if (failed(computedLayout))
2284 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2285 srcType.getMemorySpace());
2288 FailureOr<SmallVector<OpFoldResult>>
2290 MemRefType expandedType,
2293 std::optional<SmallVector<OpFoldResult>> outputShape =
2298 return *outputShape;
2305 auto [staticOutputShape, dynamicOutputShape] =
2307 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2309 dynamicOutputShape, staticOutputShape);
2317 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2318 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2319 builder, result.
location, memrefResultTy, reassociation, inputShape);
2322 assert(succeeded(outputShape) &&
"unable to infer output shape");
2323 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2330 auto srcType = llvm::cast<MemRefType>(src.
getType());
2331 FailureOr<MemRefType> resultType =
2332 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2335 assert(succeeded(resultType) &&
"could not compute layout");
2336 build(builder, result, *resultType, src, reassociation);
2344 auto srcType = llvm::cast<MemRefType>(src.
getType());
2345 FailureOr<MemRefType> resultType =
2346 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2349 assert(succeeded(resultType) &&
"could not compute layout");
2350 build(builder, result, *resultType, src, reassociation, outputShape);
2354 MemRefType srcType = getSrcType();
2355 MemRefType resultType = getResultType();
2357 if (srcType.getRank() > resultType.getRank()) {
2358 auto r0 = srcType.getRank();
2359 auto r1 = resultType.getRank();
2360 return emitOpError(
"has source rank ")
2361 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2362 << r0 <<
" > " << r1 <<
").";
2367 resultType.getShape(),
2368 getReassociationIndices(),
2373 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2374 srcType, resultType.getShape(), getReassociationIndices());
2375 if (failed(expectedResultType))
2376 return emitOpError(
"invalid source layout map");
2379 if (*expectedResultType != resultType)
2380 return emitOpError(
"expected expanded type to be ")
2381 << *expectedResultType <<
" but found " << resultType;
2383 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2384 return emitOpError(
"expected number of static shape bounds to be equal to "
2385 "the output rank (")
2386 << resultType.getRank() <<
") but found "
2387 << getStaticOutputShape().size() <<
" inputs instead";
2389 if ((int64_t)getOutputShape().size() !=
2390 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2391 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2392 "static_output_shape: static_output_shape has ")
2393 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2394 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2401 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2402 return emitOpError(
"invalid output shape provided at pos ") << pos;
2423 static FailureOr<StridedLayoutAttr>
2426 bool strict =
false) {
2429 auto srcShape = srcType.getShape();
2430 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2439 resultStrides.reserve(reassociation.size());
2442 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2443 ref = ref.drop_back();
2444 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2445 resultStrides.push_back(srcStrides[ref.back()]);
2451 resultStrides.push_back(ShapedType::kDynamic);
2456 unsigned resultStrideIndex = resultStrides.size() - 1;
2460 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2472 if (strict && (stride.saturated || srcStride.saturated))
2477 if (srcShape[idx - 1] == 1)
2480 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2487 bool CollapseShapeOp::isGuaranteedCollapsible(
2490 if (srcType.getLayout().isIdentity())
2497 MemRefType CollapseShapeOp::computeCollapsedType(
2500 resultShape.reserve(reassociation.size());
2503 for (int64_t srcDim : group)
2506 resultShape.push_back(groupSize.asInteger());
2509 if (srcType.getLayout().isIdentity()) {
2512 MemRefLayoutAttrInterface layout;
2514 srcType.getMemorySpace());
2520 FailureOr<StridedLayoutAttr> computedLayout =
2522 assert(succeeded(computedLayout) &&
2523 "invalid source layout map or collapsing non-contiguous dims");
2524 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2525 srcType.getMemorySpace());
2531 auto srcType = llvm::cast<MemRefType>(src.
getType());
2532 MemRefType resultType =
2533 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2536 build(b, result, resultType, src, attrs);
2540 MemRefType srcType = getSrcType();
2541 MemRefType resultType = getResultType();
2543 if (srcType.getRank() < resultType.getRank()) {
2544 auto r0 = srcType.getRank();
2545 auto r1 = resultType.getRank();
2546 return emitOpError(
"has source rank ")
2547 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2548 << r0 <<
" < " << r1 <<
").";
2553 srcType.getShape(), getReassociationIndices(),
2558 MemRefType expectedResultType;
2559 if (srcType.getLayout().isIdentity()) {
2562 MemRefLayoutAttrInterface layout;
2563 expectedResultType =
2564 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2565 srcType.getMemorySpace());
2570 FailureOr<StridedLayoutAttr> computedLayout =
2572 if (failed(computedLayout))
2574 "invalid source layout map or collapsing non-contiguous dims");
2575 expectedResultType =
2577 *computedLayout, srcType.getMemorySpace());
2580 if (expectedResultType != resultType)
2581 return emitOpError(
"expected collapsed type to be ")
2582 << expectedResultType <<
" but found " << resultType;
2594 auto cast = op.getOperand().getDefiningOp<CastOp>();
2601 Type newResultType = CollapseShapeOp::computeCollapsedType(
2602 llvm::cast<MemRefType>(cast.getOperand().getType()),
2603 op.getReassociationIndices());
2605 if (newResultType == op.getResultType()) {
2607 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2610 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2622 memref::DimOp, MemRefType>,
2626 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2627 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2628 adaptor.getOperands());
2631 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2632 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2633 adaptor.getOperands());
2640 void ReshapeOp::getAsmResultNames(
2642 setNameFn(getResult(),
"reshape");
2646 Type operandType = getSource().getType();
2647 Type resultType = getResult().getType();
2649 Type operandElementType =
2650 llvm::cast<ShapedType>(operandType).getElementType();
2651 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2652 if (operandElementType != resultElementType)
2653 return emitOpError(
"element types of source and destination memref "
2654 "types should be the same");
2656 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2657 if (!operandMemRefType.getLayout().isIdentity())
2658 return emitOpError(
"source memref type should have identity affine map");
2662 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2663 if (resultMemRefType) {
2664 if (!resultMemRefType.getLayout().isIdentity())
2665 return emitOpError(
"result memref type should have identity affine map");
2666 if (shapeSize == ShapedType::kDynamic)
2667 return emitOpError(
"cannot use shape operand with dynamic length to "
2668 "reshape to statically-ranked memref type");
2669 if (shapeSize != resultMemRefType.getRank())
2671 "length of shape operand differs from the result's memref rank");
2682 return emitOpError(
"store index operand count not equal to memref rank");
2687 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2697 void SubViewOp::getAsmResultNames(
2699 setNameFn(getResult(),
"subview");
2705 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2709 unsigned rank = sourceMemRefType.getRank();
2711 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2712 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2713 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2716 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2720 int64_t targetOffset = sourceOffset;
2721 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2722 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2732 targetStrides.reserve(staticOffsets.size());
2733 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2734 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2741 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2743 targetOffset, targetStrides),
2744 sourceMemRefType.getMemorySpace());
2747 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2762 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2763 staticSizes, staticStrides);
2766 MemRefType SubViewOp::inferRankReducedResultType(
2770 MemRefType inferredType =
2771 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2772 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2774 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2775 return inferredType;
2778 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2780 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2783 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2785 rankReducedStrides.reserve(resultShape.size());
2786 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2787 if (!dimsToProject->contains(idx))
2788 rankReducedStrides.push_back(value);
2792 inferredLayout.getOffset(),
2793 rankReducedStrides),
2794 inferredType.getMemorySpace());
2797 MemRefType SubViewOp::inferRankReducedResultType(
2806 return SubViewOp::inferRankReducedResultType(
2807 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2814 MemRefType resultType,
Value source,
2824 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2827 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2828 staticSizes, staticStrides);
2831 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2844 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2853 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2857 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2861 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2864 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2870 MemRefType resultType,
Value source,
2875 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2879 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2883 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2886 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2902 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2909 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2913 Value SubViewOp::getViewSource() {
return getSource(); }
2918 int64_t t1Offset, t2Offset;
2920 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2921 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2922 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2929 const llvm::SmallBitVector &droppedDims) {
2930 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2931 "incorrect number of bits");
2932 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2933 "incorrect number of dropped dims");
2934 int64_t t1Offset, t2Offset;
2936 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2937 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2938 if (failed(res1) || failed(res2))
2940 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2943 if (t1Strides[i] != t2Strides[
j])
2952 auto memrefType = llvm::cast<ShapedType>(expectedType);
2957 return op->
emitError(
"expected result rank to be smaller or equal to ")
2958 <<
"the source rank. ";
2960 return op->
emitError(
"expected result type to be ")
2962 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2964 return op->
emitError(
"expected result element type to be ")
2965 << memrefType.getElementType();
2967 return op->
emitError(
"expected result and source memory spaces to match.");
2969 return op->
emitError(
"expected result type to be ")
2971 <<
" or a rank-reduced version. (mismatch of result layout) ";
2973 llvm_unreachable(
"unexpected subview verification result");
2978 MemRefType baseType = getSourceType();
2979 MemRefType subViewType =
getType();
2982 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2983 return emitError(
"different memory spaces specified for base memref "
2985 << baseType <<
" and subview memref type " << subViewType;
2988 if (!baseType.isStrided())
2989 return emitError(
"base type ") << baseType <<
" is not strided";
2993 MemRefType expectedType = SubViewOp::inferResultType(
2994 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2999 expectedType, subViewType);
3004 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3006 *
this, expectedType);
3011 *
this, expectedType);
3019 if (failed(unusedDims))
3021 *
this, expectedType);
3026 *
this, expectedType);
3032 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3041 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3042 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3043 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3045 unsigned rank = ranks[0];
3047 for (
unsigned idx = 0; idx < rank; ++idx) {
3049 op.isDynamicOffset(idx)
3050 ? op.getDynamicOffset(idx)
3053 op.isDynamicSize(idx)
3054 ? op.getDynamicSize(idx)
3057 op.isDynamicStride(idx)
3058 ? op.getDynamicStride(idx)
3060 res.emplace_back(
Range{offset, size, stride});
3073 MemRefType currentResultType, MemRefType currentSourceType,
3076 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3077 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3079 currentSourceType, currentResultType, mixedSizes);
3080 if (failed(unusedDims))
3083 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3085 unsigned numDimsAfterReduction =
3086 nonRankReducedType.getRank() - unusedDims->count();
3087 shape.reserve(numDimsAfterReduction);
3088 strides.reserve(numDimsAfterReduction);
3089 for (
const auto &[idx, size, stride] :
3090 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3091 nonRankReducedType.getShape(), layout.getStrides())) {
3092 if (unusedDims->test(idx))
3094 shape.push_back(size);
3095 strides.push_back(stride);
3100 layout.getOffset(), strides),
3101 nonRankReducedType.getMemorySpace());
3106 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3107 unsigned rank = memrefType.getRank();
3111 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3112 targetShape, memrefType, offsets, sizes, strides);
3113 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3120 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3121 assert(sourceMemrefType &&
"not a ranked memref type");
3122 auto sourceShape = sourceMemrefType.getShape();
3123 if (sourceShape.equals(desiredShape))
3125 auto maybeRankReductionMask =
3127 if (!maybeRankReductionMask)
3137 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3140 auto mixedOffsets = subViewOp.getMixedOffsets();
3141 auto mixedSizes = subViewOp.getMixedSizes();
3142 auto mixedStrides = subViewOp.getMixedStrides();
3147 return !intValue || intValue.value() != 0;
3154 return !intValue || intValue.value() != 1;
3162 if (!intValue || *intValue != sourceShape[size.index()])
3186 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3190 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3194 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3195 return matchPattern(operand, matchConstantIndex());
3199 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3211 subViewOp.getType(), subViewOp.getSourceType(),
3212 llvm::cast<MemRefType>(castOp.getSource().getType()),
3213 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3214 subViewOp.getMixedStrides());
3219 subViewOp.getLoc(), resultType, castOp.getSource(),
3220 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3221 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3222 subViewOp.getStaticStrides());
3235 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3239 if (subViewOp.getSourceType() == subViewOp.getType()) {
3240 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3244 subViewOp.getSource());
3256 MemRefType resTy = SubViewOp::inferResultType(
3257 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3260 MemRefType nonReducedType = resTy;
3263 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3264 if (droppedDims.none())
3265 return nonReducedType;
3268 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3273 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3274 if (droppedDims.test(i))
3276 targetStrides.push_back(nonReducedStrides[i]);
3277 targetShape.push_back(nonReducedType.getDimSize(i));
3282 offset, targetStrides),
3283 nonReducedType.getMemorySpace());
3299 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3303 MemRefType sourceMemrefType = getSource().getType();
3304 MemRefType resultMemrefType = getResult().getType();
3306 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3308 if (resultMemrefType == sourceMemrefType &&
3309 resultMemrefType.hasStaticShape() &&
3310 (!resultLayout || resultLayout.hasStaticLayout())) {
3311 return getViewSource();
3317 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3318 auto srcSizes = srcSubview.getMixedSizes();
3320 auto offsets = getMixedOffsets();
3321 bool allOffsetsZero = llvm::all_of(
3323 auto strides = getMixedStrides();
3324 bool allStridesOne = llvm::all_of(
3326 bool allSizesSame = llvm::equal(sizes, srcSizes);
3327 if (allOffsetsZero && allStridesOne && allSizesSame &&
3328 resultMemrefType == sourceMemrefType)
3329 return getViewSource();
3339 void TransposeOp::getAsmResultNames(
3341 setNameFn(getResult(),
"transpose");
3347 auto originalSizes = memRefType.getShape();
3348 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3349 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3352 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3353 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3362 AffineMapAttr permutation,
3364 auto permutationMap = permutation.getValue();
3365 assert(permutationMap);
3367 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3371 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3372 build(b, result, resultType, in, attrs);
3377 p <<
" " << getIn() <<
" " << getPermutation();
3379 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3385 MemRefType srcType, dstType;
3394 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3401 return emitOpError(
"expected a permutation map");
3402 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3403 return emitOpError(
"expected a permutation map of same rank as the input");
3405 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3406 auto resultType = llvm::cast<MemRefType>(
getType());
3408 .canonicalizeStridedLayout();
3410 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3411 return emitOpError(
"result type ")
3413 <<
" is not equivalent to the canonical transposed input type "
3414 << canonicalResultType;
3421 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3425 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3427 getPermutation().
compose(otherTransposeOp.getPermutation());
3428 getInMutable().assign(otherTransposeOp.getIn());
3429 setPermutation(composedPermutation);
3439 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3440 setNameFn(getResult(),
"view");
3444 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3448 if (!baseType.getLayout().isIdentity())
3449 return emitError(
"unsupported map for base memref type ") << baseType;
3452 if (!viewType.getLayout().isIdentity())
3453 return emitError(
"unsupported map for result memref type ") << viewType;
3456 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3457 return emitError(
"different memory spaces specified for base memref "
3459 << baseType <<
" and view memref type " << viewType;
3462 unsigned numDynamicDims = viewType.getNumDynamicDims();
3463 if (getSizes().size() != numDynamicDims)
3464 return emitError(
"incorrect number of size operands for type ") << viewType;
3469 Value ViewOp::getViewSource() {
return getSource(); }
3476 LogicalResult matchAndRewrite(ViewOp viewOp,
3479 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3480 return matchPattern(operand, matchConstantIndex());
3485 auto memrefType = viewOp.getType();
3490 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3492 assert(oldOffset == 0 &&
"Expected 0 offset");
3500 newShapeConstants.reserve(memrefType.getRank());
3502 unsigned dynamicDimPos = 0;
3503 unsigned rank = memrefType.getRank();
3504 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3505 int64_t dimSize = memrefType.getDimSize(dim);
3507 if (!ShapedType::isDynamic(dimSize)) {
3508 newShapeConstants.push_back(dimSize);
3511 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3512 if (
auto constantIndexOp =
3513 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3515 newShapeConstants.push_back(constantIndexOp.value());
3518 newShapeConstants.push_back(dimSize);
3519 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3525 MemRefType newMemRefType =
3528 if (newMemRefType == memrefType)
3532 auto newViewOp = rewriter.
create<ViewOp>(
3533 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3534 viewOp.getByteShift(), newOperands);
3544 LogicalResult matchAndRewrite(ViewOp viewOp,
3546 Value memrefOperand = viewOp.getOperand(0);
3547 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3550 Value allocOperand = memrefCastOp.getOperand();
3555 viewOp.getByteShift(),
3565 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3575 "expects the number of subscripts to be equal to memref rank");
3576 switch (getKind()) {
3577 case arith::AtomicRMWKind::addf:
3578 case arith::AtomicRMWKind::maximumf:
3579 case arith::AtomicRMWKind::minimumf:
3580 case arith::AtomicRMWKind::mulf:
3581 if (!llvm::isa<FloatType>(getValue().
getType()))
3582 return emitOpError() <<
"with kind '"
3583 << arith::stringifyAtomicRMWKind(getKind())
3584 <<
"' expects a floating-point type";
3586 case arith::AtomicRMWKind::addi:
3587 case arith::AtomicRMWKind::maxs:
3588 case arith::AtomicRMWKind::maxu:
3589 case arith::AtomicRMWKind::mins:
3590 case arith::AtomicRMWKind::minu:
3591 case arith::AtomicRMWKind::muli:
3592 case arith::AtomicRMWKind::ori:
3593 case arith::AtomicRMWKind::andi:
3594 if (!llvm::isa<IntegerType>(getValue().
getType()))
3595 return emitOpError() <<
"with kind '"
3596 << arith::stringifyAtomicRMWKind(getKind())
3597 <<
"' expects an integer type";
3616 #define GET_OP_CLASSES
3617 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.