23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
34 return arith::ConstantOp::materialize(builder, value, type, loc);
47 auto cast = operand.get().getDefiningOp<CastOp>();
48 if (cast && operand.get() != inner &&
49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50 operand.set(cast.getOperand());
60 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
62 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
69 auto memrefType = llvm::cast<MemRefType>(value.
getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
79 auto memrefType = llvm::cast<MemRefType>(value.
getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
123 int64_t constValue = it.value();
124 if (!isDynamic(constValue))
143 llvm::cast<IntegerAttr>(ofr.get<
Attribute>()).getInt());
146 std::optional<int64_t> maybeConstant =
168 if (
failed(hasStaticInformation))
181 if (
failed(hasStaticInformation))
190 void AllocOp::getAsmResultNames(
192 setNameFn(getResult(),
"alloc");
195 void AllocaOp::getAsmResultNames(
197 setNameFn(getResult(),
"alloca");
200 template <
typename AllocLikeOp>
202 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203 "applies to only alloc or alloca");
208 if (
static_cast<int64_t
>(op.getDynamicSizes().size()) !=
209 memRefType.getNumDynamicDims())
210 return op.
emitOpError(
"dimension operand count does not equal memref "
211 "dynamic dimension count");
213 unsigned numSymbols = 0;
214 if (!memRefType.getLayout().isIdentity())
215 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216 if (op.getSymbolOperands().size() != numSymbols)
217 return op.
emitOpError(
"symbol operand count does not equal memref symbol "
219 << numSymbols <<
", got " << op.getSymbolOperands().size();
230 "requires an ancestor op with AutomaticAllocationScope trait");
237 template <
typename AllocLikeOp>
245 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
247 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
249 return constSizeArg.isNonNegative();
253 auto memrefType = alloc.getType();
258 newShapeConstants.reserve(memrefType.getRank());
261 unsigned dynamicDimPos = 0;
262 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263 int64_t dimSize = memrefType.getDimSize(dim);
265 if (!ShapedType::isDynamic(dimSize)) {
266 newShapeConstants.push_back(dimSize);
269 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
272 constSizeArg.isNonNegative()) {
274 newShapeConstants.push_back(constSizeArg.getZExtValue());
277 newShapeConstants.push_back(ShapedType::kDynamic);
278 dynamicSizes.push_back(dynamicSize);
284 MemRefType newMemRefType =
286 assert(
static_cast<int64_t
>(dynamicSizes.size()) ==
287 newMemRefType.getNumDynamicDims());
290 auto newAlloc = rewriter.
create<AllocLikeOp>(
291 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292 alloc.getAlignmentAttr());
300 template <
typename T>
306 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
307 if (auto storeOp = dyn_cast<StoreOp>(op))
308 return storeOp.getValue() == alloc;
309 return !isa<DeallocOp>(op);
313 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
324 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
329 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
338 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
339 MemRefType resultType = getType();
342 if (!sourceType.getLayout().isIdentity())
343 return emitError(
"unsupported layout for source memref type ")
347 if (!resultType.getLayout().isIdentity())
348 return emitError(
"unsupported layout for result memref type ")
352 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353 return emitError(
"different memory spaces specified for source memref "
355 << sourceType <<
" and result memref type " << resultType;
358 if (sourceType.getElementType() != resultType.getElementType())
359 return emitError(
"different element types specified for source memref "
361 << sourceType <<
" and result memref type " << resultType;
364 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365 return emitError(
"missing dimension operand for result type ")
367 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368 return emitError(
"unnecessary dimension operand for result type ")
376 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
384 bool printBlockTerminators =
false;
387 if (!getResults().empty()) {
388 p <<
" -> (" << getResultTypes() <<
")";
389 printBlockTerminators =
true;
394 printBlockTerminators);
410 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
420 void AllocaScopeOp::getSuccessorRegions(
433 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
439 if (isa<SideEffects::AutomaticAllocationScopeResource>(
440 effect->getResource()))
456 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
462 if (isa<SideEffects::AutomaticAllocationScopeResource>(
463 effect->getResource()))
486 bool hasPotentialAlloca =
499 if (hasPotentialAlloca) {
532 if (!lastParentWithoutScope ||
545 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
546 if (!lastParentWithoutScope ||
553 Region *containingRegion =
nullptr;
554 for (
auto &r : lastParentWithoutScope->
getRegions()) {
556 assert(containingRegion ==
nullptr &&
557 "only one region can contain the op");
558 containingRegion = &r;
561 assert(containingRegion &&
"op must be contained in a region");
571 return containingRegion->isAncestor(v.getParentRegion());
574 toHoist.push_back(alloc);
581 for (
auto *op : toHoist) {
582 auto *cloned = rewriter.
clone(*op);
599 if (!llvm::isPowerOf2_32(getAlignment()))
600 return emitOpError(
"alignment must be power of 2");
609 setNameFn(getResult(),
"cast");
650 MemRefType sourceType =
651 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
655 if (!sourceType || !resultType)
659 if (sourceType.getElementType() != resultType.getElementType())
663 if (sourceType.getRank() != resultType.getRank())
667 int64_t sourceOffset, resultOffset;
674 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675 auto ss = std::get<0>(it), st = std::get<1>(it);
677 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
682 if (sourceOffset != resultOffset)
683 if (ShapedType::isDynamic(sourceOffset) &&
684 !ShapedType::isDynamic(resultOffset))
688 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
689 auto ss = std::get<0>(it), st = std::get<1>(it);
691 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
699 if (inputs.size() != 1 || outputs.size() != 1)
701 Type a = inputs.front(), b = outputs.front();
702 auto aT = llvm::dyn_cast<MemRefType>(a);
703 auto bT = llvm::dyn_cast<MemRefType>(b);
705 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
709 if (aT.getElementType() != bT.getElementType())
711 if (aT.getLayout() != bT.getLayout()) {
712 int64_t aOffset, bOffset;
716 aStrides.size() != bStrides.size())
723 auto checkCompatible = [](int64_t a, int64_t b) {
724 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
726 if (!checkCompatible(aOffset, bOffset))
728 for (
const auto &aStride :
enumerate(aStrides))
729 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
732 if (aT.getMemorySpace() != bT.getMemorySpace())
736 if (aT.getRank() != bT.getRank())
739 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
755 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757 if (aEltType != bEltType)
760 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762 return aMemSpace == bMemSpace;
785 bool modified =
false;
788 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
792 if (fromType && toType) {
793 if (fromType.getShape() == toType.getShape() &&
794 fromType.getElementType() == toType.getElementType()) {
796 copyOp.getSourceMutable().assign(castOp.getSource());
804 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
808 if (fromType && toType) {
809 if (fromType.getShape() == toType.getShape() &&
810 fromType.getElementType() == toType.getElementType()) {
812 copyOp.getTargetMutable().assign(castOp.getSource());
829 if (copyOp.getSource() != copyOp.getTarget())
840 results.
add<FoldCopyOfCast, FoldSelfCopy>(context);
851 operand.set(castOp.getOperand());
873 setNameFn(getResult(),
"dim");
879 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
880 build(builder, result, source, indexValue);
883 std::optional<int64_t> DimOp::getConstantIndex() {
892 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
893 if (!rankedSourceType)
907 std::map<int64_t, unsigned> numOccurences;
908 for (
auto val : vals)
909 numOccurences[val]++;
910 return numOccurences;
923 llvm::SmallBitVector unusedDims(originalType.getRank());
924 if (originalType.getRank() == reducedType.getRank())
928 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
929 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
930 unusedDims.set(dim.index());
934 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
935 originalType.getRank())
939 int64_t originalOffset, candidateOffset;
955 std::map<int64_t, unsigned> currUnaccountedStrides =
957 std::map<int64_t, unsigned> candidateStridesNumOccurences =
959 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
960 if (!unusedDims.test(dim))
962 int64_t originalStride = originalStrides[dim];
963 if (currUnaccountedStrides[originalStride] >
964 candidateStridesNumOccurences[originalStride]) {
966 currUnaccountedStrides[originalStride]--;
969 if (currUnaccountedStrides[originalStride] ==
970 candidateStridesNumOccurences[originalStride]) {
972 unusedDims.reset(dim);
975 if (currUnaccountedStrides[originalStride] <
976 candidateStridesNumOccurences[originalStride]) {
983 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984 originalType.getRank())
990 MemRefType sourceType = getSourceType();
991 MemRefType resultType = getType();
994 assert(
succeeded(unusedDims) &&
"unable to find unused dims of subview");
1000 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1005 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1011 int64_t indexVal = index.getInt();
1012 if (indexVal < 0 || indexVal >= memrefType.getRank())
1016 if (!memrefType.isDynamicDim(index.getInt())) {
1018 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1022 unsigned unsignedIndex = index.getValue().getZExtValue();
1025 Operation *definingOp = getSource().getDefiningOp();
1027 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1028 return *(alloc.getDynamicSizes().begin() +
1029 memrefType.getDynamicDimIndex(unsignedIndex));
1031 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1032 return *(alloca.getDynamicSizes().begin() +
1033 memrefType.getDynamicDimIndex(unsignedIndex));
1035 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1036 return *(view.getDynamicSizes().begin() +
1037 memrefType.getDynamicDimIndex(unsignedIndex));
1039 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1040 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1041 unsigned resultIndex = 0;
1042 unsigned sourceRank = subview.getSourceType().getRank();
1043 unsigned sourceIndex = 0;
1044 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1045 if (unusedDims.test(i))
1047 if (resultIndex == unsignedIndex) {
1053 assert(subview.isDynamicSize(sourceIndex) &&
1054 "expected dynamic subview size");
1055 return subview.getDynamicSize(sourceIndex);
1058 if (
auto sizeInterface =
1059 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1060 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1061 "Expected dynamic subview size");
1062 return sizeInterface.getDynamicSize(unsignedIndex);
1080 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1084 dim,
"Dim op is not defined by a reshape op.");
1095 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1096 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1097 if (reshape->isBeforeInBlock(definingOp)) {
1100 "dim.getIndex is not defined before reshape in the same block.");
1105 else if (dim->getBlock() != reshape->getBlock() &&
1106 !dim.getIndex().getParentRegion()->isProperAncestor(
1107 reshape->getParentRegion())) {
1112 dim,
"dim.getIndex does not dominate reshape.");
1120 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1121 if (load.
getType() != dim.getType())
1122 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1132 results.
add<DimOfMemRefReshape>(context);
1143 Value elementsPerStride) {
1155 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1156 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1157 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1159 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1162 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1163 <<
", " << getTagMemRef().getType();
1204 bool isStrided = strideInfo.size() == 2;
1205 if (!strideInfo.empty() && !
isStrided) {
1207 "expected two stride related operands");
1212 if (types.size() != 3)
1235 unsigned numOperands = getNumOperands();
1239 if (numOperands < 4)
1240 return emitOpError(
"expected at least 4 operands");
1245 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1246 return emitOpError(
"expected source to be of memref type");
1247 if (numOperands < getSrcMemRefRank() + 4)
1248 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1250 if (!getSrcIndices().empty() &&
1251 !llvm::all_of(getSrcIndices().getTypes(),
1253 return emitOpError(
"expected source indices to be of index type");
1256 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1257 return emitOpError(
"expected destination to be of memref type");
1258 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1259 if (numOperands < numExpectedOperands)
1260 return emitOpError() <<
"expected at least " << numExpectedOperands
1262 if (!getDstIndices().empty() &&
1263 !llvm::all_of(getDstIndices().getTypes(),
1265 return emitOpError(
"expected destination indices to be of index type");
1269 return emitOpError(
"expected num elements to be of index type");
1272 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1273 return emitOpError(
"expected tag to be of memref type");
1274 numExpectedOperands += getTagMemRefRank();
1275 if (numOperands < numExpectedOperands)
1276 return emitOpError() <<
"expected at least " << numExpectedOperands
1278 if (!getTagIndices().empty() &&
1279 !llvm::all_of(getTagIndices().getTypes(),
1281 return emitOpError(
"expected tag indices to be of index type");
1285 if (numOperands != numExpectedOperands &&
1286 numOperands != numExpectedOperands + 2)
1287 return emitOpError(
"incorrect number of operands");
1291 if (!getStride().getType().isIndex() ||
1292 !getNumElementsPerStride().getType().isIndex())
1294 "expected stride and num elements per stride to be of type index");
1318 unsigned numTagIndices = getTagIndices().size();
1319 unsigned tagMemRefRank = getTagMemRefRank();
1320 if (numTagIndices != tagMemRefRank)
1321 return emitOpError() <<
"expected tagIndices to have the same number of "
1322 "elements as the tagMemRef rank, expected "
1323 << tagMemRefRank <<
", but got " << numTagIndices;
1331 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1333 setNameFn(getResult(),
"intptr");
1343 MLIRContext *context, std::optional<Location> location,
1344 ExtractStridedMetadataOp::Adaptor adaptor,
1346 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1350 unsigned sourceRank = sourceType.getRank();
1354 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1356 inferredReturnTypes.push_back(memrefType);
1358 inferredReturnTypes.push_back(indexType);
1360 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1361 inferredReturnTypes.push_back(indexType);
1365 void ExtractStridedMetadataOp::getAsmResultNames(
1367 setNameFn(getBaseBuffer(),
"base_buffer");
1368 setNameFn(getOffset(),
"offset");
1371 if (!getSizes().empty()) {
1372 setNameFn(getSizes().front(),
"sizes");
1373 setNameFn(getStrides().front(),
"strides");
1380 template <
typename Container>
1384 assert(values.size() == maybeConstants.size() &&
1385 " expected values and maybeConstants of the same size");
1386 bool atLeastOneReplacement =
false;
1387 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1392 assert(maybeConstant.template is<Attribute>() &&
1393 "The constified value should be either unchanged (i.e., == result) "
1395 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1396 loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1398 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1402 atLeastOneReplacement =
true;
1405 return atLeastOneReplacement;
1409 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1415 getConstifiedMixedOffset());
1417 getConstifiedMixedSizes());
1419 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1421 return success(atLeastOneReplacement);
1432 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1439 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1457 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1458 Type elementType = memrefType.getElementType();
1468 auto &body = getRegion();
1469 if (body.getNumArguments() != 1)
1470 return emitOpError(
"expected single number of entry block arguments");
1472 if (getResult().getType() != body.getArgument(0).getType())
1473 return emitOpError(
"expected block argument of the same type result type");
1480 "body of 'memref.generic_atomic_rmw' should contain "
1481 "only operations with no side effects");
1511 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1512 <<
"] : " << getMemref().getType() <<
' ';
1522 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1523 Type resultType = getResult().getType();
1524 if (parentType != resultType)
1525 return emitOpError() <<
"types mismatch between yield op: " << resultType
1526 <<
" and its parent: " << parentType;
1538 if (!op.isExternal()) {
1540 if (op.isUninitialized())
1541 p <<
"uninitialized";
1554 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1555 if (!memrefType || !memrefType.hasStaticShape())
1557 <<
"type should be static shaped memref, but got " << type;
1571 if (!llvm::isa<ElementsAttr>(initialValue))
1573 <<
"initial value should be a unit or elements attribute";
1578 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1579 if (!memrefType || !memrefType.hasStaticShape())
1580 return emitOpError(
"type should be static shaped memref, but got ")
1585 if (getInitialValue().has_value()) {
1586 Attribute initValue = getInitialValue().value();
1587 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1588 return emitOpError(
"initial value should be a unit or elements "
1589 "attribute, but got ")
1594 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1595 Type initType = elementsAttr.getType();
1597 if (initType != tensorType)
1598 return emitOpError(
"initial value expected to be of type ")
1599 << tensorType <<
", but was of type " << initType;
1603 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1604 uint64_t alignment = *alignAttr;
1606 if (!llvm::isPowerOf2_64(alignment))
1607 return emitError() <<
"alignment attribute value " << alignment
1608 <<
" is not a power of 2";
1615 ElementsAttr GlobalOp::getConstantInitValue() {
1616 auto initVal = getInitialValue();
1617 if (getConstant() && initVal.has_value())
1618 return llvm::cast<ElementsAttr>(initVal.value());
1633 return emitOpError(
"'")
1634 << getName() <<
"' does not reference a valid global memref";
1636 Type resultType = getResult().getType();
1637 if (global.getType() != resultType)
1638 return emitOpError(
"result type ")
1639 << resultType <<
" does not match type " << global.getType()
1640 <<
" of the global memref @" << getName();
1650 return emitOpError(
"incorrect number of indices for load, expected ")
1667 void MemorySpaceCastOp::getAsmResultNames(
1669 setNameFn(getResult(),
"memspacecast");
1673 if (inputs.size() != 1 || outputs.size() != 1)
1675 Type a = inputs.front(), b = outputs.front();
1676 auto aT = llvm::dyn_cast<MemRefType>(a);
1677 auto bT = llvm::dyn_cast<MemRefType>(b);
1679 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1680 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1683 if (aT.getElementType() != bT.getElementType())
1685 if (aT.getLayout() != bT.getLayout())
1687 if (aT.getShape() != bT.getShape())
1692 return uaT.getElementType() == ubT.getElementType();
1697 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1700 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1701 getSourceMutable().assign(parentCast.getSource());
1712 p <<
" " << getMemref() <<
'[';
1714 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1715 p <<
", locality<" << getLocalityHint();
1716 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1718 (*this)->getAttrs(),
1719 {
"localityHint",
"isWrite",
"isDataCache"});
1726 IntegerAttr localityHint;
1728 StringRef readOrWrite, cacheType;
1745 if (!readOrWrite.equals(
"read") && !readOrWrite.equals(
"write"))
1747 "rw specifier has to be 'read' or 'write'");
1749 PrefetchOp::getIsWriteAttrStrName(),
1752 if (!cacheType.equals(
"data") && !cacheType.equals(
"instr"))
1754 "cache type has to be 'data' or 'instr'");
1757 PrefetchOp::getIsDataCacheAttrStrName(),
1765 return emitOpError(
"too few indices");
1782 auto type = getOperand().getType();
1783 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1784 if (shapedType && shapedType.hasRank())
1786 return IntegerAttr();
1793 void ReinterpretCastOp::getAsmResultNames(
1795 setNameFn(getResult(),
"reinterpret_cast");
1802 MemRefType resultType,
Value source,
1811 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1819 MemRefType resultType,
Value source,
1824 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1828 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1832 strideValues, attrs);
1836 MemRefType resultType,
Value source,
Value offset,
1843 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1850 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1851 auto resultType = llvm::cast<MemRefType>(getType());
1852 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1853 return emitError(
"different memory spaces specified for source type ")
1854 << srcType <<
" and result memref type " << resultType;
1855 if (srcType.getElementType() != resultType.getElementType())
1856 return emitError(
"different element types specified for source type ")
1857 << srcType <<
" and result memref type " << resultType;
1860 for (
auto [idx, resultSize, expectedSize] :
1862 if (!ShapedType::isDynamic(resultSize) &&
1863 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1864 return emitError(
"expected result type with size = ")
1865 << expectedSize <<
" instead of " << resultSize
1866 <<
" in dim = " << idx;
1872 int64_t resultOffset;
1875 return emitError(
"expected result type to have strided layout but found ")
1879 int64_t expectedOffset = getStaticOffsets().front();
1880 if (!ShapedType::isDynamic(resultOffset) &&
1881 !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1882 return emitError(
"expected result type with offset = ")
1883 << expectedOffset <<
" instead of " << resultOffset;
1886 for (
auto [idx, resultStride, expectedStride] :
1888 if (!ShapedType::isDynamic(resultStride) &&
1889 !ShapedType::isDynamic(expectedStride) &&
1890 resultStride != expectedStride)
1891 return emitError(
"expected result type with stride = ")
1892 << expectedStride <<
" instead of " << resultStride
1893 <<
" in dim = " << idx;
1900 Value src = getSource();
1901 auto getPrevSrc = [&]() ->
Value {
1904 return prev.getSource();
1908 return prev.getSource();
1913 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1914 return isConstantIntValue(val, 0);
1916 return prev.getSource();
1921 if (
auto prevSrc = getPrevSrc()) {
1922 getSourceMutable().assign(prevSrc);
1927 if (!ShapedType::isDynamicShape(getType().
getShape()) &&
1928 src.
getType() == getType() && getStaticOffsets().front() == 0) {
1938 ShapedType::isDynamic);
1945 ShapedType::isDynamic);
1949 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1951 assert(values.size() == 1 &&
1952 "reinterpret_cast must have one and only one offset");
1954 ShapedType::isDynamic);
1996 struct ReinterpretCastOpExtractStridedMetadataFolder
2003 auto extractStridedMetadata =
2004 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2005 if (!extractStridedMetadata)
2012 extractStridedMetadata.getConstifiedMixedStrides();
2014 op.getConstifiedMixedStrides();
2015 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2018 unsigned rank = op.getType().getRank();
2019 for (
unsigned i = 0; i < rank; ++i) {
2020 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2025 assert(extractStridedMetadata.getSizes().size() ==
2026 op.getMixedSizes().size() &&
2027 "Strides and sizes rank must match");
2029 extractStridedMetadata.getConstifiedMixedSizes();
2031 op.getConstifiedMixedSizes();
2032 for (
unsigned i = 0; i < rank; ++i) {
2033 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2037 assert(op.getMixedOffsets().size() == 1 &&
2038 "reinterpret_cast with more than one offset should have been "
2039 "rejected by the verifier");
2041 extractStridedMetadata.getConstifiedMixedOffset();
2042 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2043 if (extractOffsetOfr != reinterpretOffsetOfr)
2051 Type srcTy = extractStridedMetadata.getSource().getType();
2053 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2056 extractStridedMetadata.getSource());
2065 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2072 void CollapseShapeOp::getAsmResultNames(
2074 setNameFn(getResult(),
"collapse_shape");
2077 void ExpandShapeOp::getAsmResultNames(
2079 setNameFn(getResult(),
"expand_shape");
2091 bool allowMultipleDynamicDimsPerGroup) {
2093 if (collapsedShape.size() != reassociation.size())
2094 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2095 << reassociation.size() <<
", expected " << collapsedShape.size();
2099 int64_t nextDim = 0;
2102 int64_t collapsedDim = it.index();
2104 bool foundDynamic =
false;
2105 for (int64_t expandedDim : group) {
2106 if (expandedDim != nextDim++)
2107 return op->
emitOpError(
"reassociation indices must be contiguous");
2109 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2111 << expandedDim <<
" is out of bounds";
2114 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2115 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2117 "at most one dimension in a reassociation group may be dynamic");
2118 foundDynamic =
true;
2123 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2126 <<
") must be dynamic if and only if reassociation group is "
2131 if (!foundDynamic) {
2132 int64_t groupSize = 1;
2133 for (int64_t expandedDim : group)
2134 groupSize *= expandedShape[expandedDim];
2135 if (groupSize != collapsedShape[collapsedDim])
2137 << collapsedShape[collapsedDim]
2138 <<
") must equal reassociation group size (" << groupSize <<
")";
2142 if (collapsedShape.empty()) {
2144 for (int64_t d : expandedShape)
2147 "rank 0 memrefs can only be extended/collapsed with/from ones");
2148 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2152 << expandedShape.size()
2153 <<
") inconsistent with number of reassociation indices (" << nextDim
2166 getReassociationIndices());
2175 getReassociationIndices());
2187 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2202 reverseResultStrides.reserve(resultShape.size());
2203 unsigned shapeIndex = resultShape.size() - 1;
2204 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2206 int64_t currentStrideToExpand = std::get<1>(it);
2207 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2208 reverseResultStrides.push_back(currentStrideToExpand);
2209 currentStrideToExpand =
2215 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2216 resultStrides.resize(resultShape.size(), 1);
2223 if (srcType.getLayout().isIdentity()) {
2226 MemRefLayoutAttrInterface layout;
2228 srcType.getMemorySpace());
2234 if (
failed(computedLayout))
2236 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2237 srcType.getMemorySpace());
2244 auto srcType = llvm::cast<MemRefType>(src.
getType());
2246 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2249 assert(
succeeded(resultType) &&
"could not compute layout");
2250 build(builder, result, *resultType, src, reassociation);
2254 MemRefType srcType = getSrcType();
2255 MemRefType resultType = getResultType();
2257 if (srcType.getRank() > resultType.getRank()) {
2258 auto r0 = srcType.getRank();
2259 auto r1 = resultType.getRank();
2260 return emitOpError(
"has source rank ")
2261 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2262 << r0 <<
" > " << r1 <<
").";
2267 resultType.getShape(),
2268 getReassociationIndices(),
2274 srcType, resultType.getShape(), getReassociationIndices());
2275 if (
failed(expectedResultType))
2276 return emitOpError(
"invalid source layout map");
2279 if (*expectedResultType != resultType)
2280 return emitOpError(
"expected expanded type to be ")
2281 << *expectedResultType <<
" but found " << resultType;
2303 bool strict =
false) {
2306 auto srcShape = srcType.getShape();
2316 resultStrides.reserve(reassociation.size());
2319 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2320 ref = ref.drop_back();
2321 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2322 resultStrides.push_back(srcStrides[ref.back()]);
2328 resultStrides.push_back(ShapedType::kDynamic);
2333 unsigned resultStrideIndex = resultStrides.size() - 1;
2337 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2349 if (strict && (stride.saturated || srcStride.saturated))
2352 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2359 bool CollapseShapeOp::isGuaranteedCollapsible(
2362 if (srcType.getLayout().isIdentity())
2369 MemRefType CollapseShapeOp::computeCollapsedType(
2372 resultShape.reserve(reassociation.size());
2375 for (int64_t srcDim : group)
2378 resultShape.push_back(groupSize.asInteger());
2381 if (srcType.getLayout().isIdentity()) {
2384 MemRefLayoutAttrInterface layout;
2386 srcType.getMemorySpace());
2395 "invalid source layout map or collapsing non-contiguous dims");
2396 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2397 srcType.getMemorySpace());
2403 auto srcType = llvm::cast<MemRefType>(src.
getType());
2404 MemRefType resultType =
2405 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2406 build(b, result, resultType, src, attrs);
2412 MemRefType srcType = getSrcType();
2413 MemRefType resultType = getResultType();
2415 if (srcType.getRank() < resultType.getRank()) {
2416 auto r0 = srcType.getRank();
2417 auto r1 = resultType.getRank();
2418 return emitOpError(
"has source rank ")
2419 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2420 << r0 <<
" < " << r1 <<
").";
2425 srcType.getShape(), getReassociationIndices(),
2430 MemRefType expectedResultType;
2431 if (srcType.getLayout().isIdentity()) {
2434 MemRefLayoutAttrInterface layout;
2435 expectedResultType =
2436 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2437 srcType.getMemorySpace());
2444 if (
failed(computedLayout))
2446 "invalid source layout map or collapsing non-contiguous dims");
2447 expectedResultType =
2449 *computedLayout, srcType.getMemorySpace());
2452 if (expectedResultType != resultType)
2453 return emitOpError(
"expected collapsed type to be ")
2454 << expectedResultType <<
" but found " << resultType;
2473 Type newResultType = CollapseShapeOp::computeCollapsedType(
2474 llvm::cast<MemRefType>(cast.getOperand().getType()),
2475 op.getReassociationIndices());
2477 if (newResultType == op.getResultType()) {
2479 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2482 op->
getLoc(), cast.getSource(), op.getReassociationIndices());
2496 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2497 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2498 adaptor.getOperands());
2501 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2502 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2503 adaptor.getOperands());
2510 void ReshapeOp::getAsmResultNames(
2512 setNameFn(getResult(),
"reshape");
2516 Type operandType = getSource().getType();
2517 Type resultType = getResult().getType();
2519 Type operandElementType =
2520 llvm::cast<ShapedType>(operandType).getElementType();
2521 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2522 if (operandElementType != resultElementType)
2523 return emitOpError(
"element types of source and destination memref "
2524 "types should be the same");
2526 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2527 if (!operandMemRefType.getLayout().isIdentity())
2528 return emitOpError(
"source memref type should have identity affine map");
2531 llvm::cast<MemRefType>(
getShape().getType()).getDimSize(0);
2532 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2533 if (resultMemRefType) {
2534 if (!resultMemRefType.getLayout().isIdentity())
2535 return emitOpError(
"result memref type should have identity affine map");
2536 if (shapeSize == ShapedType::kDynamic)
2537 return emitOpError(
"cannot use shape operand with dynamic length to "
2538 "reshape to statically-ranked memref type");
2539 if (shapeSize != resultMemRefType.getRank())
2541 "length of shape operand differs from the result's memref rank");
2552 return emitOpError(
"store index operand count not equal to memref rank");
2567 void SubViewOp::getAsmResultNames(
2569 setNameFn(getResult(),
"subview");
2575 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2579 unsigned rank = sourceMemRefType.getRank();
2581 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2582 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2583 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2590 int64_t targetOffset = sourceOffset;
2591 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2592 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2602 targetStrides.reserve(staticOffsets.size());
2603 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2604 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2611 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2613 targetOffset, targetStrides),
2614 sourceMemRefType.getMemorySpace());
2617 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2632 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2633 staticSizes, staticStrides);
2637 MemRefType sourceRankedTensorType,
2641 auto inferredType = llvm::cast<MemRefType>(
2642 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2643 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2645 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2646 return inferredType;
2649 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2651 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2654 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2656 rankReducedStrides.reserve(resultShape.size());
2657 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2658 if (!dimsToProject->contains(idx))
2659 rankReducedStrides.push_back(value);
2663 inferredLayout.getOffset(),
2664 rankReducedStrides),
2665 inferredType.getMemorySpace());
2669 MemRefType sourceRankedTensorType,
2678 return SubViewOp::inferRankReducedResultType(
2679 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2686 MemRefType resultType,
Value source,
2696 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2699 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2700 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2702 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2716 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2725 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2729 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2733 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2736 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2742 MemRefType resultType,
Value source,
2747 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2751 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2755 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2758 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2774 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2781 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2785 Value SubViewOp::getViewSource() {
return getSource(); }
2790 int64_t t1Offset, t2Offset;
2801 const llvm::SmallBitVector &droppedDims) {
2802 assert(
size_t(t1.getRank()) == droppedDims.size() &&
"incorrect number of bits");
2803 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2804 "incorrect number of dropped dims");
2805 int64_t t1Offset, t2Offset;
2811 for (int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
2814 if (t1Strides[i] != t2Strides[
j])
2823 auto memrefType = llvm::cast<ShapedType>(expectedType);
2828 return op->
emitError(
"expected result rank to be smaller or equal to ")
2829 <<
"the source rank. ";
2831 return op->
emitError(
"expected result type to be ")
2833 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2835 return op->
emitError(
"expected result element type to be ")
2836 << memrefType.getElementType();
2838 return op->
emitError(
"expected result and source memory spaces to match.");
2840 return op->
emitError(
"expected result type to be ")
2842 <<
" or a rank-reduced version. (mismatch of result layout) ";
2844 llvm_unreachable(
"unexpected subview verification result");
2849 MemRefType baseType = getSourceType();
2850 MemRefType subViewType = getType();
2853 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2854 return emitError(
"different memory spaces specified for base memref "
2856 << baseType <<
" and subview memref type " << subViewType;
2860 return emitError(
"base type ") << baseType <<
" is not strided";
2864 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2865 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2870 expectedType, subViewType);
2875 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2877 *
this, expectedType);
2882 *
this, expectedType);
2892 *
this, expectedType);
2897 *
this, expectedType);
2903 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
2912 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2913 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
2914 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
2916 unsigned rank = ranks[0];
2918 for (
unsigned idx = 0; idx < rank; ++idx) {
2920 op.isDynamicOffset(idx)
2921 ? op.getDynamicOffset(idx)
2924 op.isDynamicSize(idx)
2925 ? op.getDynamicSize(idx)
2928 op.isDynamicStride(idx)
2929 ? op.getDynamicStride(idx)
2931 res.emplace_back(
Range{offset, size, stride});
2944 MemRefType currentResultType, MemRefType currentSourceType,
2947 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2948 sourceType, mixedOffsets, mixedSizes, mixedStrides));
2950 currentSourceType, currentResultType, mixedSizes);
2954 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2956 unsigned numDimsAfterReduction =
2957 nonRankReducedType.getRank() - unusedDims->count();
2958 shape.reserve(numDimsAfterReduction);
2959 strides.reserve(numDimsAfterReduction);
2960 for (
const auto &[idx, size, stride] :
2961 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2962 nonRankReducedType.getShape(), layout.getStrides())) {
2963 if (unusedDims->test(idx))
2965 shape.push_back(size);
2966 strides.push_back(stride);
2971 layout.getOffset(), strides),
2972 nonRankReducedType.getMemorySpace());
2977 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2978 unsigned rank = memrefType.getRank();
2983 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2984 targetShape, memrefType, offsets, sizes, strides));
2985 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2992 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
2993 assert(sourceMemrefType &&
"not a ranked memref type");
2994 auto sourceShape = sourceMemrefType.getShape();
2995 if (sourceShape.equals(desiredShape))
2997 auto maybeRankReductionMask =
2999 if (!maybeRankReductionMask)
3009 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3012 auto mixedOffsets = subViewOp.getMixedOffsets();
3013 auto mixedSizes = subViewOp.getMixedSizes();
3014 auto mixedStrides = subViewOp.getMixedStrides();
3019 return !intValue || intValue.value() != 0;
3026 return !intValue || intValue.value() != 1;
3034 if (!intValue || *intValue != sourceShape[size.index()])
3058 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3066 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3067 return matchPattern(operand, matchConstantIndex());
3071 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3083 subViewOp.getType(), subViewOp.getSourceType(),
3084 llvm::cast<MemRefType>(castOp.getSource().getType()),
3085 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3086 subViewOp.getMixedStrides());
3091 subViewOp.getLoc(), resultType, castOp.getSource(),
3092 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3093 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3094 subViewOp.getStaticStrides());
3111 if (subViewOp.getSourceType() == subViewOp.getType()) {
3112 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3116 subViewOp.getSource());
3128 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3129 mixedSizes, mixedStrides);
3132 MemRefType nonReducedType = cast<MemRefType>(resTy);
3135 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3136 if (droppedDims.none())
3137 return nonReducedType;
3145 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3146 if (droppedDims.test(i))
3148 targetStrides.push_back(nonReducedStrides[i]);
3149 targetShape.push_back(nonReducedType.getDimSize(i));
3154 offset, targetStrides),
3155 nonReducedType.getMemorySpace());
3171 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3175 auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3176 auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3178 if (resultShapedType.hasStaticShape() &&
3179 resultShapedType == sourceShapedType) {
3180 return getViewSource();
3186 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3187 auto srcSizes = srcSubview.getMixedSizes();
3189 auto offsets = getMixedOffsets();
3190 bool allOffsetsZero = llvm::all_of(
3192 auto strides = getMixedStrides();
3193 bool allStridesOne = llvm::all_of(
3195 bool allSizesSame = llvm::equal(sizes, srcSizes);
3196 if (allOffsetsZero && allStridesOne && allSizesSame &&
3197 resultShapedType == sourceShapedType)
3198 return getViewSource();
3208 void TransposeOp::getAsmResultNames(
3210 setNameFn(getResult(),
"transpose");
3216 auto originalSizes = memRefType.getShape();
3218 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3221 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3222 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3231 AffineMapAttr permutation,
3233 auto permutationMap = permutation.getValue();
3234 assert(permutationMap);
3236 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3240 build(b, result, resultType, in, attrs);
3241 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3246 p <<
" " << getIn() <<
" " << getPermutation();
3248 p <<
" : " << getIn().getType() <<
" to " << getType();
3254 MemRefType srcType, dstType;
3263 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3270 return emitOpError(
"expected a permutation map");
3271 if (getPermutation().getNumDims() != getIn().getType().getRank())
3272 return emitOpError(
"expected a permutation map of same rank as the input");
3274 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3275 auto resultType = llvm::cast<MemRefType>(getType());
3280 return emitOpError(
"result type ")
3282 <<
" is not equivalent to the canonical transposed input type "
3283 << canonicalResultType;
3290 if (getPermutation().isIdentity() && getType() == getIn().getType())
3294 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3296 getPermutation().
compose(otherTransposeOp.getPermutation());
3297 getInMutable().assign(otherTransposeOp.getIn());
3298 setPermutation(composedPermutation);
3308 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3309 setNameFn(getResult(),
"view");
3313 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3314 auto viewType = getType();
3317 if (!baseType.getLayout().isIdentity())
3318 return emitError(
"unsupported map for base memref type ") << baseType;
3321 if (!viewType.getLayout().isIdentity())
3322 return emitError(
"unsupported map for result memref type ") << viewType;
3325 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3326 return emitError(
"different memory spaces specified for base memref "
3328 << baseType <<
" and view memref type " << viewType;
3331 unsigned numDynamicDims = viewType.getNumDynamicDims();
3332 if (getSizes().size() != numDynamicDims)
3333 return emitError(
"incorrect number of size operands for type ") << viewType;
3338 Value ViewOp::getViewSource() {
return getSource(); }
3348 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3349 return matchPattern(operand, matchConstantIndex());
3354 auto memrefType = viewOp.getType();
3361 assert(oldOffset == 0 &&
"Expected 0 offset");
3369 newShapeConstants.reserve(memrefType.getRank());
3371 unsigned dynamicDimPos = 0;
3372 unsigned rank = memrefType.getRank();
3373 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3374 int64_t dimSize = memrefType.getDimSize(dim);
3376 if (!ShapedType::isDynamic(dimSize)) {
3377 newShapeConstants.push_back(dimSize);
3380 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3381 if (
auto constantIndexOp =
3382 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3384 newShapeConstants.push_back(constantIndexOp.value());
3387 newShapeConstants.push_back(dimSize);
3388 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3394 MemRefType newMemRefType =
3397 if (newMemRefType == memrefType)
3401 auto newViewOp = rewriter.
create<ViewOp>(
3402 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3403 viewOp.getByteShift(), newOperands);
3415 Value memrefOperand = viewOp.getOperand(0);
3416 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3419 Value allocOperand = memrefCastOp.getOperand();
3424 viewOp.getByteShift(),
3434 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3444 "expects the number of subscripts to be equal to memref rank");
3445 switch (getKind()) {
3446 case arith::AtomicRMWKind::addf:
3447 case arith::AtomicRMWKind::maximumf:
3448 case arith::AtomicRMWKind::minimumf:
3449 case arith::AtomicRMWKind::mulf:
3450 if (!llvm::isa<FloatType>(getValue().getType()))
3451 return emitOpError() <<
"with kind '"
3452 << arith::stringifyAtomicRMWKind(getKind())
3453 <<
"' expects a floating-point type";
3455 case arith::AtomicRMWKind::addi:
3456 case arith::AtomicRMWKind::maxs:
3457 case arith::AtomicRMWKind::maxu:
3458 case arith::AtomicRMWKind::mins:
3459 case arith::AtomicRMWKind::minu:
3460 case arith::AtomicRMWKind::muli:
3461 case arith::AtomicRMWKind::ori:
3462 case arith::AtomicRMWKind::andi:
3463 if (!llvm::isa<IntegerType>(getValue().getType()))
3464 return emitOpError() <<
"with kind '"
3465 << arith::stringifyAtomicRMWKind(getKind())
3466 <<
"' expects an integer type";
3485 #define GET_OP_CLASSES
3486 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.