24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
35 return arith::ConstantOp::materialize(builder, value, type, loc);
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
55 return success(folded);
61 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
63 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
70 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (ShapedType::isStatic(cstVal)) {
118 void AllocOp::getAsmResultNames(
120 setNameFn(getResult(),
"alloc");
123 void AllocaOp::getAsmResultNames(
125 setNameFn(getResult(),
"alloca");
128 template <
typename AllocLikeOp>
130 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131 "applies to only alloc or alloca");
132 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
134 return op.emitOpError(
"result must be a memref");
136 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137 return op.emitOpError(
"dimension operand count does not equal memref "
138 "dynamic dimension count");
140 unsigned numSymbols = 0;
141 if (!memRefType.getLayout().isIdentity())
142 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143 if (op.getSymbolOperands().size() != numSymbols)
144 return op.emitOpError(
"symbol operand count does not equal memref symbol "
146 << numSymbols <<
", got " << op.getSymbolOperands().size();
157 "requires an ancestor op with AutomaticAllocationScope trait");
164 template <
typename AllocLikeOp>
168 LogicalResult matchAndRewrite(AllocLikeOp alloc,
172 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
174 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
176 return constSizeArg.isNonNegative();
180 auto memrefType = alloc.getType();
185 newShapeConstants.reserve(memrefType.getRank());
188 unsigned dynamicDimPos = 0;
189 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190 int64_t dimSize = memrefType.getDimSize(dim);
192 if (ShapedType::isStatic(dimSize)) {
193 newShapeConstants.push_back(dimSize);
196 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
199 constSizeArg.isNonNegative()) {
201 newShapeConstants.push_back(constSizeArg.getZExtValue());
204 newShapeConstants.push_back(ShapedType::kDynamic);
205 dynamicSizes.push_back(dynamicSize);
211 MemRefType newMemRefType =
213 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
216 auto newAlloc = rewriter.
create<AllocLikeOp>(
217 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218 alloc.getAlignmentAttr());
226 template <
typename T>
230 LogicalResult matchAndRewrite(T alloc,
232 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
233 if (auto storeOp = dyn_cast<StoreOp>(op))
234 return storeOp.getValue() == alloc;
235 return !isa<DeallocOp>(op);
239 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
250 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
255 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
264 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
265 MemRefType resultType =
getType();
268 if (!sourceType.getLayout().isIdentity())
269 return emitError(
"unsupported layout for source memref type ")
273 if (!resultType.getLayout().isIdentity())
274 return emitError(
"unsupported layout for result memref type ")
278 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279 return emitError(
"different memory spaces specified for source memref "
281 << sourceType <<
" and result memref type " << resultType;
284 if (sourceType.getElementType() != resultType.getElementType())
285 return emitError(
"different element types specified for source memref "
287 << sourceType <<
" and result memref type " << resultType;
290 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291 return emitError(
"missing dimension operand for result type ")
293 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294 return emitError(
"unnecessary dimension operand for result type ")
302 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
310 bool printBlockTerminators =
false;
313 if (!getResults().empty()) {
314 p <<
" -> (" << getResultTypes() <<
")";
315 printBlockTerminators =
true;
320 printBlockTerminators);
336 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
346 void AllocaScopeOp::getSuccessorRegions(
359 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
365 if (isa<SideEffects::AutomaticAllocationScopeResource>(
366 effect->getResource()))
382 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
388 if (isa<SideEffects::AutomaticAllocationScopeResource>(
389 effect->getResource()))
413 bool hasPotentialAlloca =
426 if (hasPotentialAlloca) {
459 if (!lastParentWithoutScope ||
472 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
473 if (!lastParentWithoutScope ||
480 Region *containingRegion =
nullptr;
481 for (
auto &r : lastParentWithoutScope->
getRegions()) {
482 if (r.isAncestor(op->getParentRegion())) {
483 assert(containingRegion ==
nullptr &&
484 "only one region can contain the op");
485 containingRegion = &r;
488 assert(containingRegion &&
"op must be contained in a region");
498 return containingRegion->isAncestor(v.getParentRegion());
501 toHoist.push_back(alloc);
508 for (
auto *op : toHoist) {
509 auto *cloned = rewriter.
clone(*op);
510 rewriter.
replaceOp(op, cloned->getResults());
526 if (!llvm::isPowerOf2_32(getAlignment()))
527 return emitOpError(
"alignment must be power of 2");
531 void AssumeAlignmentOp::getAsmResultNames(
533 setNameFn(getResult(),
"assume_align");
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
540 if (source.getAlignment() != getAlignment())
550 setNameFn(getResult(),
"cast");
591 MemRefType sourceType =
592 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
593 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
596 if (!sourceType || !resultType)
600 if (sourceType.getElementType() != resultType.getElementType())
604 if (sourceType.getRank() != resultType.getRank())
608 int64_t sourceOffset, resultOffset;
610 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
615 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616 auto ss = std::get<0>(it), st = std::get<1>(it);
618 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
623 if (sourceOffset != resultOffset)
624 if (ShapedType::isDynamic(sourceOffset) &&
625 ShapedType::isStatic(resultOffset))
629 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
630 auto ss = std::get<0>(it), st = std::get<1>(it);
632 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
640 if (inputs.size() != 1 || outputs.size() != 1)
642 Type a = inputs.front(), b = outputs.front();
643 auto aT = llvm::dyn_cast<MemRefType>(a);
644 auto bT = llvm::dyn_cast<MemRefType>(b);
646 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
647 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
650 if (aT.getElementType() != bT.getElementType())
652 if (aT.getLayout() != bT.getLayout()) {
653 int64_t aOffset, bOffset;
655 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657 aStrides.size() != bStrides.size())
664 auto checkCompatible = [](int64_t a, int64_t b) {
665 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
667 if (!checkCompatible(aOffset, bOffset))
669 for (
const auto &aStride :
enumerate(aStrides))
670 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
673 if (aT.getMemorySpace() != bT.getMemorySpace())
677 if (aT.getRank() != bT.getRank())
680 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
696 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698 if (aEltType != bEltType)
701 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703 return aMemSpace == bMemSpace;
724 LogicalResult matchAndRewrite(CopyOp copyOp,
726 bool modified =
false;
729 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
730 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
731 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
733 if (fromType && toType) {
734 if (fromType.getShape() == toType.getShape() &&
735 fromType.getElementType() == toType.getElementType()) {
737 copyOp.getSourceMutable().assign(castOp.getSource());
745 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
746 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
747 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
749 if (fromType && toType) {
750 if (fromType.getShape() == toType.getShape() &&
751 fromType.getElementType() == toType.getElementType()) {
753 copyOp.getTargetMutable().assign(castOp.getSource());
760 return success(modified);
768 LogicalResult matchAndRewrite(CopyOp copyOp,
770 if (copyOp.getSource() != copyOp.getTarget())
785 LogicalResult matchAndRewrite(CopyOp copyOp,
787 if (isEmptyMemRef(copyOp.getSource().getType()) ||
788 isEmptyMemRef(copyOp.getTarget().getType())) {
800 results.
add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
803 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
811 operand.set(castOp.getOperand());
815 return success(folded);
822 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
833 setNameFn(getResult(),
"dim");
839 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
840 build(builder, result, source, indexValue);
843 std::optional<int64_t> DimOp::getConstantIndex() {
852 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
853 if (!rankedSourceType)
864 setResultRange(getResult(),
873 std::map<int64_t, unsigned> numOccurences;
874 for (
auto val : vals)
875 numOccurences[val]++;
876 return numOccurences;
886 static FailureOr<llvm::SmallBitVector>
889 llvm::SmallBitVector unusedDims(originalType.getRank());
890 if (originalType.getRank() == reducedType.getRank())
894 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
895 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
896 unusedDims.set(dim.index());
900 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
901 originalType.getRank())
905 int64_t originalOffset, candidateOffset;
907 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
909 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
921 std::map<int64_t, unsigned> currUnaccountedStrides =
923 std::map<int64_t, unsigned> candidateStridesNumOccurences =
925 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
926 if (!unusedDims.test(dim))
928 int64_t originalStride = originalStrides[dim];
929 if (currUnaccountedStrides[originalStride] >
930 candidateStridesNumOccurences[originalStride]) {
932 currUnaccountedStrides[originalStride]--;
935 if (currUnaccountedStrides[originalStride] ==
936 candidateStridesNumOccurences[originalStride]) {
938 unusedDims.reset(dim);
941 if (currUnaccountedStrides[originalStride] <
942 candidateStridesNumOccurences[originalStride]) {
949 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
950 originalType.getRank())
956 MemRefType sourceType = getSourceType();
957 MemRefType resultType =
getType();
958 FailureOr<llvm::SmallBitVector> unusedDims =
960 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
966 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
971 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
977 int64_t indexVal = index.getInt();
978 if (indexVal < 0 || indexVal >= memrefType.getRank())
982 if (!memrefType.isDynamicDim(index.getInt())) {
984 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
988 unsigned unsignedIndex = index.getValue().getZExtValue();
991 Operation *definingOp = getSource().getDefiningOp();
993 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
994 return *(alloc.getDynamicSizes().begin() +
995 memrefType.getDynamicDimIndex(unsignedIndex));
997 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
998 return *(alloca.getDynamicSizes().begin() +
999 memrefType.getDynamicDimIndex(unsignedIndex));
1001 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1002 return *(view.getDynamicSizes().begin() +
1003 memrefType.getDynamicDimIndex(unsignedIndex));
1005 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1006 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1007 unsigned resultIndex = 0;
1008 unsigned sourceRank = subview.getSourceType().getRank();
1009 unsigned sourceIndex = 0;
1010 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1011 if (unusedDims.test(i))
1013 if (resultIndex == unsignedIndex) {
1019 assert(subview.isDynamicSize(sourceIndex) &&
1020 "expected dynamic subview size");
1021 return subview.getDynamicSize(sourceIndex);
1024 if (
auto sizeInterface =
1025 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1026 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1027 "Expected dynamic subview size");
1028 return sizeInterface.getDynamicSize(unsignedIndex);
1044 LogicalResult matchAndRewrite(DimOp dim,
1046 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1050 dim,
"Dim op is not defined by a reshape op.");
1061 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1062 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1063 if (reshape->isBeforeInBlock(definingOp)) {
1066 "dim.getIndex is not defined before reshape in the same block.");
1071 else if (dim->getBlock() != reshape->getBlock() &&
1072 !dim.getIndex().getParentRegion()->isProperAncestor(
1073 reshape->getParentRegion())) {
1078 dim,
"dim.getIndex does not dominate reshape.");
1086 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1087 if (load.
getType() != dim.getType())
1088 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1098 results.
add<DimOfMemRefReshape>(context);
1109 Value elementsPerStride) {
1121 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1122 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1123 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1125 p <<
", " <<
getStride() <<
", " << getNumElementsPerStride();
1128 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1129 <<
", " << getTagMemRef().getType();
1170 bool isStrided = strideInfo.size() == 2;
1171 if (!strideInfo.empty() && !isStrided) {
1173 "expected two stride related operands");
1178 if (types.size() != 3)
1201 unsigned numOperands = getNumOperands();
1205 if (numOperands < 4)
1206 return emitOpError(
"expected at least 4 operands");
1211 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1212 return emitOpError(
"expected source to be of memref type");
1213 if (numOperands < getSrcMemRefRank() + 4)
1214 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1216 if (!getSrcIndices().empty() &&
1217 !llvm::all_of(getSrcIndices().getTypes(),
1219 return emitOpError(
"expected source indices to be of index type");
1222 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1223 return emitOpError(
"expected destination to be of memref type");
1224 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1225 if (numOperands < numExpectedOperands)
1226 return emitOpError() <<
"expected at least " << numExpectedOperands
1228 if (!getDstIndices().empty() &&
1229 !llvm::all_of(getDstIndices().getTypes(),
1231 return emitOpError(
"expected destination indices to be of index type");
1235 return emitOpError(
"expected num elements to be of index type");
1238 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1239 return emitOpError(
"expected tag to be of memref type");
1240 numExpectedOperands += getTagMemRefRank();
1241 if (numOperands < numExpectedOperands)
1242 return emitOpError() <<
"expected at least " << numExpectedOperands
1244 if (!getTagIndices().empty() &&
1245 !llvm::all_of(getTagIndices().getTypes(),
1247 return emitOpError(
"expected tag indices to be of index type");
1251 if (numOperands != numExpectedOperands &&
1252 numOperands != numExpectedOperands + 2)
1253 return emitOpError(
"incorrect number of operands");
1258 !getNumElementsPerStride().
getType().isIndex())
1260 "expected stride and num elements per stride to be of type index");
1266 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1276 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1284 unsigned numTagIndices = getTagIndices().size();
1285 unsigned tagMemRefRank = getTagMemRefRank();
1286 if (numTagIndices != tagMemRefRank)
1287 return emitOpError() <<
"expected tagIndices to have the same number of "
1288 "elements as the tagMemRef rank, expected "
1289 << tagMemRefRank <<
", but got " << numTagIndices;
1297 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1299 setNameFn(getResult(),
"intptr");
1308 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1309 MLIRContext *context, std::optional<Location> location,
1310 ExtractStridedMetadataOp::Adaptor adaptor,
1312 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1316 unsigned sourceRank = sourceType.getRank();
1320 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1322 inferredReturnTypes.push_back(memrefType);
1324 inferredReturnTypes.push_back(indexType);
1326 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1327 inferredReturnTypes.push_back(indexType);
1331 void ExtractStridedMetadataOp::getAsmResultNames(
1333 setNameFn(getBaseBuffer(),
"base_buffer");
1334 setNameFn(getOffset(),
"offset");
1337 if (!getSizes().empty()) {
1338 setNameFn(getSizes().front(),
"sizes");
1339 setNameFn(getStrides().front(),
"strides");
1346 template <
typename Container>
1350 assert(values.size() == maybeConstants.size() &&
1351 " expected values and maybeConstants of the same size");
1352 bool atLeastOneReplacement =
false;
1353 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1358 assert(isa<Attribute>(maybeConstant) &&
1359 "The constified value should be either unchanged (i.e., == result) "
1361 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1362 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1363 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1367 atLeastOneReplacement =
true;
1370 return atLeastOneReplacement;
1374 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1380 getConstifiedMixedOffset());
1382 getConstifiedMixedSizes());
1384 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1386 return success(atLeastOneReplacement);
1396 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1400 LogicalResult status =
1401 getSource().getType().getStridesAndOffset(staticValues, unused);
1403 assert(succeeded(status) &&
"could not get strides from type");
1408 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1413 LogicalResult status =
1414 getSource().getType().getStridesAndOffset(unused, offset);
1416 assert(succeeded(status) &&
"could not get offset from type");
1417 staticValues.push_back(offset);
1432 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1433 Type elementType = memrefType.getElementType();
1443 auto &body = getRegion();
1444 if (body.getNumArguments() != 1)
1445 return emitOpError(
"expected single number of entry block arguments");
1447 if (getResult().
getType() != body.getArgument(0).getType())
1448 return emitOpError(
"expected block argument of the same type result type");
1455 "body of 'memref.generic_atomic_rmw' should contain "
1456 "only operations with no side effects");
1486 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1487 <<
"] : " << getMemref().
getType() <<
' ';
1497 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1498 Type resultType = getResult().getType();
1499 if (parentType != resultType)
1500 return emitOpError() <<
"types mismatch between yield op: " << resultType
1501 <<
" and its parent: " << parentType;
1513 if (!op.isExternal()) {
1515 if (op.isUninitialized())
1516 p <<
"uninitialized";
1529 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1530 if (!memrefType || !memrefType.hasStaticShape())
1532 <<
"type should be static shaped memref, but got " << type;
1546 if (!llvm::isa<ElementsAttr>(initialValue))
1548 <<
"initial value should be a unit or elements attribute";
1553 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1554 if (!memrefType || !memrefType.hasStaticShape())
1555 return emitOpError(
"type should be static shaped memref, but got ")
1560 if (getInitialValue().has_value()) {
1561 Attribute initValue = getInitialValue().value();
1562 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1563 return emitOpError(
"initial value should be a unit or elements "
1564 "attribute, but got ")
1569 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1571 auto initElementType =
1572 cast<TensorType>(elementsAttr.getType()).getElementType();
1573 auto memrefElementType = memrefType.getElementType();
1575 if (initElementType != memrefElementType)
1576 return emitOpError(
"initial value element expected to be of type ")
1577 << memrefElementType <<
", but was of type " << initElementType;
1582 auto initShape = elementsAttr.getShapedType().getShape();
1583 auto memrefShape = memrefType.getShape();
1584 if (initShape != memrefShape)
1585 return emitOpError(
"initial value shape expected to be ")
1586 << memrefShape <<
" but was " << initShape;
1590 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1591 uint64_t alignment = *alignAttr;
1593 if (!llvm::isPowerOf2_64(alignment))
1594 return emitError() <<
"alignment attribute value " << alignment
1595 <<
" is not a power of 2";
1602 ElementsAttr GlobalOp::getConstantInitValue() {
1603 auto initVal = getInitialValue();
1604 if (getConstant() && initVal.has_value())
1605 return llvm::cast<ElementsAttr>(initVal.value());
1620 return emitOpError(
"'")
1621 << getName() <<
"' does not reference a valid global memref";
1623 Type resultType = getResult().getType();
1624 if (global.getType() != resultType)
1625 return emitOpError(
"result type ")
1626 << resultType <<
" does not match type " << global.getType()
1627 <<
" of the global memref @" << getName();
1637 return emitOpError(
"incorrect number of indices for load, expected ")
1654 void MemorySpaceCastOp::getAsmResultNames(
1656 setNameFn(getResult(),
"memspacecast");
1660 if (inputs.size() != 1 || outputs.size() != 1)
1662 Type a = inputs.front(), b = outputs.front();
1663 auto aT = llvm::dyn_cast<MemRefType>(a);
1664 auto bT = llvm::dyn_cast<MemRefType>(b);
1666 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1667 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1670 if (aT.getElementType() != bT.getElementType())
1672 if (aT.getLayout() != bT.getLayout())
1674 if (aT.getShape() != bT.getShape())
1679 return uaT.getElementType() == ubT.getElementType();
1684 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1687 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1688 getSourceMutable().assign(parentCast.getSource());
1699 p <<
" " << getMemref() <<
'[';
1701 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1702 p <<
", locality<" << getLocalityHint();
1703 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1705 (*this)->getAttrs(),
1706 {
"localityHint",
"isWrite",
"isDataCache"});
1713 IntegerAttr localityHint;
1715 StringRef readOrWrite, cacheType;
1732 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1734 "rw specifier has to be 'read' or 'write'");
1735 result.
addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1738 if (cacheType !=
"data" && cacheType !=
"instr")
1740 "cache type has to be 'data' or 'instr'");
1742 result.
addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1750 return emitOpError(
"too few indices");
1755 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1767 auto type = getOperand().getType();
1768 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1769 if (shapedType && shapedType.hasRank())
1771 return IntegerAttr();
1778 void ReinterpretCastOp::getAsmResultNames(
1780 setNameFn(getResult(),
"reinterpret_cast");
1787 MemRefType resultType,
Value source,
1797 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1808 auto sourceType = cast<BaseMemRefType>(source.
getType());
1815 b.
getContext(), staticOffsets.front(), staticStrides);
1816 auto resultType =
MemRefType::get(staticSizes, sourceType.getElementType(),
1817 stridedLayout, sourceType.getMemorySpace());
1818 build(b, result, resultType, source, offset, sizes, strides, attrs);
1822 MemRefType resultType,
Value source,
1827 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1831 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1835 strideValues, attrs);
1839 MemRefType resultType,
Value source,
Value offset,
1846 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1853 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
1854 auto resultType = llvm::cast<MemRefType>(
getType());
1855 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1856 return emitError(
"different memory spaces specified for source type ")
1857 << srcType <<
" and result memref type " << resultType;
1858 if (srcType.getElementType() != resultType.getElementType())
1859 return emitError(
"different element types specified for source type ")
1860 << srcType <<
" and result memref type " << resultType;
1863 for (
auto [idx, resultSize, expectedSize] :
1865 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1866 return emitError(
"expected result type with size = ")
1867 << (ShapedType::isDynamic(expectedSize)
1868 ? std::string(
"dynamic")
1869 : std::to_string(expectedSize))
1870 <<
" instead of " << resultSize <<
" in dim = " << idx;
1876 int64_t resultOffset;
1878 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1879 return emitError(
"expected result type to have strided layout but found ")
1883 int64_t expectedOffset = getStaticOffsets().front();
1884 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1885 return emitError(
"expected result type with offset = ")
1886 << (ShapedType::isDynamic(expectedOffset)
1887 ? std::string(
"dynamic")
1888 : std::to_string(expectedOffset))
1889 <<
" instead of " << resultOffset;
1892 for (
auto [idx, resultStride, expectedStride] :
1894 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1895 return emitError(
"expected result type with stride = ")
1896 << (ShapedType::isDynamic(expectedStride)
1897 ? std::string(
"dynamic")
1898 : std::to_string(expectedStride))
1899 <<
" instead of " << resultStride <<
" in dim = " << idx;
1906 Value src = getSource();
1907 auto getPrevSrc = [&]() ->
Value {
1910 return prev.getSource();
1914 return prev.getSource();
1920 return prev.getSource();
1925 if (
auto prevSrc = getPrevSrc()) {
1926 getSourceMutable().assign(prevSrc);
1949 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
1951 assert(succeeded(status) &&
"could not get strides from type");
1956 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1958 assert(values.size() == 1 &&
1959 "reinterpret_cast must have one and only one offset");
1962 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
1964 assert(succeeded(status) &&
"could not get offset from type");
1965 staticValues.push_back(offset);
2013 struct ReinterpretCastOpExtractStridedMetadataFolder
2018 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2020 auto extractStridedMetadata =
2021 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2022 if (!extractStridedMetadata)
2027 auto isReinterpretCastNoop = [&]() ->
bool {
2029 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2030 op.getConstifiedMixedStrides()))
2034 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2035 op.getConstifiedMixedSizes()))
2039 assert(op.getMixedOffsets().size() == 1 &&
2040 "reinterpret_cast with more than one offset should have been "
2041 "rejected by the verifier");
2042 return extractStridedMetadata.getConstifiedMixedOffset() ==
2043 op.getConstifiedMixedOffset();
2046 if (!isReinterpretCastNoop()) {
2063 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2073 Type srcTy = extractStridedMetadata.getSource().getType();
2074 if (srcTy == op.getResult().getType())
2075 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2078 extractStridedMetadata.getSource());
2087 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2094 void CollapseShapeOp::getAsmResultNames(
2096 setNameFn(getResult(),
"collapse_shape");
2099 void ExpandShapeOp::getAsmResultNames(
2101 setNameFn(getResult(),
"expand_shape");
2106 reifiedResultShapes = {
2107 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2116 static LogicalResult
2120 bool allowMultipleDynamicDimsPerGroup) {
2122 if (collapsedShape.size() != reassociation.size())
2123 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2124 << reassociation.size() <<
", expected " << collapsedShape.size();
2128 int64_t nextDim = 0;
2131 int64_t collapsedDim = it.index();
2133 bool foundDynamic =
false;
2134 for (int64_t expandedDim : group) {
2135 if (expandedDim != nextDim++)
2136 return op->
emitOpError(
"reassociation indices must be contiguous");
2138 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2140 << expandedDim <<
" is out of bounds";
2143 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2144 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2146 "at most one dimension in a reassociation group may be dynamic");
2147 foundDynamic =
true;
2152 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2155 <<
") must be dynamic if and only if reassociation group is "
2160 if (!foundDynamic) {
2161 int64_t groupSize = 1;
2162 for (int64_t expandedDim : group)
2163 groupSize *= expandedShape[expandedDim];
2164 if (groupSize != collapsedShape[collapsedDim])
2166 << collapsedShape[collapsedDim]
2167 <<
") must equal reassociation group size (" << groupSize <<
")";
2171 if (collapsedShape.empty()) {
2173 for (int64_t d : expandedShape)
2176 "rank 0 memrefs can only be extended/collapsed with/from ones");
2177 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2181 << expandedShape.size()
2182 <<
") inconsistent with number of reassociation indices (" << nextDim
2195 getReassociationIndices());
2204 getReassociationIndices());
2209 static FailureOr<StridedLayoutAttr>
2214 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2216 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2231 reverseResultStrides.reserve(resultShape.size());
2232 unsigned shapeIndex = resultShape.size() - 1;
2233 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2235 int64_t currentStrideToExpand = std::get<1>(it);
2236 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2237 reverseResultStrides.push_back(currentStrideToExpand);
2238 currentStrideToExpand =
2244 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2245 resultStrides.resize(resultShape.size(), 1);
2249 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2252 if (srcType.getLayout().isIdentity()) {
2255 MemRefLayoutAttrInterface layout;
2257 srcType.getMemorySpace());
2261 FailureOr<StridedLayoutAttr> computedLayout =
2263 if (failed(computedLayout))
2265 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2266 srcType.getMemorySpace());
2269 FailureOr<SmallVector<OpFoldResult>>
2271 MemRefType expandedType,
2274 std::optional<SmallVector<OpFoldResult>> outputShape =
2279 return *outputShape;
2286 auto [staticOutputShape, dynamicOutputShape] =
2288 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2290 dynamicOutputShape, staticOutputShape);
2298 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2299 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2300 builder, result.
location, memrefResultTy, reassociation, inputShape);
2303 assert(succeeded(outputShape) &&
"unable to infer output shape");
2304 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2311 auto srcType = llvm::cast<MemRefType>(src.
getType());
2312 FailureOr<MemRefType> resultType =
2313 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2316 assert(succeeded(resultType) &&
"could not compute layout");
2317 build(builder, result, *resultType, src, reassociation);
2325 auto srcType = llvm::cast<MemRefType>(src.
getType());
2326 FailureOr<MemRefType> resultType =
2327 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2330 assert(succeeded(resultType) &&
"could not compute layout");
2331 build(builder, result, *resultType, src, reassociation, outputShape);
2335 MemRefType srcType = getSrcType();
2336 MemRefType resultType = getResultType();
2338 if (srcType.getRank() > resultType.getRank()) {
2339 auto r0 = srcType.getRank();
2340 auto r1 = resultType.getRank();
2341 return emitOpError(
"has source rank ")
2342 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2343 << r0 <<
" > " << r1 <<
").";
2348 resultType.getShape(),
2349 getReassociationIndices(),
2354 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2355 srcType, resultType.getShape(), getReassociationIndices());
2356 if (failed(expectedResultType))
2357 return emitOpError(
"invalid source layout map");
2360 if (*expectedResultType != resultType)
2361 return emitOpError(
"expected expanded type to be ")
2362 << *expectedResultType <<
" but found " << resultType;
2364 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2365 return emitOpError(
"expected number of static shape bounds to be equal to "
2366 "the output rank (")
2367 << resultType.getRank() <<
") but found "
2368 << getStaticOutputShape().size() <<
" inputs instead";
2370 if ((int64_t)getOutputShape().size() !=
2371 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2372 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2373 "static_output_shape: static_output_shape has ")
2374 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2375 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2382 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2383 return emitOpError(
"invalid output shape provided at pos ") << pos;
2404 static FailureOr<StridedLayoutAttr>
2407 bool strict =
false) {
2410 auto srcShape = srcType.getShape();
2411 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2420 resultStrides.reserve(reassociation.size());
2423 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2424 ref = ref.drop_back();
2425 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2426 resultStrides.push_back(srcStrides[ref.back()]);
2432 resultStrides.push_back(ShapedType::kDynamic);
2437 unsigned resultStrideIndex = resultStrides.size() - 1;
2441 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2453 if (strict && (stride.saturated || srcStride.saturated))
2458 if (srcShape[idx - 1] == 1)
2461 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2468 bool CollapseShapeOp::isGuaranteedCollapsible(
2471 if (srcType.getLayout().isIdentity())
2478 MemRefType CollapseShapeOp::computeCollapsedType(
2481 resultShape.reserve(reassociation.size());
2484 for (int64_t srcDim : group)
2487 resultShape.push_back(groupSize.asInteger());
2490 if (srcType.getLayout().isIdentity()) {
2493 MemRefLayoutAttrInterface layout;
2495 srcType.getMemorySpace());
2501 FailureOr<StridedLayoutAttr> computedLayout =
2503 assert(succeeded(computedLayout) &&
2504 "invalid source layout map or collapsing non-contiguous dims");
2505 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2506 srcType.getMemorySpace());
2512 auto srcType = llvm::cast<MemRefType>(src.
getType());
2513 MemRefType resultType =
2514 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2517 build(b, result, resultType, src, attrs);
2521 MemRefType srcType = getSrcType();
2522 MemRefType resultType = getResultType();
2524 if (srcType.getRank() < resultType.getRank()) {
2525 auto r0 = srcType.getRank();
2526 auto r1 = resultType.getRank();
2527 return emitOpError(
"has source rank ")
2528 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2529 << r0 <<
" < " << r1 <<
").";
2534 srcType.getShape(), getReassociationIndices(),
2539 MemRefType expectedResultType;
2540 if (srcType.getLayout().isIdentity()) {
2543 MemRefLayoutAttrInterface layout;
2544 expectedResultType =
2545 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2546 srcType.getMemorySpace());
2551 FailureOr<StridedLayoutAttr> computedLayout =
2553 if (failed(computedLayout))
2555 "invalid source layout map or collapsing non-contiguous dims");
2556 expectedResultType =
2558 *computedLayout, srcType.getMemorySpace());
2561 if (expectedResultType != resultType)
2562 return emitOpError(
"expected collapsed type to be ")
2563 << expectedResultType <<
" but found " << resultType;
2575 auto cast = op.getOperand().getDefiningOp<CastOp>();
2582 Type newResultType = CollapseShapeOp::computeCollapsedType(
2583 llvm::cast<MemRefType>(cast.getOperand().getType()),
2584 op.getReassociationIndices());
2586 if (newResultType == op.getResultType()) {
2588 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2591 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2603 memref::DimOp, MemRefType>,
2607 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2608 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2609 adaptor.getOperands());
2612 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2613 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2614 adaptor.getOperands());
2621 void ReshapeOp::getAsmResultNames(
2623 setNameFn(getResult(),
"reshape");
2627 Type operandType = getSource().getType();
2628 Type resultType = getResult().getType();
2630 Type operandElementType =
2631 llvm::cast<ShapedType>(operandType).getElementType();
2632 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2633 if (operandElementType != resultElementType)
2634 return emitOpError(
"element types of source and destination memref "
2635 "types should be the same");
2637 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2638 if (!operandMemRefType.getLayout().isIdentity())
2639 return emitOpError(
"source memref type should have identity affine map");
2643 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2644 if (resultMemRefType) {
2645 if (!resultMemRefType.getLayout().isIdentity())
2646 return emitOpError(
"result memref type should have identity affine map");
2647 if (shapeSize == ShapedType::kDynamic)
2648 return emitOpError(
"cannot use shape operand with dynamic length to "
2649 "reshape to statically-ranked memref type");
2650 if (shapeSize != resultMemRefType.getRank())
2652 "length of shape operand differs from the result's memref rank");
2663 return emitOpError(
"store index operand count not equal to memref rank");
2668 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2678 void SubViewOp::getAsmResultNames(
2680 setNameFn(getResult(),
"subview");
2686 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2690 unsigned rank = sourceMemRefType.getRank();
2692 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2693 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2694 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2697 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2701 int64_t targetOffset = sourceOffset;
2702 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2703 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2713 targetStrides.reserve(staticOffsets.size());
2714 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2715 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2722 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2724 targetOffset, targetStrides),
2725 sourceMemRefType.getMemorySpace());
2728 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2743 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2744 staticSizes, staticStrides);
2747 MemRefType SubViewOp::inferRankReducedResultType(
2751 MemRefType inferredType =
2752 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2753 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2755 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2756 return inferredType;
2759 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2761 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2764 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2766 rankReducedStrides.reserve(resultShape.size());
2767 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2768 if (!dimsToProject->contains(idx))
2769 rankReducedStrides.push_back(value);
2773 inferredLayout.getOffset(),
2774 rankReducedStrides),
2775 inferredType.getMemorySpace());
2778 MemRefType SubViewOp::inferRankReducedResultType(
2787 return SubViewOp::inferRankReducedResultType(
2788 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2795 MemRefType resultType,
Value source,
2805 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2808 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2809 staticSizes, staticStrides);
2812 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2825 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2834 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2838 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2842 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2845 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2851 MemRefType resultType,
Value source,
2856 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2860 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2864 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2867 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2883 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2890 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2894 Value SubViewOp::getViewSource() {
return getSource(); }
2899 int64_t t1Offset, t2Offset;
2901 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2902 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2903 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2910 const llvm::SmallBitVector &droppedDims) {
2911 assert(
size_t(t1.getRank()) == droppedDims.size() &&
2912 "incorrect number of bits");
2913 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2914 "incorrect number of dropped dims");
2915 int64_t t1Offset, t2Offset;
2917 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2918 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2919 if (failed(res1) || failed(res2))
2921 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2924 if (t1Strides[i] != t2Strides[
j])
2932 SubViewOp op,
Type expectedType) {
2933 auto memrefType = llvm::cast<ShapedType>(expectedType);
2938 return op->emitError(
"expected result rank to be smaller or equal to ")
2939 <<
"the source rank, but got " << op.getType();
2941 return op->emitError(
"expected result type to be ")
2943 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
2946 return op->emitError(
"expected result element type to be ")
2947 << memrefType.getElementType() <<
", but got " << op.getType();
2949 return op->emitError(
2950 "expected result and source memory spaces to match, but got ")
2953 return op->emitError(
"expected result type to be ")
2955 <<
" or a rank-reduced version. (mismatch of result layout), but "
2959 llvm_unreachable(
"unexpected subview verification result");
2964 MemRefType baseType = getSourceType();
2965 MemRefType subViewType =
getType();
2971 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2972 return emitError(
"different memory spaces specified for base memref "
2974 << baseType <<
" and subview memref type " << subViewType;
2977 if (!baseType.isStrided())
2978 return emitError(
"base type ") << baseType <<
" is not strided";
2982 MemRefType expectedType = SubViewOp::inferResultType(
2983 baseType, staticOffsets, staticSizes, staticStrides);
2988 expectedType, subViewType);
2993 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2995 *
this, expectedType);
3000 *
this, expectedType);
3008 if (failed(unusedDims))
3010 *
this, expectedType);
3015 *
this, expectedType);
3021 staticStrides,
true);
3023 return getOperation()->emitError(boundsResult.
errorMessage);
3029 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3038 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3039 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3040 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3042 unsigned rank = ranks[0];
3044 for (
unsigned idx = 0; idx < rank; ++idx) {
3046 op.isDynamicOffset(idx)
3047 ? op.getDynamicOffset(idx)
3050 op.isDynamicSize(idx)
3051 ? op.getDynamicSize(idx)
3054 op.isDynamicStride(idx)
3055 ? op.getDynamicStride(idx)
3057 res.emplace_back(
Range{offset, size, stride});
3070 MemRefType currentResultType, MemRefType currentSourceType,
3073 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3074 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3076 currentSourceType, currentResultType, mixedSizes);
3077 if (failed(unusedDims))
3080 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3082 unsigned numDimsAfterReduction =
3083 nonRankReducedType.getRank() - unusedDims->count();
3084 shape.reserve(numDimsAfterReduction);
3085 strides.reserve(numDimsAfterReduction);
3086 for (
const auto &[idx, size, stride] :
3087 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3088 nonRankReducedType.getShape(), layout.getStrides())) {
3089 if (unusedDims->test(idx))
3091 shape.push_back(size);
3092 strides.push_back(stride);
3097 layout.getOffset(), strides),
3098 nonRankReducedType.getMemorySpace());
3103 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3104 unsigned rank = memrefType.getRank();
3108 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3109 targetShape, memrefType, offsets, sizes, strides);
3110 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3117 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3118 assert(sourceMemrefType &&
"not a ranked memref type");
3119 auto sourceShape = sourceMemrefType.getShape();
3120 if (sourceShape.equals(desiredShape))
3122 auto maybeRankReductionMask =
3124 if (!maybeRankReductionMask)
3134 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3137 auto mixedOffsets = subViewOp.getMixedOffsets();
3138 auto mixedSizes = subViewOp.getMixedSizes();
3139 auto mixedStrides = subViewOp.getMixedStrides();
3144 return !intValue || intValue.value() != 0;
3151 return !intValue || intValue.value() != 1;
3159 if (!intValue || *intValue != sourceShape[size.index()])
3183 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3187 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3191 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3192 return matchPattern(operand, matchConstantIndex());
3196 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3208 subViewOp.getType(), subViewOp.getSourceType(),
3209 llvm::cast<MemRefType>(castOp.getSource().getType()),
3210 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3211 subViewOp.getMixedStrides());
3216 subViewOp.getLoc(), resultType, castOp.getSource(),
3217 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3218 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3219 subViewOp.getStaticStrides());
3232 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3236 if (subViewOp.getSourceType() == subViewOp.getType()) {
3237 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3241 subViewOp.getSource());
3253 MemRefType resTy = SubViewOp::inferResultType(
3254 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3257 MemRefType nonReducedType = resTy;
3260 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3261 if (droppedDims.none())
3262 return nonReducedType;
3265 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3270 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3271 if (droppedDims.test(i))
3273 targetStrides.push_back(nonReducedStrides[i]);
3274 targetShape.push_back(nonReducedType.getDimSize(i));
3279 offset, targetStrides),
3280 nonReducedType.getMemorySpace());
3296 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3300 MemRefType sourceMemrefType = getSource().getType();
3301 MemRefType resultMemrefType = getResult().getType();
3303 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3305 if (resultMemrefType == sourceMemrefType &&
3306 resultMemrefType.hasStaticShape() &&
3307 (!resultLayout || resultLayout.hasStaticLayout())) {
3308 return getViewSource();
3314 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3315 auto srcSizes = srcSubview.getMixedSizes();
3317 auto offsets = getMixedOffsets();
3319 auto strides = getMixedStrides();
3320 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3321 bool allSizesSame = llvm::equal(sizes, srcSizes);
3322 if (allOffsetsZero && allStridesOne && allSizesSame &&
3323 resultMemrefType == sourceMemrefType)
3324 return getViewSource();
3334 void TransposeOp::getAsmResultNames(
3336 setNameFn(getResult(),
"transpose");
3342 auto originalSizes = memRefType.getShape();
3343 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3344 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3347 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3348 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3357 AffineMapAttr permutation,
3359 auto permutationMap = permutation.getValue();
3360 assert(permutationMap);
3362 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3366 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3367 build(b, result, resultType, in, attrs);
3372 p <<
" " << getIn() <<
" " << getPermutation();
3374 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3380 MemRefType srcType, dstType;
3389 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3396 return emitOpError(
"expected a permutation map");
3397 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3398 return emitOpError(
"expected a permutation map of same rank as the input");
3400 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3401 auto resultType = llvm::cast<MemRefType>(
getType());
3403 .canonicalizeStridedLayout();
3405 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3406 return emitOpError(
"result type ")
3408 <<
" is not equivalent to the canonical transposed input type "
3409 << canonicalResultType;
3416 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3420 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3422 getPermutation().
compose(otherTransposeOp.getPermutation());
3423 getInMutable().assign(otherTransposeOp.getIn());
3424 setPermutation(composedPermutation);
3434 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3435 setNameFn(getResult(),
"view");
3439 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3443 if (!baseType.getLayout().isIdentity())
3444 return emitError(
"unsupported map for base memref type ") << baseType;
3447 if (!viewType.getLayout().isIdentity())
3448 return emitError(
"unsupported map for result memref type ") << viewType;
3451 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3452 return emitError(
"different memory spaces specified for base memref "
3454 << baseType <<
" and view memref type " << viewType;
3457 unsigned numDynamicDims = viewType.getNumDynamicDims();
3458 if (getSizes().size() != numDynamicDims)
3459 return emitError(
"incorrect number of size operands for type ") << viewType;
3464 Value ViewOp::getViewSource() {
return getSource(); }
3467 MemRefType sourceMemrefType = getSource().getType();
3468 MemRefType resultMemrefType = getResult().getType();
3470 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3471 return getViewSource();
3481 LogicalResult matchAndRewrite(ViewOp viewOp,
3484 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3485 return matchPattern(operand, matchConstantIndex());
3490 auto memrefType = viewOp.getType();
3495 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3497 assert(oldOffset == 0 &&
"Expected 0 offset");
3505 newShapeConstants.reserve(memrefType.getRank());
3507 unsigned dynamicDimPos = 0;
3508 unsigned rank = memrefType.getRank();
3509 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3510 int64_t dimSize = memrefType.getDimSize(dim);
3512 if (ShapedType::isStatic(dimSize)) {
3513 newShapeConstants.push_back(dimSize);
3516 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3517 if (
auto constantIndexOp =
3518 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3520 newShapeConstants.push_back(constantIndexOp.value());
3523 newShapeConstants.push_back(dimSize);
3524 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3530 MemRefType newMemRefType =
3533 if (newMemRefType == memrefType)
3537 auto newViewOp = rewriter.
create<ViewOp>(
3538 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3539 viewOp.getByteShift(), newOperands);
3549 LogicalResult matchAndRewrite(ViewOp viewOp,
3551 Value memrefOperand = viewOp.getOperand(0);
3552 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3555 Value allocOperand = memrefCastOp.getOperand();
3560 viewOp.getByteShift(),
3570 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3580 "expects the number of subscripts to be equal to memref rank");
3581 switch (getKind()) {
3582 case arith::AtomicRMWKind::addf:
3583 case arith::AtomicRMWKind::maximumf:
3584 case arith::AtomicRMWKind::minimumf:
3585 case arith::AtomicRMWKind::mulf:
3586 if (!llvm::isa<FloatType>(getValue().
getType()))
3587 return emitOpError() <<
"with kind '"
3588 << arith::stringifyAtomicRMWKind(getKind())
3589 <<
"' expects a floating-point type";
3591 case arith::AtomicRMWKind::addi:
3592 case arith::AtomicRMWKind::maxs:
3593 case arith::AtomicRMWKind::maxu:
3594 case arith::AtomicRMWKind::mins:
3595 case arith::AtomicRMWKind::minu:
3596 case arith::AtomicRMWKind::muli:
3597 case arith::AtomicRMWKind::ori:
3598 case arith::AtomicRMWKind::andi:
3599 if (!llvm::isa<IntegerType>(getValue().
getType()))
3600 return emitOpError() <<
"with kind '"
3601 << arith::stringifyAtomicRMWKind(getKind())
3602 <<
"' expects an integer type";
3621 #define GET_OP_CLASSES
3622 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Check whether this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.