24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (ShapedType::isStatic(cstVal)) {
118 void AllocOp::getAsmResultNames(
120 setNameFn(getResult(),
"alloc");
123 void AllocaOp::getAsmResultNames(
125 setNameFn(getResult(),
"alloca");
128 template <
typename AllocLikeOp>
130 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131 "applies to only alloc or alloca");
132 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
134 return op.emitOpError(
"result must be a memref");
136 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137 return op.emitOpError(
"dimension operand count does not equal memref "
138 "dynamic dimension count");
140 unsigned numSymbols = 0;
141 if (!memRefType.getLayout().isIdentity())
142 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143 if (op.getSymbolOperands().size() != numSymbols)
144 return op.emitOpError(
"symbol operand count does not equal memref symbol "
146 << numSymbols <<
", got " << op.getSymbolOperands().size();
157 "requires an ancestor op with AutomaticAllocationScope trait");
164 template <
typename AllocLikeOp>
168 LogicalResult matchAndRewrite(AllocLikeOp alloc,
172 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
174 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
176 return constSizeArg.isNonNegative();
180 auto memrefType = alloc.getType();
185 newShapeConstants.reserve(memrefType.getRank());
188 unsigned dynamicDimPos = 0;
189 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190 int64_t dimSize = memrefType.getDimSize(dim);
192 if (ShapedType::isStatic(dimSize)) {
193 newShapeConstants.push_back(dimSize);
196 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
199 constSizeArg.isNonNegative()) {
201 newShapeConstants.push_back(constSizeArg.getZExtValue());
204 newShapeConstants.push_back(ShapedType::kDynamic);
205 dynamicSizes.push_back(dynamicSize);
211 MemRefType newMemRefType =
213 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
216 auto newAlloc = rewriter.
create<AllocLikeOp>(
217 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218 alloc.getAlignmentAttr());
226 template <
typename T>
230 LogicalResult matchAndRewrite(T alloc,
232 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
233 if (auto storeOp = dyn_cast<StoreOp>(op))
234 return storeOp.getValue() == alloc;
235 return !isa<DeallocOp>(op);
239 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
250 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
255 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
264 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
265 MemRefType resultType =
getType();
268 if (!sourceType.getLayout().isIdentity())
269 return emitError(
"unsupported layout for source memref type ")
273 if (!resultType.getLayout().isIdentity())
274 return emitError(
"unsupported layout for result memref type ")
278 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279 return emitError(
"different memory spaces specified for source memref "
281 << sourceType <<
" and result memref type " << resultType;
284 if (sourceType.getElementType() != resultType.getElementType())
285 return emitError(
"different element types specified for source memref "
287 << sourceType <<
" and result memref type " << resultType;
290 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291 return emitError(
"missing dimension operand for result type ")
293 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294 return emitError(
"unnecessary dimension operand for result type ")
302 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
310 bool printBlockTerminators =
false;
313 if (!getResults().empty()) {
314 p <<
" -> (" << getResultTypes() <<
")";
315 printBlockTerminators =
true;
320 printBlockTerminators);
336 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
346 void AllocaScopeOp::getSuccessorRegions(
359 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
365 if (isa<SideEffects::AutomaticAllocationScopeResource>(
366 effect->getResource()))
382 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
388 if (isa<SideEffects::AutomaticAllocationScopeResource>(
389 effect->getResource()))
413 bool hasPotentialAlloca =
426 if (hasPotentialAlloca) {
459 if (!lastParentWithoutScope ||
472 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
473 if (!lastParentWithoutScope ||
480 Region *containingRegion =
nullptr;
481 for (
auto &r : lastParentWithoutScope->
getRegions()) {
482 if (r.isAncestor(op->getParentRegion())) {
483 assert(containingRegion ==
nullptr &&
484 "only one region can contain the op");
485 containingRegion = &r;
488 assert(containingRegion &&
"op must be contained in a region");
498 return containingRegion->isAncestor(v.getParentRegion());
501 toHoist.push_back(alloc);
508 for (
auto *op : toHoist) {
509 auto *cloned = rewriter.
clone(*op);
510 rewriter.
replaceOp(op, cloned->getResults());
526 if (!llvm::isPowerOf2_32(getAlignment()))
527 return emitOpError(
"alignment must be power of 2");
531 void AssumeAlignmentOp::getAsmResultNames(
533 setNameFn(getResult(),
"assume_align");
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
540 if (source.getAlignment() != getAlignment())
550 setNameFn(getResult(),
"cast");
591 MemRefType sourceType =
592 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
596 if (!sourceType || !resultType)
600 if (sourceType.getElementType() != resultType.getElementType())
604 if (sourceType.getRank() != resultType.getRank())
608 int64_t sourceOffset, resultOffset;
610 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
615 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616 auto ss = std::get<0>(it), st = std::get<1>(it);
618 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
623 if (sourceOffset != resultOffset)
624 if (ShapedType::isDynamic(sourceOffset) &&
625 ShapedType::isStatic(resultOffset))
629 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
630 auto ss = std::get<0>(it), st = std::get<1>(it);
632 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
640 if (inputs.size() != 1 || outputs.size() != 1)
642 Type a = inputs.front(), b = outputs.front();
643 auto aT = llvm::dyn_cast<MemRefType>(a);
644 auto bT = llvm::dyn_cast<MemRefType>(b);
646 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
650 if (aT.getElementType() != bT.getElementType())
652 if (aT.getLayout() != bT.getLayout()) {
653 int64_t aOffset, bOffset;
655 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657 aStrides.size() != bStrides.size())
664 auto checkCompatible = [](int64_t a, int64_t b) {
665 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
667 if (!checkCompatible(aOffset, bOffset))
669 for (
const auto &aStride :
enumerate(aStrides))
670 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
673 if (aT.getMemorySpace() != bT.getMemorySpace())
677 if (aT.getRank() != bT.getRank())
680 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
696 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698 if (aEltType != bEltType)
701 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703 return aMemSpace == bMemSpace;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
725 if (copyOp.getSource() != copyOp.getTarget())
740 LogicalResult matchAndRewrite(CopyOp copyOp,
742 if (isEmptyMemRef(copyOp.getSource().getType()) ||
743 isEmptyMemRef(copyOp.getTarget().getType())) {
755 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
762 for (
OpOperand &operand : op->getOpOperands()) {
765 operand.set(castOp.getOperand());
772 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
783 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
794 setNameFn(getResult(),
"dim");
800 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
801 build(builder, result, source, indexValue);
804 std::optional<int64_t> DimOp::getConstantIndex() {
813 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
814 if (!rankedSourceType)
825 setResultRange(getResult(),
834 std::map<int64_t, unsigned> numOccurences;
835 for (
auto val : vals)
836 numOccurences[val]++;
837 return numOccurences;
847 static FailureOr<llvm::SmallBitVector>
850 llvm::SmallBitVector unusedDims(originalType.getRank());
851 if (originalType.getRank() == reducedType.getRank())
855 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
856 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
857 unusedDims.set(dim.index());
861 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
862 originalType.getRank())
866 int64_t originalOffset, candidateOffset;
868 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
870 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
882 std::map<int64_t, unsigned> currUnaccountedStrides =
884 std::map<int64_t, unsigned> candidateStridesNumOccurences =
886 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
887 if (!unusedDims.test(dim))
889 int64_t originalStride = originalStrides[dim];
890 if (currUnaccountedStrides[originalStride] >
891 candidateStridesNumOccurences[originalStride]) {
893 currUnaccountedStrides[originalStride]--;
896 if (currUnaccountedStrides[originalStride] ==
897 candidateStridesNumOccurences[originalStride]) {
899 unusedDims.reset(dim);
902 if (currUnaccountedStrides[originalStride] <
903 candidateStridesNumOccurences[originalStride]) {
910 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
911 originalType.getRank())
917 MemRefType sourceType = getSourceType();
918 MemRefType resultType =
getType();
919 FailureOr<llvm::SmallBitVector> unusedDims =
921 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
927 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
932 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
938 int64_t indexVal = index.getInt();
939 if (indexVal < 0 || indexVal >= memrefType.getRank())
943 if (!memrefType.isDynamicDim(index.getInt())) {
945 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
949 unsigned unsignedIndex = index.getValue().getZExtValue();
952 Operation *definingOp = getSource().getDefiningOp();
954 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
955 return *(alloc.getDynamicSizes().begin() +
956 memrefType.getDynamicDimIndex(unsignedIndex));
958 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
959 return *(alloca.getDynamicSizes().begin() +
960 memrefType.getDynamicDimIndex(unsignedIndex));
962 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
963 return *(view.getDynamicSizes().begin() +
964 memrefType.getDynamicDimIndex(unsignedIndex));
966 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
967 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
968 unsigned resultIndex = 0;
969 unsigned sourceRank = subview.getSourceType().getRank();
970 unsigned sourceIndex = 0;
971 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
972 if (unusedDims.test(i))
974 if (resultIndex == unsignedIndex) {
980 assert(subview.isDynamicSize(sourceIndex) &&
981 "expected dynamic subview size");
982 return subview.getDynamicSize(sourceIndex);
985 if (
auto sizeInterface =
986 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
987 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
988 "Expected dynamic subview size");
989 return sizeInterface.getDynamicSize(unsignedIndex);
1005 LogicalResult matchAndRewrite(DimOp dim,
1007 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1011 dim,
"Dim op is not defined by a reshape op.");
1022 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1023 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1024 if (reshape->isBeforeInBlock(definingOp)) {
1027 "dim.getIndex is not defined before reshape in the same block.");
1032 else if (dim->getBlock() != reshape->getBlock() &&
1033 !dim.getIndex().getParentRegion()->isProperAncestor(
1034 reshape->getParentRegion())) {
1039 dim,
"dim.getIndex does not dominate reshape.");
1047 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1048 if (load.
getType() != dim.getType())
1049 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1059 results.
add<DimOfMemRefReshape>(context);
1070 Value elementsPerStride) {
1082 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1083 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1084 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1086 p <<
", " <<
getStride() <<
", " << getNumElementsPerStride();
1089 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1090 <<
", " << getTagMemRef().getType();
1131 bool isStrided = strideInfo.size() == 2;
1132 if (!strideInfo.empty() && !isStrided) {
1134 "expected two stride related operands");
1139 if (types.size() != 3)
1162 unsigned numOperands = getNumOperands();
1166 if (numOperands < 4)
1167 return emitOpError(
"expected at least 4 operands");
1172 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1173 return emitOpError(
"expected source to be of memref type");
1174 if (numOperands < getSrcMemRefRank() + 4)
1175 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1177 if (!getSrcIndices().empty() &&
1178 !llvm::all_of(getSrcIndices().getTypes(),
1180 return emitOpError(
"expected source indices to be of index type");
1183 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1184 return emitOpError(
"expected destination to be of memref type");
1185 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1186 if (numOperands < numExpectedOperands)
1187 return emitOpError() <<
"expected at least " << numExpectedOperands
1189 if (!getDstIndices().empty() &&
1190 !llvm::all_of(getDstIndices().getTypes(),
1192 return emitOpError(
"expected destination indices to be of index type");
1196 return emitOpError(
"expected num elements to be of index type");
1199 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1200 return emitOpError(
"expected tag to be of memref type");
1201 numExpectedOperands += getTagMemRefRank();
1202 if (numOperands < numExpectedOperands)
1203 return emitOpError() <<
"expected at least " << numExpectedOperands
1205 if (!getTagIndices().empty() &&
1206 !llvm::all_of(getTagIndices().getTypes(),
1208 return emitOpError(
"expected tag indices to be of index type");
1212 if (numOperands != numExpectedOperands &&
1213 numOperands != numExpectedOperands + 2)
1214 return emitOpError(
"incorrect number of operands");
1219 !getNumElementsPerStride().
getType().isIndex())
1221 "expected stride and num elements per stride to be of type index");
1227 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1237 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1245 unsigned numTagIndices = getTagIndices().size();
1246 unsigned tagMemRefRank = getTagMemRefRank();
1247 if (numTagIndices != tagMemRefRank)
1248 return emitOpError() <<
"expected tagIndices to have the same number of "
1249 "elements as the tagMemRef rank, expected "
1250 << tagMemRefRank <<
", but got " << numTagIndices;
1258 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1260 setNameFn(getResult(),
"intptr");
1269 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1270 MLIRContext *context, std::optional<Location> location,
1271 ExtractStridedMetadataOp::Adaptor adaptor,
1273 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1277 unsigned sourceRank = sourceType.getRank();
1281 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1283 inferredReturnTypes.push_back(memrefType);
1285 inferredReturnTypes.push_back(indexType);
1287 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1288 inferredReturnTypes.push_back(indexType);
1292 void ExtractStridedMetadataOp::getAsmResultNames(
1294 setNameFn(getBaseBuffer(),
"base_buffer");
1295 setNameFn(getOffset(),
"offset");
1298 if (!getSizes().empty()) {
1299 setNameFn(getSizes().front(),
"sizes");
1300 setNameFn(getStrides().front(),
"strides");
1307 template <
typename Container>
1311 assert(values.size() == maybeConstants.size() &&
1312 " expected values and maybeConstants of the same size");
1313 bool atLeastOneReplacement =
false;
1314 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1319 assert(isa<Attribute>(maybeConstant) &&
1320 "The constified value should be either unchanged (i.e., == result) "
1322 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1323 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1324 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1328 atLeastOneReplacement =
true;
1331 return atLeastOneReplacement;
1335 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1341 getConstifiedMixedOffset());
1343 getConstifiedMixedSizes());
1345 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1347 return success(atLeastOneReplacement);
1357 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1361 LogicalResult status =
1362 getSource().getType().getStridesAndOffset(staticValues, unused);
1364 assert(succeeded(status) &&
"could not get strides from type");
1369 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1374 LogicalResult status =
1375 getSource().getType().getStridesAndOffset(unused, offset);
1377 assert(succeeded(status) &&
"could not get offset from type");
1378 staticValues.push_back(offset);
1393 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1394 Type elementType = memrefType.getElementType();
1404 auto &body = getRegion();
1405 if (body.getNumArguments() != 1)
1406 return emitOpError(
"expected single number of entry block arguments");
1408 if (getResult().
getType() != body.getArgument(0).getType())
1409 return emitOpError(
"expected block argument of the same type result type");
1416 "body of 'memref.generic_atomic_rmw' should contain "
1417 "only operations with no side effects");
1447 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1448 <<
"] : " << getMemref().
getType() <<
' ';
1458 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1459 Type resultType = getResult().getType();
1460 if (parentType != resultType)
1461 return emitOpError() <<
"types mismatch between yield op: " << resultType
1462 <<
" and its parent: " << parentType;
1474 if (!op.isExternal()) {
1476 if (op.isUninitialized())
1477 p <<
"uninitialized";
1490 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1491 if (!memrefType || !memrefType.hasStaticShape())
1493 <<
"type should be static shaped memref, but got " << type;
1507 if (!llvm::isa<ElementsAttr>(initialValue))
1509 <<
"initial value should be a unit or elements attribute";
1514 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1515 if (!memrefType || !memrefType.hasStaticShape())
1516 return emitOpError(
"type should be static shaped memref, but got ")
1521 if (getInitialValue().has_value()) {
1522 Attribute initValue = getInitialValue().value();
1523 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1524 return emitOpError(
"initial value should be a unit or elements "
1525 "attribute, but got ")
1530 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1532 auto initElementType =
1533 cast<TensorType>(elementsAttr.getType()).getElementType();
1534 auto memrefElementType = memrefType.getElementType();
1536 if (initElementType != memrefElementType)
1537 return emitOpError(
"initial value element expected to be of type ")
1538 << memrefElementType <<
", but was of type " << initElementType;
1543 auto initShape = elementsAttr.getShapedType().getShape();
1544 auto memrefShape = memrefType.getShape();
1545 if (initShape != memrefShape)
1546 return emitOpError(
"initial value shape expected to be ")
1547 << memrefShape <<
" but was " << initShape;
1551 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1552 uint64_t alignment = *alignAttr;
1554 if (!llvm::isPowerOf2_64(alignment))
1555 return emitError() <<
"alignment attribute value " << alignment
1556 <<
" is not a power of 2";
1563 ElementsAttr GlobalOp::getConstantInitValue() {
1564 auto initVal = getInitialValue();
1565 if (getConstant() && initVal.has_value())
1566 return llvm::cast<ElementsAttr>(initVal.value());
1581 return emitOpError(
"'")
1582 << getName() <<
"' does not reference a valid global memref";
1584 Type resultType = getResult().getType();
1585 if (global.getType() != resultType)
1586 return emitOpError(
"result type ")
1587 << resultType <<
" does not match type " << global.getType()
1588 <<
" of the global memref @" << getName();
1598 return emitOpError(
"incorrect number of indices for load, expected ")
1615 void MemorySpaceCastOp::getAsmResultNames(
1617 setNameFn(getResult(),
"memspacecast");
1621 if (inputs.size() != 1 || outputs.size() != 1)
1623 Type a = inputs.front(), b = outputs.front();
1624 auto aT = llvm::dyn_cast<MemRefType>(a);
1625 auto bT = llvm::dyn_cast<MemRefType>(b);
1627 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1628 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1631 if (aT.getElementType() != bT.getElementType())
1633 if (aT.getLayout() != bT.getLayout())
1635 if (aT.getShape() != bT.getShape())
1640 return uaT.getElementType() == ubT.getElementType();
1645 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1648 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1649 getSourceMutable().assign(parentCast.getSource());
1660 p <<
" " << getMemref() <<
'[';
1662 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1663 p <<
", locality<" << getLocalityHint();
1664 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1666 (*this)->getAttrs(),
1667 {
"localityHint",
"isWrite",
"isDataCache"});
1674 IntegerAttr localityHint;
1676 StringRef readOrWrite, cacheType;
1693 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1695 "rw specifier has to be 'read' or 'write'");
1696 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1699 if (cacheType !=
"data" && cacheType !=
"instr")
1701 "cache type has to be 'data' or 'instr'");
1703 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1711 return emitOpError(
"too few indices");
1716 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1728 auto type = getOperand().getType();
1729 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1730 if (shapedType && shapedType.hasRank())
1732 return IntegerAttr();
1739 void ReinterpretCastOp::getAsmResultNames(
1741 setNameFn(getResult(),
"reinterpret_cast");
1748 MemRefType resultType,
Value source,
1758 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1769 auto sourceType = cast<BaseMemRefType>(source.
getType());
1776 b.
getContext(), staticOffsets.front(), staticStrides);
1777 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1778 stridedLayout, sourceType.getMemorySpace());
1779 build(b, result, resultType, source, offset, sizes, strides, attrs);
1783 MemRefType resultType,
Value source,
1788 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1792 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1796 strideValues, attrs);
1800 MemRefType resultType,
Value source,
Value offset,
1807 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1814 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1815 auto resultType = llvm::cast<MemRefType>(
getType());
1816 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1817 return emitError(
"different memory spaces specified for source type ")
1818 << srcType <<
" and result memref type " << resultType;
1819 if (srcType.getElementType() != resultType.getElementType())
1820 return emitError(
"different element types specified for source type ")
1821 << srcType <<
" and result memref type " << resultType;
1824 for (
auto [idx, resultSize, expectedSize] :
1826 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1827 return emitError(
"expected result type with size = ")
1828 << (ShapedType::isDynamic(expectedSize)
1829 ? std::string(
"dynamic")
1830 : std::to_string(expectedSize))
1831 <<
" instead of " << resultSize <<
" in dim = " << idx;
1837 int64_t resultOffset;
1839 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1840 return emitError(
"expected result type to have strided layout but found ")
1844 int64_t expectedOffset = getStaticOffsets().front();
1845 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1846 return emitError(
"expected result type with offset = ")
1847 << (ShapedType::isDynamic(expectedOffset)
1848 ? std::string(
"dynamic")
1849 : std::to_string(expectedOffset))
1850 <<
" instead of " << resultOffset;
1853 for (
auto [idx, resultStride, expectedStride] :
1855 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1856 return emitError(
"expected result type with stride = ")
1857 << (ShapedType::isDynamic(expectedStride)
1858 ? std::string(
"dynamic")
1859 : std::to_string(expectedStride))
1860 <<
" instead of " << resultStride <<
" in dim = " << idx;
1867 Value src = getSource();
1868 auto getPrevSrc = [&]() ->
Value {
1871 return prev.getSource();
1875 return prev.getSource();
1881 return prev.getSource();
1886 if (
auto prevSrc = getPrevSrc()) {
1887 getSourceMutable().assign(prevSrc);
1910 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
1912 assert(succeeded(status) &&
"could not get strides from type");
1917 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1919 assert(values.size() == 1 &&
1920 "reinterpret_cast must have one and only one offset");
1923 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
1925 assert(succeeded(status) &&
"could not get offset from type");
1926 staticValues.push_back(offset);
1974 struct ReinterpretCastOpExtractStridedMetadataFolder
1979 LogicalResult matchAndRewrite(ReinterpretCastOp op,
1981 auto extractStridedMetadata =
1982 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
1983 if (!extractStridedMetadata)
1988 auto isReinterpretCastNoop = [&]() ->
bool {
1990 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
1991 op.getConstifiedMixedStrides()))
1995 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
1996 op.getConstifiedMixedSizes()))
2000 assert(op.getMixedOffsets().size() == 1 &&
2001 "reinterpret_cast with more than one offset should have been "
2002 "rejected by the verifier");
2003 return extractStridedMetadata.getConstifiedMixedOffset() ==
2004 op.getConstifiedMixedOffset();
2007 if (!isReinterpretCastNoop()) {
2024 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2034 Type srcTy = extractStridedMetadata.getSource().getType();
2035 if (srcTy == op.getResult().getType())
2036 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2039 extractStridedMetadata.getSource());
2048 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2055 void CollapseShapeOp::getAsmResultNames(
2057 setNameFn(getResult(),
"collapse_shape");
2060 void ExpandShapeOp::getAsmResultNames(
2062 setNameFn(getResult(),
"expand_shape");
2067 reifiedResultShapes = {
2068 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2077 static LogicalResult
2081 bool allowMultipleDynamicDimsPerGroup) {
2083 if (collapsedShape.size() != reassociation.size())
2084 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2085 << reassociation.size() <<
", expected " << collapsedShape.size();
2089 int64_t nextDim = 0;
2092 int64_t collapsedDim = it.index();
2094 bool foundDynamic =
false;
2095 for (int64_t expandedDim : group) {
2096 if (expandedDim != nextDim++)
2097 return op->
emitOpError(
"reassociation indices must be contiguous");
2099 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2101 << expandedDim <<
" is out of bounds";
2104 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2105 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2107 "at most one dimension in a reassociation group may be dynamic");
2108 foundDynamic =
true;
2113 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2116 <<
") must be dynamic if and only if reassociation group is "
2121 if (!foundDynamic) {
2122 int64_t groupSize = 1;
2123 for (int64_t expandedDim : group)
2124 groupSize *= expandedShape[expandedDim];
2125 if (groupSize != collapsedShape[collapsedDim])
2127 << collapsedShape[collapsedDim]
2128 <<
") must equal reassociation group size (" << groupSize <<
")";
2132 if (collapsedShape.empty()) {
2134 for (int64_t d : expandedShape)
2137 "rank 0 memrefs can only be extended/collapsed with/from ones");
2138 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2142 << expandedShape.size()
2143 <<
") inconsistent with number of reassociation indices (" << nextDim
2156 getReassociationIndices());
2165 getReassociationIndices());
2170 static FailureOr<StridedLayoutAttr>
2175 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2177 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2192 reverseResultStrides.reserve(resultShape.size());
2193 unsigned shapeIndex = resultShape.size() - 1;
2194 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2196 int64_t currentStrideToExpand = std::get<1>(it);
2197 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2198 reverseResultStrides.push_back(currentStrideToExpand);
2199 currentStrideToExpand =
2205 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2206 resultStrides.resize(resultShape.size(), 1);
2210 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2213 if (srcType.getLayout().isIdentity()) {
2216 MemRefLayoutAttrInterface layout;
2218 srcType.getMemorySpace());
2222 FailureOr<StridedLayoutAttr> computedLayout =
2224 if (failed(computedLayout))
2226 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2227 srcType.getMemorySpace());
2230 FailureOr<SmallVector<OpFoldResult>>
2232 MemRefType expandedType,
2235 std::optional<SmallVector<OpFoldResult>> outputShape =
2240 return *outputShape;
2247 auto [staticOutputShape, dynamicOutputShape] =
2249 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2251 dynamicOutputShape, staticOutputShape);
2259 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2260 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2261 builder, result.
location, memrefResultTy, reassociation, inputShape);
2264 assert(succeeded(outputShape) &&
"unable to infer output shape");
2265 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2272 auto srcType = llvm::cast<MemRefType>(src.
getType());
2273 FailureOr<MemRefType> resultType =
2274 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2277 assert(succeeded(resultType) &&
"could not compute layout");
2278 build(builder, result, *resultType, src, reassociation);
2286 auto srcType = llvm::cast<MemRefType>(src.
getType());
2287 FailureOr<MemRefType> resultType =
2288 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2291 assert(succeeded(resultType) &&
"could not compute layout");
2292 build(builder, result, *resultType, src, reassociation, outputShape);
2296 MemRefType srcType = getSrcType();
2297 MemRefType resultType = getResultType();
2299 if (srcType.getRank() > resultType.getRank()) {
2300 auto r0 = srcType.getRank();
2301 auto r1 = resultType.getRank();
2302 return emitOpError(
"has source rank ")
2303 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2304 << r0 <<
" > " << r1 <<
").";
2309 resultType.getShape(),
2310 getReassociationIndices(),
2315 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2316 srcType, resultType.getShape(), getReassociationIndices());
2317 if (failed(expectedResultType))
2318 return emitOpError(
"invalid source layout map");
2321 if (*expectedResultType != resultType)
2322 return emitOpError(
"expected expanded type to be ")
2323 << *expectedResultType <<
" but found " << resultType;
2325 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2326 return emitOpError(
"expected number of static shape bounds to be equal to "
2327 "the output rank (")
2328 << resultType.getRank() <<
") but found "
2329 << getStaticOutputShape().size() <<
" inputs instead";
2331 if ((int64_t)getOutputShape().size() !=
2332 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2333 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2334 "static_output_shape: static_output_shape has ")
2335 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2336 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2343 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2344 return emitOpError(
"invalid output shape provided at pos ") << pos;
2365 static FailureOr<StridedLayoutAttr>
2368 bool strict =
false) {
2371 auto srcShape = srcType.getShape();
2372 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2381 resultStrides.reserve(reassociation.size());
2384 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2385 ref = ref.drop_back();
2386 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2387 resultStrides.push_back(srcStrides[ref.back()]);
2393 resultStrides.push_back(ShapedType::kDynamic);
2398 unsigned resultStrideIndex = resultStrides.size() - 1;
2402 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2414 if (strict && (stride.saturated || srcStride.saturated))
2419 if (srcShape[idx - 1] == 1)
2422 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2429 bool CollapseShapeOp::isGuaranteedCollapsible(
2432 if (srcType.getLayout().isIdentity())
2439 MemRefType CollapseShapeOp::computeCollapsedType(
2442 resultShape.reserve(reassociation.size());
2445 for (int64_t srcDim : group)
2448 resultShape.push_back(groupSize.asInteger());
2451 if (srcType.getLayout().isIdentity()) {
2454 MemRefLayoutAttrInterface layout;
2456 srcType.getMemorySpace());
2462 FailureOr<StridedLayoutAttr> computedLayout =
2464 assert(succeeded(computedLayout) &&
2465 "invalid source layout map or collapsing non-contiguous dims");
2466 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2467 srcType.getMemorySpace());
2473 auto srcType = llvm::cast<MemRefType>(src.
getType());
2474 MemRefType resultType =
2475 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2478 build(b, result, resultType, src, attrs);
2482 MemRefType srcType = getSrcType();
2483 MemRefType resultType = getResultType();
2485 if (srcType.getRank() < resultType.getRank()) {
2486 auto r0 = srcType.getRank();
2487 auto r1 = resultType.getRank();
2488 return emitOpError(
"has source rank ")
2489 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2490 << r0 <<
" < " << r1 <<
").";
2495 srcType.getShape(), getReassociationIndices(),
2500 MemRefType expectedResultType;
2501 if (srcType.getLayout().isIdentity()) {
2504 MemRefLayoutAttrInterface layout;
2505 expectedResultType =
2506 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2507 srcType.getMemorySpace());
2512 FailureOr<StridedLayoutAttr> computedLayout =
2514 if (failed(computedLayout))
2516 "invalid source layout map or collapsing non-contiguous dims");
2517 expectedResultType =
2519 *computedLayout, srcType.getMemorySpace());
2522 if (expectedResultType != resultType)
2523 return emitOpError(
"expected collapsed type to be ")
2524 << expectedResultType <<
" but found " << resultType;
2536 auto cast = op.getOperand().getDefiningOp<CastOp>();
2543 Type newResultType = CollapseShapeOp::computeCollapsedType(
2544 llvm::cast<MemRefType>(cast.getOperand().getType()),
2545 op.getReassociationIndices());
2547 if (newResultType == op.getResultType()) {
2549 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2552 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2564 memref::DimOp, MemRefType>,
2568 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2569 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2570 adaptor.getOperands());
2573 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2574 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2575 adaptor.getOperands());
2582 void ReshapeOp::getAsmResultNames(
2584 setNameFn(getResult(),
"reshape");
2588 Type operandType = getSource().getType();
2589 Type resultType = getResult().getType();
2591 Type operandElementType =
2592 llvm::cast<ShapedType>(operandType).getElementType();
2593 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2594 if (operandElementType != resultElementType)
2595 return emitOpError(
"element types of source and destination memref "
2596 "types should be the same");
2598 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2599 if (!operandMemRefType.getLayout().isIdentity())
2600 return emitOpError(
"source memref type should have identity affine map");
2604 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2605 if (resultMemRefType) {
2606 if (!resultMemRefType.getLayout().isIdentity())
2607 return emitOpError(
"result memref type should have identity affine map");
2608 if (shapeSize == ShapedType::kDynamic)
2609 return emitOpError(
"cannot use shape operand with dynamic length to "
2610 "reshape to statically-ranked memref type");
2611 if (shapeSize != resultMemRefType.getRank())
2613 "length of shape operand differs from the result's memref rank");
2624 return emitOpError(
"store index operand count not equal to memref rank");
2629 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2639 void SubViewOp::getAsmResultNames(
2641 setNameFn(getResult(),
"subview");
2647 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2651 unsigned rank = sourceMemRefType.getRank();
2653 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2654 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2655 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2658 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2662 int64_t targetOffset = sourceOffset;
2663 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2664 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2674 targetStrides.reserve(staticOffsets.size());
2675 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2676 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2683 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2685 targetOffset, targetStrides),
2686 sourceMemRefType.getMemorySpace());
2689 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2704 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2705 staticSizes, staticStrides);
2708 MemRefType SubViewOp::inferRankReducedResultType(
2712 MemRefType inferredType =
2713 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2714 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2716 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2717 return inferredType;
2720 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2722 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2725 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2727 rankReducedStrides.reserve(resultShape.size());
2728 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2729 if (!dimsToProject->contains(idx))
2730 rankReducedStrides.push_back(value);
2734 inferredLayout.getOffset(),
2735 rankReducedStrides),
2736 inferredType.getMemorySpace());
2739 MemRefType SubViewOp::inferRankReducedResultType(
2748 return SubViewOp::inferRankReducedResultType(
2749 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2756 MemRefType resultType,
Value source,
2766 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2769 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2770 staticSizes, staticStrides);
2773 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2786 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2795 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2799 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2803 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2806 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2812 MemRefType resultType,
Value source,
2817 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2821 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2825 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2828 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2844 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2851 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2855 Value SubViewOp::getViewSource() {
return getSource(); }
2860 int64_t t1Offset, t2Offset;
2862 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2863 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2864 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2871 const llvm::SmallBitVector &droppedDims) {
2872 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2873 "incorrect number of bits");
2874 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2875 "incorrect number of dropped dims");
2876 int64_t t1Offset, t2Offset;
2878 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2879 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2880 if (failed(res1) || failed(res2))
2882 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2885 if (t1Strides[i] != t2Strides[
j])
2893 SubViewOp op,
Type expectedType) {
2894 auto memrefType = llvm::cast<ShapedType>(expectedType);
2899 return op->emitError(
"expected result rank to be smaller or equal to ")
2900 <<
"the source rank, but got " << op.getType();
2902 return op->emitError(
"expected result type to be ")
2904 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
2907 return op->emitError(
"expected result element type to be ")
2908 << memrefType.getElementType() <<
", but got " << op.getType();
2910 return op->emitError(
2911 "expected result and source memory spaces to match, but got ")
2914 return op->emitError(
"expected result type to be ")
2916 <<
" or a rank-reduced version. (mismatch of result layout), but "
2920 llvm_unreachable(
"unexpected subview verification result");
2925 MemRefType baseType = getSourceType();
2926 MemRefType subViewType =
getType();
2932 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2933 return emitError(
"different memory spaces specified for base memref "
2935 << baseType <<
" and subview memref type " << subViewType;
2938 if (!baseType.isStrided())
2939 return emitError(
"base type ") << baseType <<
" is not strided";
2943 MemRefType expectedType = SubViewOp::inferResultType(
2944 baseType, staticOffsets, staticSizes, staticStrides);
2949 expectedType, subViewType);
2954 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2956 *
this, expectedType);
2961 *
this, expectedType);
2969 if (failed(unusedDims))
2971 *
this, expectedType);
2976 *
this, expectedType);
2982 staticStrides,
true);
2984 return getOperation()->emitError(boundsResult.
errorMessage);
2990 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
2999 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3000 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3001 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3003 unsigned rank = ranks[0];
3005 for (
unsigned idx = 0; idx < rank; ++idx) {
3007 op.isDynamicOffset(idx)
3008 ? op.getDynamicOffset(idx)
3011 op.isDynamicSize(idx)
3012 ? op.getDynamicSize(idx)
3015 op.isDynamicStride(idx)
3016 ? op.getDynamicStride(idx)
3018 res.emplace_back(
Range{offset, size, stride});
3031 MemRefType currentResultType, MemRefType currentSourceType,
3034 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3035 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3037 currentSourceType, currentResultType, mixedSizes);
3038 if (failed(unusedDims))
3041 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3043 unsigned numDimsAfterReduction =
3044 nonRankReducedType.getRank() - unusedDims->count();
3045 shape.reserve(numDimsAfterReduction);
3046 strides.reserve(numDimsAfterReduction);
3047 for (
const auto &[idx, size, stride] :
3048 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3049 nonRankReducedType.getShape(), layout.getStrides())) {
3050 if (unusedDims->test(idx))
3052 shape.push_back(size);
3053 strides.push_back(stride);
3058 layout.getOffset(), strides),
3059 nonRankReducedType.getMemorySpace());
3064 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3065 unsigned rank = memrefType.getRank();
3069 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3070 targetShape, memrefType, offsets, sizes, strides);
3071 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3078 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3079 assert(sourceMemrefType &&
"not a ranked memref type");
3080 auto sourceShape = sourceMemrefType.getShape();
3081 if (sourceShape.equals(desiredShape))
3083 auto maybeRankReductionMask =
3085 if (!maybeRankReductionMask)
3095 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3098 auto mixedOffsets = subViewOp.getMixedOffsets();
3099 auto mixedSizes = subViewOp.getMixedSizes();
3100 auto mixedStrides = subViewOp.getMixedStrides();
3105 return !intValue || intValue.value() != 0;
3112 return !intValue || intValue.value() != 1;
3120 if (!intValue || *intValue != sourceShape[size.index()])
3144 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3148 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3152 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3153 return matchPattern(operand, matchConstantIndex());
3157 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3169 subViewOp.getType(), subViewOp.getSourceType(),
3170 llvm::cast<MemRefType>(castOp.getSource().getType()),
3171 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3172 subViewOp.getMixedStrides());
3177 subViewOp.getLoc(), resultType, castOp.getSource(),
3178 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3179 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3180 subViewOp.getStaticStrides());
3193 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3197 if (subViewOp.getSourceType() == subViewOp.getType()) {
3198 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3202 subViewOp.getSource());
3214 MemRefType resTy = SubViewOp::inferResultType(
3215 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3218 MemRefType nonReducedType = resTy;
3221 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3222 if (droppedDims.none())
3223 return nonReducedType;
3226 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3231 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3232 if (droppedDims.test(i))
3234 targetStrides.push_back(nonReducedStrides[i]);
3235 targetShape.push_back(nonReducedType.getDimSize(i));
3240 offset, targetStrides),
3241 nonReducedType.getMemorySpace());
3257 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3261 MemRefType sourceMemrefType = getSource().getType();
3262 MemRefType resultMemrefType = getResult().getType();
3264 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3266 if (resultMemrefType == sourceMemrefType &&
3267 resultMemrefType.hasStaticShape() &&
3268 (!resultLayout || resultLayout.hasStaticLayout())) {
3269 return getViewSource();
3275 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3276 auto srcSizes = srcSubview.getMixedSizes();
3278 auto offsets = getMixedOffsets();
3280 auto strides = getMixedStrides();
3281 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3282 bool allSizesSame = llvm::equal(sizes, srcSizes);
3283 if (allOffsetsZero && allStridesOne && allSizesSame &&
3284 resultMemrefType == sourceMemrefType)
3285 return getViewSource();
3295 void TransposeOp::getAsmResultNames(
3297 setNameFn(getResult(),
"transpose");
3303 auto originalSizes = memRefType.getShape();
3304 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3305 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3308 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3309 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3318 AffineMapAttr permutation,
3320 auto permutationMap = permutation.getValue();
3321 assert(permutationMap);
3323 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3327 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3328 build(b, result, resultType, in, attrs);
3333 p <<
" " << getIn() <<
" " << getPermutation();
3335 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3341 MemRefType srcType, dstType;
3350 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3357 return emitOpError(
"expected a permutation map");
3358 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3359 return emitOpError(
"expected a permutation map of same rank as the input");
3361 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3362 auto resultType = llvm::cast<MemRefType>(
getType());
3364 .canonicalizeStridedLayout();
3366 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3367 return emitOpError(
"result type ")
3369 <<
" is not equivalent to the canonical transposed input type "
3370 << canonicalResultType;
3377 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3381 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3383 getPermutation().
compose(otherTransposeOp.getPermutation());
3384 getInMutable().assign(otherTransposeOp.getIn());
3385 setPermutation(composedPermutation);
3395 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3396 setNameFn(getResult(),
"view");
3400 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3404 if (!baseType.getLayout().isIdentity())
3405 return emitError(
"unsupported map for base memref type ") << baseType;
3408 if (!viewType.getLayout().isIdentity())
3409 return emitError(
"unsupported map for result memref type ") << viewType;
3412 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3413 return emitError(
"different memory spaces specified for base memref "
3415 << baseType <<
" and view memref type " << viewType;
3418 unsigned numDynamicDims = viewType.getNumDynamicDims();
3419 if (getSizes().size() != numDynamicDims)
3420 return emitError(
"incorrect number of size operands for type ") << viewType;
3425 Value ViewOp::getViewSource() {
return getSource(); }
3428 MemRefType sourceMemrefType = getSource().getType();
3429 MemRefType resultMemrefType = getResult().getType();
3431 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3432 return getViewSource();
3442 LogicalResult matchAndRewrite(ViewOp viewOp,
3445 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3446 return matchPattern(operand, matchConstantIndex());
3451 auto memrefType = viewOp.getType();
3456 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3458 assert(oldOffset == 0 &&
"Expected 0 offset");
3466 newShapeConstants.reserve(memrefType.getRank());
3468 unsigned dynamicDimPos = 0;
3469 unsigned rank = memrefType.getRank();
3470 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3471 int64_t dimSize = memrefType.getDimSize(dim);
3473 if (ShapedType::isStatic(dimSize)) {
3474 newShapeConstants.push_back(dimSize);
3477 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3478 if (
auto constantIndexOp =
3479 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3481 newShapeConstants.push_back(constantIndexOp.value());
3484 newShapeConstants.push_back(dimSize);
3485 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3491 MemRefType newMemRefType =
3494 if (newMemRefType == memrefType)
3498 auto newViewOp = rewriter.
create<ViewOp>(
3499 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3500 viewOp.getByteShift(), newOperands);
3510 LogicalResult matchAndRewrite(ViewOp viewOp,
3512 Value memrefOperand = viewOp.getOperand(0);
3513 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3516 Value allocOperand = memrefCastOp.getOperand();
3521 viewOp.getByteShift(),
3531 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3541 "expects the number of subscripts to be equal to memref rank");
3542 switch (getKind()) {
3543 case arith::AtomicRMWKind::addf:
3544 case arith::AtomicRMWKind::maximumf:
3545 case arith::AtomicRMWKind::minimumf:
3546 case arith::AtomicRMWKind::mulf:
3547 if (!llvm::isa<FloatType>(getValue().
getType()))
3548 return emitOpError() <<
"with kind '"
3549 << arith::stringifyAtomicRMWKind(getKind())
3550 <<
"' expects a floating-point type";
3552 case arith::AtomicRMWKind::addi:
3553 case arith::AtomicRMWKind::maxs:
3554 case arith::AtomicRMWKind::maxu:
3555 case arith::AtomicRMWKind::mins:
3556 case arith::AtomicRMWKind::minu:
3557 case arith::AtomicRMWKind::muli:
3558 case arith::AtomicRMWKind::ori:
3559 case arith::AtomicRMWKind::andi:
3560 if (!llvm::isa<IntegerType>(getValue().
getType()))
3561 return emitOpError() <<
"with kind '"
3562 << arith::stringifyAtomicRMWKind(getKind())
3563 <<
"' expects an integer type";
3582 #define GET_OP_CLASSES
3583 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Check whether this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.