24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (!ShapedType::isDynamic(cstVal)) {
118 void AllocOp::getAsmResultNames(
120 setNameFn(getResult(),
"alloc");
123 void AllocaOp::getAsmResultNames(
125 setNameFn(getResult(),
"alloca");
128 template <
typename AllocLikeOp>
130 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131 "applies to only alloc or alloca");
132 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
134 return op.emitOpError(
"result must be a memref");
136 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137 return op.emitOpError(
"dimension operand count does not equal memref "
138 "dynamic dimension count");
140 unsigned numSymbols = 0;
141 if (!memRefType.getLayout().isIdentity())
142 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143 if (op.getSymbolOperands().size() != numSymbols)
144 return op.emitOpError(
"symbol operand count does not equal memref symbol "
146 << numSymbols <<
", got " << op.getSymbolOperands().size();
157 "requires an ancestor op with AutomaticAllocationScope trait");
164 template <
typename AllocLikeOp>
168 LogicalResult matchAndRewrite(AllocLikeOp alloc,
172 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
174 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
176 return constSizeArg.isNonNegative();
180 auto memrefType = alloc.getType();
185 newShapeConstants.reserve(memrefType.getRank());
188 unsigned dynamicDimPos = 0;
189 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190 int64_t dimSize = memrefType.getDimSize(dim);
192 if (!ShapedType::isDynamic(dimSize)) {
193 newShapeConstants.push_back(dimSize);
196 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
199 constSizeArg.isNonNegative()) {
201 newShapeConstants.push_back(constSizeArg.getZExtValue());
204 newShapeConstants.push_back(ShapedType::kDynamic);
205 dynamicSizes.push_back(dynamicSize);
211 MemRefType newMemRefType =
213 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
216 auto newAlloc = rewriter.
create<AllocLikeOp>(
217 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218 alloc.getAlignmentAttr());
226 template <
typename T>
230 LogicalResult matchAndRewrite(T alloc,
232 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
233 if (auto storeOp = dyn_cast<StoreOp>(op))
234 return storeOp.getValue() == alloc;
235 return !isa<DeallocOp>(op);
239 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
250 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
255 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
264 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
265 MemRefType resultType =
getType();
268 if (!sourceType.getLayout().isIdentity())
269 return emitError(
"unsupported layout for source memref type ")
273 if (!resultType.getLayout().isIdentity())
274 return emitError(
"unsupported layout for result memref type ")
278 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279 return emitError(
"different memory spaces specified for source memref "
281 << sourceType <<
" and result memref type " << resultType;
284 if (sourceType.getElementType() != resultType.getElementType())
285 return emitError(
"different element types specified for source memref "
287 << sourceType <<
" and result memref type " << resultType;
290 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291 return emitError(
"missing dimension operand for result type ")
293 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294 return emitError(
"unnecessary dimension operand for result type ")
302 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
310 bool printBlockTerminators =
false;
313 if (!getResults().empty()) {
314 p <<
" -> (" << getResultTypes() <<
")";
315 printBlockTerminators =
true;
320 printBlockTerminators);
336 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
346 void AllocaScopeOp::getSuccessorRegions(
359 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
365 if (isa<SideEffects::AutomaticAllocationScopeResource>(
366 effect->getResource()))
382 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
388 if (isa<SideEffects::AutomaticAllocationScopeResource>(
389 effect->getResource()))
412 bool hasPotentialAlloca =
425 if (hasPotentialAlloca) {
458 if (!lastParentWithoutScope ||
471 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
472 if (!lastParentWithoutScope ||
479 Region *containingRegion =
nullptr;
480 for (
auto &r : lastParentWithoutScope->
getRegions()) {
481 if (r.isAncestor(op->getParentRegion())) {
482 assert(containingRegion ==
nullptr &&
483 "only one region can contain the op");
484 containingRegion = &r;
487 assert(containingRegion &&
"op must be contained in a region");
497 return containingRegion->isAncestor(v.getParentRegion());
500 toHoist.push_back(alloc);
507 for (
auto *op : toHoist) {
508 auto *cloned = rewriter.
clone(*op);
509 rewriter.
replaceOp(op, cloned->getResults());
525 if (!llvm::isPowerOf2_32(getAlignment()))
526 return emitOpError(
"alignment must be power of 2");
530 void AssumeAlignmentOp::getAsmResultNames(
532 setNameFn(getResult(),
"assume_align");
540 setNameFn(getResult(),
"cast");
581 MemRefType sourceType =
582 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
583 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
586 if (!sourceType || !resultType)
590 if (sourceType.getElementType() != resultType.getElementType())
594 if (sourceType.getRank() != resultType.getRank())
598 int64_t sourceOffset, resultOffset;
600 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
601 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
605 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
606 auto ss = std::get<0>(it), st = std::get<1>(it);
608 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
613 if (sourceOffset != resultOffset)
614 if (ShapedType::isDynamic(sourceOffset) &&
615 !ShapedType::isDynamic(resultOffset))
619 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
620 auto ss = std::get<0>(it), st = std::get<1>(it);
622 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
630 if (inputs.size() != 1 || outputs.size() != 1)
632 Type a = inputs.front(), b = outputs.front();
633 auto aT = llvm::dyn_cast<MemRefType>(a);
634 auto bT = llvm::dyn_cast<MemRefType>(b);
636 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
637 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
640 if (aT.getElementType() != bT.getElementType())
642 if (aT.getLayout() != bT.getLayout()) {
643 int64_t aOffset, bOffset;
645 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
646 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
647 aStrides.size() != bStrides.size())
654 auto checkCompatible = [](int64_t a, int64_t b) {
655 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
657 if (!checkCompatible(aOffset, bOffset))
659 for (
const auto &aStride :
enumerate(aStrides))
660 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
663 if (aT.getMemorySpace() != bT.getMemorySpace())
667 if (aT.getRank() != bT.getRank())
670 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
671 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
672 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
686 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
687 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
688 if (aEltType != bEltType)
691 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
692 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
693 return aMemSpace == bMemSpace;
714 LogicalResult matchAndRewrite(CopyOp copyOp,
716 bool modified =
false;
719 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
720 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
721 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
723 if (fromType && toType) {
724 if (fromType.getShape() == toType.getShape() &&
725 fromType.getElementType() == toType.getElementType()) {
727 copyOp.getSourceMutable().assign(castOp.getSource());
735 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
736 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
737 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
739 if (fromType && toType) {
740 if (fromType.getShape() == toType.getShape() &&
741 fromType.getElementType() == toType.getElementType()) {
743 copyOp.getTargetMutable().assign(castOp.getSource());
750 return success(modified);
758 LogicalResult matchAndRewrite(CopyOp copyOp,
760 if (copyOp.getSource() != copyOp.getTarget())
775 LogicalResult matchAndRewrite(CopyOp copyOp,
777 if (isEmptyMemRef(copyOp.getSource().getType()) ||
778 isEmptyMemRef(copyOp.getTarget().getType())) {
790 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
793 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
801 operand.set(castOp.getOperand());
805 return success(folded);
812 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
823 setNameFn(getResult(),
"dim");
829 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
830 build(builder, result, source, indexValue);
833 std::optional<int64_t> DimOp::getConstantIndex() {
842 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
843 if (!rankedSourceType)
854 setResultRange(getResult(),
863 std::map<int64_t, unsigned> numOccurences;
864 for (
auto val : vals)
865 numOccurences[val]++;
866 return numOccurences;
876 static FailureOr<llvm::SmallBitVector>
879 llvm::SmallBitVector unusedDims(originalType.getRank());
880 if (originalType.getRank() == reducedType.getRank())
884 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
885 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
886 unusedDims.set(dim.index());
890 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
891 originalType.getRank())
895 int64_t originalOffset, candidateOffset;
897 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
899 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
911 std::map<int64_t, unsigned> currUnaccountedStrides =
913 std::map<int64_t, unsigned> candidateStridesNumOccurences =
915 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
916 if (!unusedDims.test(dim))
918 int64_t originalStride = originalStrides[dim];
919 if (currUnaccountedStrides[originalStride] >
920 candidateStridesNumOccurences[originalStride]) {
922 currUnaccountedStrides[originalStride]--;
925 if (currUnaccountedStrides[originalStride] ==
926 candidateStridesNumOccurences[originalStride]) {
928 unusedDims.reset(dim);
931 if (currUnaccountedStrides[originalStride] <
932 candidateStridesNumOccurences[originalStride]) {
939 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
940 originalType.getRank())
946 MemRefType sourceType = getSourceType();
947 MemRefType resultType =
getType();
948 FailureOr<llvm::SmallBitVector> unusedDims =
950 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
956 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
961 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
967 int64_t indexVal = index.getInt();
968 if (indexVal < 0 || indexVal >= memrefType.getRank())
972 if (!memrefType.isDynamicDim(index.getInt())) {
974 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
978 unsigned unsignedIndex = index.getValue().getZExtValue();
981 Operation *definingOp = getSource().getDefiningOp();
983 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
984 return *(alloc.getDynamicSizes().begin() +
985 memrefType.getDynamicDimIndex(unsignedIndex));
987 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
988 return *(alloca.getDynamicSizes().begin() +
989 memrefType.getDynamicDimIndex(unsignedIndex));
991 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
992 return *(view.getDynamicSizes().begin() +
993 memrefType.getDynamicDimIndex(unsignedIndex));
995 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
996 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
997 unsigned resultIndex = 0;
998 unsigned sourceRank = subview.getSourceType().getRank();
999 unsigned sourceIndex = 0;
1000 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1001 if (unusedDims.test(i))
1003 if (resultIndex == unsignedIndex) {
1009 assert(subview.isDynamicSize(sourceIndex) &&
1010 "expected dynamic subview size");
1011 return subview.getDynamicSize(sourceIndex);
1014 if (
auto sizeInterface =
1015 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1016 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1017 "Expected dynamic subview size");
1018 return sizeInterface.getDynamicSize(unsignedIndex);
1034 LogicalResult matchAndRewrite(DimOp dim,
1036 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1040 dim,
"Dim op is not defined by a reshape op.");
1051 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1052 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1053 if (reshape->isBeforeInBlock(definingOp)) {
1056 "dim.getIndex is not defined before reshape in the same block.");
1061 else if (dim->getBlock() != reshape->getBlock() &&
1062 !dim.getIndex().getParentRegion()->isProperAncestor(
1063 reshape->getParentRegion())) {
1068 dim,
"dim.getIndex does not dominate reshape.");
1076 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1077 if (load.
getType() != dim.getType())
1078 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1088 results.
add<DimOfMemRefReshape>(context);
1099 Value elementsPerStride) {
1111 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1112 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1113 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1115 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1118 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1119 <<
", " << getTagMemRef().getType();
1160 bool isStrided = strideInfo.size() == 2;
1161 if (!strideInfo.empty() && !isStrided) {
1163 "expected two stride related operands");
1168 if (types.size() != 3)
1191 unsigned numOperands = getNumOperands();
1195 if (numOperands < 4)
1196 return emitOpError(
"expected at least 4 operands");
1201 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1202 return emitOpError(
"expected source to be of memref type");
1203 if (numOperands < getSrcMemRefRank() + 4)
1204 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1206 if (!getSrcIndices().empty() &&
1207 !llvm::all_of(getSrcIndices().getTypes(),
1209 return emitOpError(
"expected source indices to be of index type");
1212 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1213 return emitOpError(
"expected destination to be of memref type");
1214 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1215 if (numOperands < numExpectedOperands)
1216 return emitOpError() <<
"expected at least " << numExpectedOperands
1218 if (!getDstIndices().empty() &&
1219 !llvm::all_of(getDstIndices().getTypes(),
1221 return emitOpError(
"expected destination indices to be of index type");
1225 return emitOpError(
"expected num elements to be of index type");
1228 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1229 return emitOpError(
"expected tag to be of memref type");
1230 numExpectedOperands += getTagMemRefRank();
1231 if (numOperands < numExpectedOperands)
1232 return emitOpError() <<
"expected at least " << numExpectedOperands
1234 if (!getTagIndices().empty() &&
1235 !llvm::all_of(getTagIndices().getTypes(),
1237 return emitOpError(
"expected tag indices to be of index type");
1241 if (numOperands != numExpectedOperands &&
1242 numOperands != numExpectedOperands + 2)
1243 return emitOpError(
"incorrect number of operands");
1247 if (!getStride().
getType().isIndex() ||
1248 !getNumElementsPerStride().
getType().isIndex())
1250 "expected stride and num elements per stride to be of type index");
1256 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1266 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1274 unsigned numTagIndices = getTagIndices().size();
1275 unsigned tagMemRefRank = getTagMemRefRank();
1276 if (numTagIndices != tagMemRefRank)
1277 return emitOpError() <<
"expected tagIndices to have the same number of "
1278 "elements as the tagMemRef rank, expected "
1279 << tagMemRefRank <<
", but got " << numTagIndices;
1287 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1289 setNameFn(getResult(),
"intptr");
1298 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1299 MLIRContext *context, std::optional<Location> location,
1300 ExtractStridedMetadataOp::Adaptor adaptor,
1302 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1306 unsigned sourceRank = sourceType.getRank();
1310 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1312 inferredReturnTypes.push_back(memrefType);
1314 inferredReturnTypes.push_back(indexType);
1316 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1317 inferredReturnTypes.push_back(indexType);
1321 void ExtractStridedMetadataOp::getAsmResultNames(
1323 setNameFn(getBaseBuffer(),
"base_buffer");
1324 setNameFn(getOffset(),
"offset");
1327 if (!getSizes().empty()) {
1328 setNameFn(getSizes().front(),
"sizes");
1329 setNameFn(getStrides().front(),
"strides");
1336 template <
typename Container>
1340 assert(values.size() == maybeConstants.size() &&
1341 " expected values and maybeConstants of the same size");
1342 bool atLeastOneReplacement =
false;
1343 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1348 assert(isa<Attribute>(maybeConstant) &&
1349 "The constified value should be either unchanged (i.e., == result) "
1351 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1352 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1353 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1357 atLeastOneReplacement =
true;
1360 return atLeastOneReplacement;
1364 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1370 getConstifiedMixedOffset());
1372 getConstifiedMixedSizes());
1374 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1376 return success(atLeastOneReplacement);
1386 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1390 LogicalResult status =
1391 getSource().getType().getStridesAndOffset(staticValues, unused);
1393 assert(succeeded(status) &&
"could not get strides from type");
1398 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1403 LogicalResult status =
1404 getSource().getType().getStridesAndOffset(unused, offset);
1406 assert(succeeded(status) &&
"could not get offset from type");
1407 staticValues.push_back(offset);
1422 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1423 Type elementType = memrefType.getElementType();
1433 auto &body = getRegion();
1434 if (body.getNumArguments() != 1)
1435 return emitOpError(
"expected single number of entry block arguments");
1437 if (getResult().
getType() != body.getArgument(0).getType())
1438 return emitOpError(
"expected block argument of the same type result type");
1445 "body of 'memref.generic_atomic_rmw' should contain "
1446 "only operations with no side effects");
1476 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1477 <<
"] : " << getMemref().
getType() <<
' ';
1487 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1488 Type resultType = getResult().getType();
1489 if (parentType != resultType)
1490 return emitOpError() <<
"types mismatch between yield op: " << resultType
1491 <<
" and its parent: " << parentType;
1503 if (!op.isExternal()) {
1505 if (op.isUninitialized())
1506 p <<
"uninitialized";
1519 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1520 if (!memrefType || !memrefType.hasStaticShape())
1522 <<
"type should be static shaped memref, but got " << type;
1536 if (!llvm::isa<ElementsAttr>(initialValue))
1538 <<
"initial value should be a unit or elements attribute";
1543 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1544 if (!memrefType || !memrefType.hasStaticShape())
1545 return emitOpError(
"type should be static shaped memref, but got ")
1550 if (getInitialValue().has_value()) {
1551 Attribute initValue = getInitialValue().value();
1552 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1553 return emitOpError(
"initial value should be a unit or elements "
1554 "attribute, but got ")
1559 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1560 Type initType = elementsAttr.getType();
1562 if (initType != tensorType)
1563 return emitOpError(
"initial value expected to be of type ")
1564 << tensorType <<
", but was of type " << initType;
1568 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1569 uint64_t alignment = *alignAttr;
1571 if (!llvm::isPowerOf2_64(alignment))
1572 return emitError() <<
"alignment attribute value " << alignment
1573 <<
" is not a power of 2";
1580 ElementsAttr GlobalOp::getConstantInitValue() {
1581 auto initVal = getInitialValue();
1582 if (getConstant() && initVal.has_value())
1583 return llvm::cast<ElementsAttr>(initVal.value());
1598 return emitOpError(
"'")
1599 << getName() <<
"' does not reference a valid global memref";
1601 Type resultType = getResult().getType();
1602 if (global.getType() != resultType)
1603 return emitOpError(
"result type ")
1604 << resultType <<
" does not match type " << global.getType()
1605 <<
" of the global memref @" << getName();
1615 return emitOpError(
"incorrect number of indices for load, expected ")
1632 void MemorySpaceCastOp::getAsmResultNames(
1634 setNameFn(getResult(),
"memspacecast");
1638 if (inputs.size() != 1 || outputs.size() != 1)
1640 Type a = inputs.front(), b = outputs.front();
1641 auto aT = llvm::dyn_cast<MemRefType>(a);
1642 auto bT = llvm::dyn_cast<MemRefType>(b);
1644 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1645 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1648 if (aT.getElementType() != bT.getElementType())
1650 if (aT.getLayout() != bT.getLayout())
1652 if (aT.getShape() != bT.getShape())
1657 return uaT.getElementType() == ubT.getElementType();
1662 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1665 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1666 getSourceMutable().assign(parentCast.getSource());
1677 p <<
" " << getMemref() <<
'[';
1679 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1680 p <<
", locality<" << getLocalityHint();
1681 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1683 (*this)->getAttrs(),
1684 {
"localityHint",
"isWrite",
"isDataCache"});
1691 IntegerAttr localityHint;
1693 StringRef readOrWrite, cacheType;
1710 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1712 "rw specifier has to be 'read' or 'write'");
1713 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1716 if (cacheType !=
"data" && cacheType !=
"instr")
1718 "cache type has to be 'data' or 'instr'");
1720 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1728 return emitOpError(
"too few indices");
1733 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1745 auto type = getOperand().getType();
1746 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1747 if (shapedType && shapedType.hasRank())
1749 return IntegerAttr();
1756 void ReinterpretCastOp::getAsmResultNames(
1758 setNameFn(getResult(),
"reinterpret_cast");
1765 MemRefType resultType,
Value source,
1775 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1786 auto sourceType = cast<BaseMemRefType>(source.
getType());
1793 b.
getContext(), staticOffsets.front(), staticStrides);
1794 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1795 stridedLayout, sourceType.getMemorySpace());
1796 build(b, result, resultType, source, offset, sizes, strides, attrs);
1800 MemRefType resultType,
Value source,
1805 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1809 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1813 strideValues, attrs);
1817 MemRefType resultType,
Value source,
Value offset,
1824 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1831 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1832 auto resultType = llvm::cast<MemRefType>(
getType());
1833 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1834 return emitError(
"different memory spaces specified for source type ")
1835 << srcType <<
" and result memref type " << resultType;
1836 if (srcType.getElementType() != resultType.getElementType())
1837 return emitError(
"different element types specified for source type ")
1838 << srcType <<
" and result memref type " << resultType;
1841 for (
auto [idx, resultSize, expectedSize] :
1843 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1844 return emitError(
"expected result type with size = ")
1845 << (ShapedType::isDynamic(expectedSize)
1846 ? std::string(
"dynamic")
1847 : std::to_string(expectedSize))
1848 <<
" instead of " << resultSize <<
" in dim = " << idx;
1854 int64_t resultOffset;
1856 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1857 return emitError(
"expected result type to have strided layout but found ")
1861 int64_t expectedOffset = getStaticOffsets().front();
1862 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1863 return emitError(
"expected result type with offset = ")
1864 << (ShapedType::isDynamic(expectedOffset)
1865 ? std::string(
"dynamic")
1866 : std::to_string(expectedOffset))
1867 <<
" instead of " << resultOffset;
1870 for (
auto [idx, resultStride, expectedStride] :
1872 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1873 return emitError(
"expected result type with stride = ")
1874 << (ShapedType::isDynamic(expectedStride)
1875 ? std::string(
"dynamic")
1876 : std::to_string(expectedStride))
1877 <<
" instead of " << resultStride <<
" in dim = " << idx;
1884 Value src = getSource();
1885 auto getPrevSrc = [&]() ->
Value {
1888 return prev.getSource();
1892 return prev.getSource();
1897 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1898 return isConstantIntValue(val, 0);
1900 return prev.getSource();
1905 if (
auto prevSrc = getPrevSrc()) {
1906 getSourceMutable().assign(prevSrc);
1929 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
1931 assert(succeeded(status) &&
"could not get strides from type");
1936 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1938 assert(values.size() == 1 &&
1939 "reinterpret_cast must have one and only one offset");
1942 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
1944 assert(succeeded(status) &&
"could not get offset from type");
1945 staticValues.push_back(offset);
1993 struct ReinterpretCastOpExtractStridedMetadataFolder
1998 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2000 auto extractStridedMetadata =
2001 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2002 if (!extractStridedMetadata)
2007 auto isReinterpretCastNoop = [&]() ->
bool {
2009 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2010 op.getConstifiedMixedStrides()))
2014 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2015 op.getConstifiedMixedSizes()))
2019 assert(op.getMixedOffsets().size() == 1 &&
2020 "reinterpret_cast with more than one offset should have been "
2021 "rejected by the verifier");
2022 return extractStridedMetadata.getConstifiedMixedOffset() ==
2023 op.getConstifiedMixedOffset();
2026 if (!isReinterpretCastNoop()) {
2043 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2053 Type srcTy = extractStridedMetadata.getSource().getType();
2054 if (srcTy == op.getResult().getType())
2055 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2058 extractStridedMetadata.getSource());
2067 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2074 void CollapseShapeOp::getAsmResultNames(
2076 setNameFn(getResult(),
"collapse_shape");
2079 void ExpandShapeOp::getAsmResultNames(
2081 setNameFn(getResult(),
"expand_shape");
2086 reifiedResultShapes = {
2087 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2096 static LogicalResult
2100 bool allowMultipleDynamicDimsPerGroup) {
2102 if (collapsedShape.size() != reassociation.size())
2103 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2104 << reassociation.size() <<
", expected " << collapsedShape.size();
2108 int64_t nextDim = 0;
2111 int64_t collapsedDim = it.index();
2113 bool foundDynamic =
false;
2114 for (int64_t expandedDim : group) {
2115 if (expandedDim != nextDim++)
2116 return op->
emitOpError(
"reassociation indices must be contiguous");
2118 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2120 << expandedDim <<
" is out of bounds";
2123 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2124 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2126 "at most one dimension in a reassociation group may be dynamic");
2127 foundDynamic =
true;
2132 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2135 <<
") must be dynamic if and only if reassociation group is "
2140 if (!foundDynamic) {
2141 int64_t groupSize = 1;
2142 for (int64_t expandedDim : group)
2143 groupSize *= expandedShape[expandedDim];
2144 if (groupSize != collapsedShape[collapsedDim])
2146 << collapsedShape[collapsedDim]
2147 <<
") must equal reassociation group size (" << groupSize <<
")";
2151 if (collapsedShape.empty()) {
2153 for (int64_t d : expandedShape)
2156 "rank 0 memrefs can only be extended/collapsed with/from ones");
2157 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2161 << expandedShape.size()
2162 <<
") inconsistent with number of reassociation indices (" << nextDim
2175 getReassociationIndices());
2184 getReassociationIndices());
2189 static FailureOr<StridedLayoutAttr>
2194 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2196 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2211 reverseResultStrides.reserve(resultShape.size());
2212 unsigned shapeIndex = resultShape.size() - 1;
2213 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2215 int64_t currentStrideToExpand = std::get<1>(it);
2216 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2217 reverseResultStrides.push_back(currentStrideToExpand);
2218 currentStrideToExpand =
2224 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2225 resultStrides.resize(resultShape.size(), 1);
2229 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2232 if (srcType.getLayout().isIdentity()) {
2235 MemRefLayoutAttrInterface layout;
2237 srcType.getMemorySpace());
2241 FailureOr<StridedLayoutAttr> computedLayout =
2243 if (failed(computedLayout))
2245 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2246 srcType.getMemorySpace());
2249 FailureOr<SmallVector<OpFoldResult>>
2251 MemRefType expandedType,
2254 std::optional<SmallVector<OpFoldResult>> outputShape =
2259 return *outputShape;
2266 auto [staticOutputShape, dynamicOutputShape] =
2268 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2270 dynamicOutputShape, staticOutputShape);
2278 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2279 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2280 builder, result.
location, memrefResultTy, reassociation, inputShape);
2283 assert(succeeded(outputShape) &&
"unable to infer output shape");
2284 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2291 auto srcType = llvm::cast<MemRefType>(src.
getType());
2292 FailureOr<MemRefType> resultType =
2293 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2296 assert(succeeded(resultType) &&
"could not compute layout");
2297 build(builder, result, *resultType, src, reassociation);
2305 auto srcType = llvm::cast<MemRefType>(src.
getType());
2306 FailureOr<MemRefType> resultType =
2307 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2310 assert(succeeded(resultType) &&
"could not compute layout");
2311 build(builder, result, *resultType, src, reassociation, outputShape);
2315 MemRefType srcType = getSrcType();
2316 MemRefType resultType = getResultType();
2318 if (srcType.getRank() > resultType.getRank()) {
2319 auto r0 = srcType.getRank();
2320 auto r1 = resultType.getRank();
2321 return emitOpError(
"has source rank ")
2322 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2323 << r0 <<
" > " << r1 <<
").";
2328 resultType.getShape(),
2329 getReassociationIndices(),
2334 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2335 srcType, resultType.getShape(), getReassociationIndices());
2336 if (failed(expectedResultType))
2337 return emitOpError(
"invalid source layout map");
2340 if (*expectedResultType != resultType)
2341 return emitOpError(
"expected expanded type to be ")
2342 << *expectedResultType <<
" but found " << resultType;
2344 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2345 return emitOpError(
"expected number of static shape bounds to be equal to "
2346 "the output rank (")
2347 << resultType.getRank() <<
") but found "
2348 << getStaticOutputShape().size() <<
" inputs instead";
2350 if ((int64_t)getOutputShape().size() !=
2351 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2352 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2353 "static_output_shape: static_output_shape has ")
2354 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2355 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2362 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2363 return emitOpError(
"invalid output shape provided at pos ") << pos;
2384 static FailureOr<StridedLayoutAttr>
2387 bool strict =
false) {
2390 auto srcShape = srcType.getShape();
2391 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2400 resultStrides.reserve(reassociation.size());
2403 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2404 ref = ref.drop_back();
2405 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2406 resultStrides.push_back(srcStrides[ref.back()]);
2412 resultStrides.push_back(ShapedType::kDynamic);
2417 unsigned resultStrideIndex = resultStrides.size() - 1;
2421 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2433 if (strict && (stride.saturated || srcStride.saturated))
2438 if (srcShape[idx - 1] == 1)
2441 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2448 bool CollapseShapeOp::isGuaranteedCollapsible(
2451 if (srcType.getLayout().isIdentity())
2458 MemRefType CollapseShapeOp::computeCollapsedType(
2461 resultShape.reserve(reassociation.size());
2464 for (int64_t srcDim : group)
2467 resultShape.push_back(groupSize.asInteger());
2470 if (srcType.getLayout().isIdentity()) {
2473 MemRefLayoutAttrInterface layout;
2475 srcType.getMemorySpace());
2481 FailureOr<StridedLayoutAttr> computedLayout =
2483 assert(succeeded(computedLayout) &&
2484 "invalid source layout map or collapsing non-contiguous dims");
2485 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2486 srcType.getMemorySpace());
2492 auto srcType = llvm::cast<MemRefType>(src.
getType());
2493 MemRefType resultType =
2494 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2497 build(b, result, resultType, src, attrs);
2501 MemRefType srcType = getSrcType();
2502 MemRefType resultType = getResultType();
2504 if (srcType.getRank() < resultType.getRank()) {
2505 auto r0 = srcType.getRank();
2506 auto r1 = resultType.getRank();
2507 return emitOpError(
"has source rank ")
2508 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2509 << r0 <<
" < " << r1 <<
").";
2514 srcType.getShape(), getReassociationIndices(),
2519 MemRefType expectedResultType;
2520 if (srcType.getLayout().isIdentity()) {
2523 MemRefLayoutAttrInterface layout;
2524 expectedResultType =
2525 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2526 srcType.getMemorySpace());
2531 FailureOr<StridedLayoutAttr> computedLayout =
2533 if (failed(computedLayout))
2535 "invalid source layout map or collapsing non-contiguous dims");
2536 expectedResultType =
2538 *computedLayout, srcType.getMemorySpace());
2541 if (expectedResultType != resultType)
2542 return emitOpError(
"expected collapsed type to be ")
2543 << expectedResultType <<
" but found " << resultType;
2555 auto cast = op.getOperand().getDefiningOp<CastOp>();
2562 Type newResultType = CollapseShapeOp::computeCollapsedType(
2563 llvm::cast<MemRefType>(cast.getOperand().getType()),
2564 op.getReassociationIndices());
2566 if (newResultType == op.getResultType()) {
2568 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2571 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2583 memref::DimOp, MemRefType>,
2587 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2588 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2589 adaptor.getOperands());
2592 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2593 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2594 adaptor.getOperands());
2601 void ReshapeOp::getAsmResultNames(
2603 setNameFn(getResult(),
"reshape");
2607 Type operandType = getSource().getType();
2608 Type resultType = getResult().getType();
2610 Type operandElementType =
2611 llvm::cast<ShapedType>(operandType).getElementType();
2612 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2613 if (operandElementType != resultElementType)
2614 return emitOpError(
"element types of source and destination memref "
2615 "types should be the same");
2617 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2618 if (!operandMemRefType.getLayout().isIdentity())
2619 return emitOpError(
"source memref type should have identity affine map");
2623 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2624 if (resultMemRefType) {
2625 if (!resultMemRefType.getLayout().isIdentity())
2626 return emitOpError(
"result memref type should have identity affine map");
2627 if (shapeSize == ShapedType::kDynamic)
2628 return emitOpError(
"cannot use shape operand with dynamic length to "
2629 "reshape to statically-ranked memref type");
2630 if (shapeSize != resultMemRefType.getRank())
2632 "length of shape operand differs from the result's memref rank");
2643 return emitOpError(
"store index operand count not equal to memref rank");
2648 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2658 void SubViewOp::getAsmResultNames(
2660 setNameFn(getResult(),
"subview");
2666 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2670 unsigned rank = sourceMemRefType.getRank();
2672 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2673 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2674 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2677 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2681 int64_t targetOffset = sourceOffset;
2682 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2683 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2693 targetStrides.reserve(staticOffsets.size());
2694 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2695 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2702 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2704 targetOffset, targetStrides),
2705 sourceMemRefType.getMemorySpace());
2708 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2723 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2724 staticSizes, staticStrides);
2727 MemRefType SubViewOp::inferRankReducedResultType(
2731 MemRefType inferredType =
2732 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2733 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2735 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2736 return inferredType;
2739 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2741 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2744 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2746 rankReducedStrides.reserve(resultShape.size());
2747 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2748 if (!dimsToProject->contains(idx))
2749 rankReducedStrides.push_back(value);
2753 inferredLayout.getOffset(),
2754 rankReducedStrides),
2755 inferredType.getMemorySpace());
2758 MemRefType SubViewOp::inferRankReducedResultType(
2767 return SubViewOp::inferRankReducedResultType(
2768 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2775 MemRefType resultType,
Value source,
2785 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2788 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2789 staticSizes, staticStrides);
2792 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2805 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2814 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2818 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2822 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2825 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2831 MemRefType resultType,
Value source,
2836 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2840 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2844 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2847 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2863 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2870 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2874 Value SubViewOp::getViewSource() {
return getSource(); }
2879 int64_t t1Offset, t2Offset;
2881 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2882 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2883 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2890 const llvm::SmallBitVector &droppedDims) {
2891 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2892 "incorrect number of bits");
2893 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2894 "incorrect number of dropped dims");
2895 int64_t t1Offset, t2Offset;
2897 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2898 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2899 if (failed(res1) || failed(res2))
2901 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2904 if (t1Strides[i] != t2Strides[
j])
2913 auto memrefType = llvm::cast<ShapedType>(expectedType);
2918 return op->
emitError(
"expected result rank to be smaller or equal to ")
2919 <<
"the source rank. ";
2921 return op->
emitError(
"expected result type to be ")
2923 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2925 return op->
emitError(
"expected result element type to be ")
2926 << memrefType.getElementType();
2928 return op->
emitError(
"expected result and source memory spaces to match.");
2930 return op->
emitError(
"expected result type to be ")
2932 <<
" or a rank-reduced version. (mismatch of result layout) ";
2934 llvm_unreachable(
"unexpected subview verification result");
2939 MemRefType baseType = getSourceType();
2940 MemRefType subViewType =
getType();
2946 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2947 return emitError(
"different memory spaces specified for base memref "
2949 << baseType <<
" and subview memref type " << subViewType;
2952 if (!baseType.isStrided())
2953 return emitError(
"base type ") << baseType <<
" is not strided";
2957 MemRefType expectedType = SubViewOp::inferResultType(
2958 baseType, staticOffsets, staticSizes, staticStrides);
2963 expectedType, subViewType);
2968 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2970 *
this, expectedType);
2975 *
this, expectedType);
2983 if (failed(unusedDims))
2985 *
this, expectedType);
2990 *
this, expectedType);
2996 staticStrides,
true);
2998 return getOperation()->emitError(boundsResult.
errorMessage);
3004 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3013 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3014 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3015 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3017 unsigned rank = ranks[0];
3019 for (
unsigned idx = 0; idx < rank; ++idx) {
3021 op.isDynamicOffset(idx)
3022 ? op.getDynamicOffset(idx)
3025 op.isDynamicSize(idx)
3026 ? op.getDynamicSize(idx)
3029 op.isDynamicStride(idx)
3030 ? op.getDynamicStride(idx)
3032 res.emplace_back(
Range{offset, size, stride});
3045 MemRefType currentResultType, MemRefType currentSourceType,
3048 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3049 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3051 currentSourceType, currentResultType, mixedSizes);
3052 if (failed(unusedDims))
3055 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3057 unsigned numDimsAfterReduction =
3058 nonRankReducedType.getRank() - unusedDims->count();
3059 shape.reserve(numDimsAfterReduction);
3060 strides.reserve(numDimsAfterReduction);
3061 for (
const auto &[idx, size, stride] :
3062 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3063 nonRankReducedType.getShape(), layout.getStrides())) {
3064 if (unusedDims->test(idx))
3066 shape.push_back(size);
3067 strides.push_back(stride);
3072 layout.getOffset(), strides),
3073 nonRankReducedType.getMemorySpace());
3078 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3079 unsigned rank = memrefType.getRank();
3083 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3084 targetShape, memrefType, offsets, sizes, strides);
3085 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3092 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3093 assert(sourceMemrefType &&
"not a ranked memref type");
3094 auto sourceShape = sourceMemrefType.getShape();
3095 if (sourceShape.equals(desiredShape))
3097 auto maybeRankReductionMask =
3099 if (!maybeRankReductionMask)
3109 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3112 auto mixedOffsets = subViewOp.getMixedOffsets();
3113 auto mixedSizes = subViewOp.getMixedSizes();
3114 auto mixedStrides = subViewOp.getMixedStrides();
3119 return !intValue || intValue.value() != 0;
3126 return !intValue || intValue.value() != 1;
3134 if (!intValue || *intValue != sourceShape[size.index()])
3158 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3162 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3166 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3167 return matchPattern(operand, matchConstantIndex());
3171 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3183 subViewOp.getType(), subViewOp.getSourceType(),
3184 llvm::cast<MemRefType>(castOp.getSource().getType()),
3185 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3186 subViewOp.getMixedStrides());
3191 subViewOp.getLoc(), resultType, castOp.getSource(),
3192 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3193 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3194 subViewOp.getStaticStrides());
3207 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3211 if (subViewOp.getSourceType() == subViewOp.getType()) {
3212 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3216 subViewOp.getSource());
3228 MemRefType resTy = SubViewOp::inferResultType(
3229 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3232 MemRefType nonReducedType = resTy;
3235 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3236 if (droppedDims.none())
3237 return nonReducedType;
3240 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3245 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3246 if (droppedDims.test(i))
3248 targetStrides.push_back(nonReducedStrides[i]);
3249 targetShape.push_back(nonReducedType.getDimSize(i));
3254 offset, targetStrides),
3255 nonReducedType.getMemorySpace());
3271 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3275 MemRefType sourceMemrefType = getSource().getType();
3276 MemRefType resultMemrefType = getResult().getType();
3278 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3280 if (resultMemrefType == sourceMemrefType &&
3281 resultMemrefType.hasStaticShape() &&
3282 (!resultLayout || resultLayout.hasStaticLayout())) {
3283 return getViewSource();
3289 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3290 auto srcSizes = srcSubview.getMixedSizes();
3292 auto offsets = getMixedOffsets();
3293 bool allOffsetsZero = llvm::all_of(
3295 auto strides = getMixedStrides();
3296 bool allStridesOne = llvm::all_of(
3298 bool allSizesSame = llvm::equal(sizes, srcSizes);
3299 if (allOffsetsZero && allStridesOne && allSizesSame &&
3300 resultMemrefType == sourceMemrefType)
3301 return getViewSource();
3311 void TransposeOp::getAsmResultNames(
3313 setNameFn(getResult(),
"transpose");
3319 auto originalSizes = memRefType.getShape();
3320 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3321 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3324 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3325 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3334 AffineMapAttr permutation,
3336 auto permutationMap = permutation.getValue();
3337 assert(permutationMap);
3339 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3343 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3344 build(b, result, resultType, in, attrs);
3349 p <<
" " << getIn() <<
" " << getPermutation();
3351 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3357 MemRefType srcType, dstType;
3366 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3373 return emitOpError(
"expected a permutation map");
3374 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3375 return emitOpError(
"expected a permutation map of same rank as the input");
3377 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3378 auto resultType = llvm::cast<MemRefType>(
getType());
3380 .canonicalizeStridedLayout();
3382 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3383 return emitOpError(
"result type ")
3385 <<
" is not equivalent to the canonical transposed input type "
3386 << canonicalResultType;
3393 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3397 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3399 getPermutation().
compose(otherTransposeOp.getPermutation());
3400 getInMutable().assign(otherTransposeOp.getIn());
3401 setPermutation(composedPermutation);
3411 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3412 setNameFn(getResult(),
"view");
3416 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3420 if (!baseType.getLayout().isIdentity())
3421 return emitError(
"unsupported map for base memref type ") << baseType;
3424 if (!viewType.getLayout().isIdentity())
3425 return emitError(
"unsupported map for result memref type ") << viewType;
3428 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3429 return emitError(
"different memory spaces specified for base memref "
3431 << baseType <<
" and view memref type " << viewType;
3434 unsigned numDynamicDims = viewType.getNumDynamicDims();
3435 if (getSizes().size() != numDynamicDims)
3436 return emitError(
"incorrect number of size operands for type ") << viewType;
3441 Value ViewOp::getViewSource() {
return getSource(); }
3448 LogicalResult matchAndRewrite(ViewOp viewOp,
3451 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3452 return matchPattern(operand, matchConstantIndex());
3457 auto memrefType = viewOp.getType();
3462 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3464 assert(oldOffset == 0 &&
"Expected 0 offset");
3472 newShapeConstants.reserve(memrefType.getRank());
3474 unsigned dynamicDimPos = 0;
3475 unsigned rank = memrefType.getRank();
3476 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3477 int64_t dimSize = memrefType.getDimSize(dim);
3479 if (!ShapedType::isDynamic(dimSize)) {
3480 newShapeConstants.push_back(dimSize);
3483 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3484 if (
auto constantIndexOp =
3485 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3487 newShapeConstants.push_back(constantIndexOp.value());
3490 newShapeConstants.push_back(dimSize);
3491 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3497 MemRefType newMemRefType =
3500 if (newMemRefType == memrefType)
3504 auto newViewOp = rewriter.
create<ViewOp>(
3505 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3506 viewOp.getByteShift(), newOperands);
3516 LogicalResult matchAndRewrite(ViewOp viewOp,
3518 Value memrefOperand = viewOp.getOperand(0);
3519 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3522 Value allocOperand = memrefCastOp.getOperand();
3527 viewOp.getByteShift(),
3537 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3547 "expects the number of subscripts to be equal to memref rank");
3548 switch (getKind()) {
3549 case arith::AtomicRMWKind::addf:
3550 case arith::AtomicRMWKind::maximumf:
3551 case arith::AtomicRMWKind::minimumf:
3552 case arith::AtomicRMWKind::mulf:
3553 if (!llvm::isa<FloatType>(getValue().
getType()))
3554 return emitOpError() <<
"with kind '"
3555 << arith::stringifyAtomicRMWKind(getKind())
3556 <<
"' expects a floating-point type";
3558 case arith::AtomicRMWKind::addi:
3559 case arith::AtomicRMWKind::maxs:
3560 case arith::AtomicRMWKind::maxu:
3561 case arith::AtomicRMWKind::mins:
3562 case arith::AtomicRMWKind::minu:
3563 case arith::AtomicRMWKind::muli:
3564 case arith::AtomicRMWKind::ori:
3565 case arith::AtomicRMWKind::andi:
3566 if (!llvm::isa<IntegerType>(getValue().
getType()))
3567 return emitOpError() <<
"with kind '"
3568 << arith::stringifyAtomicRMWKind(getKind())
3569 <<
"' expects an integer type";
3588 #define GET_OP_CLASSES
3589 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.