23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
33 static Wrapper stride(int64_t v) {
34 return (ShapedType::isDynamic(v)) ? Wrapper{
true, 0} : Wrapper{
false, v};
36 static Wrapper offset(int64_t v) {
37 return (ShapedType::isDynamic(v)) ? Wrapper{
true, 0} : Wrapper{
false, v};
39 static Wrapper size(int64_t v) {
40 return (ShapedType::isDynamic(v)) ? Wrapper{
true, 0} : Wrapper{
false, v};
42 int64_t asOffset() {
return saturated ? ShapedType::kDynamic : v; }
43 int64_t asSize() {
return saturated ? ShapedType::kDynamic : v; }
44 int64_t asStride() {
return saturated ? ShapedType::kDynamic : v; }
46 return (saturated && other.saturated) ||
47 (!saturated && !other.saturated && v == other.v);
49 bool operator!=(Wrapper other) {
return !(*
this == other); }
51 if (saturated || other.saturated)
52 return Wrapper{
true, 0};
53 return Wrapper{
false, other.v + v};
56 if (saturated || other.saturated)
57 return Wrapper{
true, 0};
58 return Wrapper{
false, other.v * v};
71 return arith::ConstantOp::materialize(builder, value, type, loc);
84 auto cast = operand.get().getDefiningOp<CastOp>();
85 if (cast && operand.get() != inner &&
86 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
87 operand.set(cast.getOperand());
97 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
99 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
106 auto memrefType = llvm::cast<MemRefType>(value.
getType());
108 if (memrefType.isDynamicDim(dim))
109 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
111 return builder.
getIndexAttr(memrefType.getDimSize(dim));
116 auto memrefType = llvm::cast<MemRefType>(value.
getType());
118 for (int64_t i = 0; i < memrefType.getRank(); ++i)
160 int64_t constValue = it.value();
161 if (!isDynamic(constValue))
180 llvm::cast<IntegerAttr>(ofr.get<
Attribute>()).getInt());
183 std::optional<int64_t> maybeConstant =
205 if (
failed(hasStaticInformation))
218 if (
failed(hasStaticInformation))
227 void AllocOp::getAsmResultNames(
229 setNameFn(getResult(),
"alloc");
232 void AllocaOp::getAsmResultNames(
234 setNameFn(getResult(),
"alloca");
237 template <
typename AllocLikeOp>
239 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
240 "applies to only alloc or alloca");
245 if (
static_cast<int64_t
>(op.getDynamicSizes().size()) !=
246 memRefType.getNumDynamicDims())
247 return op.
emitOpError(
"dimension operand count does not equal memref "
248 "dynamic dimension count");
250 unsigned numSymbols = 0;
251 if (!memRefType.getLayout().isIdentity())
252 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
253 if (op.getSymbolOperands().size() != numSymbols)
254 return op.
emitOpError(
"symbol operand count does not equal memref symbol "
256 << numSymbols <<
", got " << op.getSymbolOperands().size();
267 "requires an ancestor op with AutomaticAllocationScope trait");
274 template <
typename AllocLikeOp>
282 if (llvm::none_of(alloc.getDynamicSizes(), [](
Value operand) {
284 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
286 return constSizeArg.isNonNegative();
290 auto memrefType = alloc.getType();
295 newShapeConstants.reserve(memrefType.getRank());
298 unsigned dynamicDimPos = 0;
299 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
300 int64_t dimSize = memrefType.getDimSize(dim);
302 if (!ShapedType::isDynamic(dimSize)) {
303 newShapeConstants.push_back(dimSize);
306 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
309 constSizeArg.isNonNegative()) {
311 newShapeConstants.push_back(constSizeArg.getZExtValue());
314 newShapeConstants.push_back(ShapedType::kDynamic);
315 dynamicSizes.push_back(dynamicSize);
321 MemRefType newMemRefType =
323 assert(
static_cast<int64_t
>(dynamicSizes.size()) ==
324 newMemRefType.getNumDynamicDims());
327 auto newAlloc = rewriter.
create<AllocLikeOp>(
328 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
329 alloc.getAlignmentAttr());
337 template <
typename T>
343 if (llvm::any_of(alloc->getUsers(), [&](
Operation *op) {
344 if (auto storeOp = dyn_cast<StoreOp>(op))
345 return storeOp.getValue() == alloc;
346 return !isa<DeallocOp>(op);
350 for (
Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
361 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
366 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
375 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
376 MemRefType resultType = getType();
379 if (!sourceType.getLayout().isIdentity())
380 return emitError(
"unsupported layout for source memref type ")
384 if (!resultType.getLayout().isIdentity())
385 return emitError(
"unsupported layout for result memref type ")
389 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
390 return emitError(
"different memory spaces specified for source memref "
392 << sourceType <<
" and result memref type " << resultType;
395 if (sourceType.getElementType() != resultType.getElementType())
396 return emitError(
"different element types specified for source memref "
398 << sourceType <<
" and result memref type " << resultType;
401 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
402 return emitError(
"missing dimension operand for result type ")
404 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
405 return emitError(
"unnecessary dimension operand for result type ")
413 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
421 bool printBlockTerminators =
false;
424 if (!getResults().empty()) {
425 p <<
" -> (" << getResultTypes() <<
")";
426 printBlockTerminators =
true;
431 printBlockTerminators);
447 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
457 void AllocaScopeOp::getSuccessorRegions(
470 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
476 if (isa<SideEffects::AutomaticAllocationScopeResource>(
477 effect->getResource()))
493 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
499 if (isa<SideEffects::AutomaticAllocationScopeResource>(
500 effect->getResource()))
523 bool hasPotentialAlloca =
536 if (hasPotentialAlloca) {
569 if (!lastParentWithoutScope ||
582 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
583 if (!lastParentWithoutScope ||
590 Region *containingRegion =
nullptr;
591 for (
auto &r : lastParentWithoutScope->
getRegions()) {
593 assert(containingRegion ==
nullptr &&
594 "only one region can contain the op");
595 containingRegion = &r;
598 assert(containingRegion &&
"op must be contained in a region");
608 return containingRegion->isAncestor(v.getParentRegion());
611 toHoist.push_back(alloc);
618 for (
auto *op : toHoist) {
619 auto *cloned = rewriter.
clone(*op);
636 if (!llvm::isPowerOf2_32(getAlignment()))
637 return emitOpError(
"alignment must be power of 2");
646 setNameFn(getResult(),
"cast");
687 MemRefType sourceType =
688 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
689 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692 if (!sourceType || !resultType)
696 if (sourceType.getElementType() != resultType.getElementType())
700 if (sourceType.getRank() != resultType.getRank())
704 int64_t sourceOffset, resultOffset;
711 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
712 auto ss = std::get<0>(it), st = std::get<1>(it);
714 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
719 if (sourceOffset != resultOffset)
720 if (ShapedType::isDynamic(sourceOffset) &&
721 !ShapedType::isDynamic(resultOffset))
725 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
726 auto ss = std::get<0>(it), st = std::get<1>(it);
728 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
736 if (inputs.size() != 1 || outputs.size() != 1)
738 Type a = inputs.front(), b = outputs.front();
739 auto aT = llvm::dyn_cast<MemRefType>(a);
740 auto bT = llvm::dyn_cast<MemRefType>(b);
742 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
743 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
746 if (aT.getElementType() != bT.getElementType())
748 if (aT.getLayout() != bT.getLayout()) {
749 int64_t aOffset, bOffset;
753 aStrides.size() != bStrides.size())
760 auto checkCompatible = [](int64_t a, int64_t b) {
761 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
763 if (!checkCompatible(aOffset, bOffset))
765 for (
const auto &aStride :
enumerate(aStrides))
766 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
769 if (aT.getMemorySpace() != bT.getMemorySpace())
773 if (aT.getRank() != bT.getRank())
776 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
777 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
778 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
792 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
793 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
794 if (aEltType != bEltType)
797 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
798 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
799 return aMemSpace == bMemSpace;
822 bool modified =
false;
825 if (
auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
826 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
827 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
829 if (fromType && toType) {
830 if (fromType.getShape() == toType.getShape() &&
831 fromType.getElementType() == toType.getElementType()) {
833 copyOp.getSourceMutable().assign(castOp.getSource());
841 if (
auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
842 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
843 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
845 if (fromType && toType) {
846 if (fromType.getShape() == toType.getShape() &&
847 fromType.getElementType() == toType.getElementType()) {
849 copyOp.getTargetMutable().assign(castOp.getSource());
866 if (copyOp.getSource() != copyOp.getTarget())
877 results.
add<FoldCopyOfCast, FoldSelfCopy>(context);
888 operand.set(castOp.getOperand());
910 setNameFn(getResult(),
"dim");
916 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
917 build(builder, result, source, indexValue);
920 std::optional<int64_t> DimOp::getConstantIndex() {
929 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
930 if (!rankedSourceType)
943 std::map<int64_t, unsigned> numOccurences;
944 for (
auto val : vals)
945 numOccurences[val]++;
946 return numOccurences;
956 static std::optional<llvm::SmallBitVector>
959 llvm::SmallBitVector unusedDims(originalType.getRank());
960 if (originalType.getRank() == reducedType.getRank())
964 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
965 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
966 unusedDims.set(dim.index());
970 if (
static_cast<int64_t
>(unusedDims.count()) + reducedType.getRank() ==
971 originalType.getRank())
975 int64_t originalOffset, candidateOffset;
991 std::map<int64_t, unsigned> currUnaccountedStrides =
993 std::map<int64_t, unsigned> candidateStridesNumOccurences =
995 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
996 if (!unusedDims.test(dim))
998 int64_t originalStride = originalStrides[dim];
999 if (currUnaccountedStrides[originalStride] >
1000 candidateStridesNumOccurences[originalStride]) {
1002 currUnaccountedStrides[originalStride]--;
1005 if (currUnaccountedStrides[originalStride] ==
1006 candidateStridesNumOccurences[originalStride]) {
1008 unusedDims.reset(dim);
1011 if (currUnaccountedStrides[originalStride] <
1012 candidateStridesNumOccurences[originalStride]) {
1015 return std::nullopt;
1019 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1020 originalType.getRank())
1021 return std::nullopt;
1026 MemRefType sourceType = getSourceType();
1027 MemRefType resultType = getType();
1028 std::optional<llvm::SmallBitVector> unusedDims =
1030 assert(unusedDims &&
"unable to find unused dims of subview");
1036 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1041 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1047 int64_t indexVal = index.getInt();
1048 if (indexVal < 0 || indexVal >= memrefType.getRank())
1052 if (!memrefType.isDynamicDim(index.getInt())) {
1054 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1058 unsigned unsignedIndex = index.getValue().getZExtValue();
1061 Operation *definingOp = getSource().getDefiningOp();
1063 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1064 return *(alloc.getDynamicSizes().begin() +
1065 memrefType.getDynamicDimIndex(unsignedIndex));
1067 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1068 return *(alloca.getDynamicSizes().begin() +
1069 memrefType.getDynamicDimIndex(unsignedIndex));
1071 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1072 return *(view.getDynamicSizes().begin() +
1073 memrefType.getDynamicDimIndex(unsignedIndex));
1075 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1076 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1077 unsigned resultIndex = 0;
1078 unsigned sourceRank = subview.getSourceType().getRank();
1079 unsigned sourceIndex = 0;
1080 for (
auto i : llvm::seq<unsigned>(0, sourceRank)) {
1081 if (unusedDims.test(i))
1083 if (resultIndex == unsignedIndex) {
1089 assert(subview.isDynamicSize(sourceIndex) &&
1090 "expected dynamic subview size");
1091 return subview.getDynamicSize(sourceIndex);
1094 if (
auto sizeInterface =
1095 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1096 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1097 "Expected dynamic subview size");
1098 return sizeInterface.getDynamicSize(unsignedIndex);
1116 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1126 rewriter.
create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1127 if (load.
getType() != dim.getType())
1128 load = rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), load);
1138 results.
add<DimOfMemRefReshape>(context);
1149 Value elementsPerStride) {
1161 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1162 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1163 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1165 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1168 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1169 <<
", " << getTagMemRef().getType();
1210 bool isStrided = strideInfo.size() == 2;
1211 if (!strideInfo.empty() && !
isStrided) {
1213 "expected two stride related operands");
1218 if (types.size() != 3)
1241 unsigned numOperands = getNumOperands();
1245 if (numOperands < 4)
1246 return emitOpError(
"expected at least 4 operands");
1251 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1252 return emitOpError(
"expected source to be of memref type");
1253 if (numOperands < getSrcMemRefRank() + 4)
1254 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1256 if (!getSrcIndices().empty() &&
1257 !llvm::all_of(getSrcIndices().getTypes(),
1259 return emitOpError(
"expected source indices to be of index type");
1262 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1263 return emitOpError(
"expected destination to be of memref type");
1264 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1265 if (numOperands < numExpectedOperands)
1266 return emitOpError() <<
"expected at least " << numExpectedOperands
1268 if (!getDstIndices().empty() &&
1269 !llvm::all_of(getDstIndices().getTypes(),
1271 return emitOpError(
"expected destination indices to be of index type");
1275 return emitOpError(
"expected num elements to be of index type");
1278 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1279 return emitOpError(
"expected tag to be of memref type");
1280 numExpectedOperands += getTagMemRefRank();
1281 if (numOperands < numExpectedOperands)
1282 return emitOpError() <<
"expected at least " << numExpectedOperands
1284 if (!getTagIndices().empty() &&
1285 !llvm::all_of(getTagIndices().getTypes(),
1287 return emitOpError(
"expected tag indices to be of index type");
1291 if (numOperands != numExpectedOperands &&
1292 numOperands != numExpectedOperands + 2)
1293 return emitOpError(
"incorrect number of operands");
1297 if (!getStride().getType().isIndex() ||
1298 !getNumElementsPerStride().getType().isIndex())
1300 "expected stride and num elements per stride to be of type index");
1324 unsigned numTagIndices = getTagIndices().size();
1325 unsigned tagMemRefRank = getTagMemRefRank();
1326 if (numTagIndices != tagMemRefRank)
1327 return emitOpError() <<
"expected tagIndices to have the same number of "
1328 "elements as the tagMemRef rank, expected "
1329 << tagMemRefRank <<
", but got " << numTagIndices;
1337 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1339 setNameFn(getResult(),
"intptr");
1349 MLIRContext *context, std::optional<Location> location,
1350 ExtractStridedMetadataOp::Adaptor adaptor,
1352 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1356 unsigned sourceRank = sourceType.getRank();
1360 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1362 inferredReturnTypes.push_back(memrefType);
1364 inferredReturnTypes.push_back(indexType);
1366 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1367 inferredReturnTypes.push_back(indexType);
1371 void ExtractStridedMetadataOp::getAsmResultNames(
1373 setNameFn(getBaseBuffer(),
"base_buffer");
1374 setNameFn(getOffset(),
"offset");
1377 if (!getSizes().empty()) {
1378 setNameFn(getSizes().front(),
"sizes");
1379 setNameFn(getStrides().front(),
"strides");
1386 template <
typename Container>
1390 assert(values.size() == maybeConstants.size() &&
1391 " expected values and maybeConstants of the same size");
1392 bool atLeastOneReplacement =
false;
1393 for (
auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1398 assert(maybeConstant.template is<Attribute>() &&
1399 "The constified value should be either unchanged (i.e., == result) "
1401 Value constantVal = rewriter.
create<arith::ConstantIndexOp>(
1402 loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1404 for (
Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1408 atLeastOneReplacement =
true;
1411 return atLeastOneReplacement;
1415 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1421 getConstifiedMixedOffset());
1423 getConstifiedMixedSizes());
1425 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1427 return success(atLeastOneReplacement);
1438 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1445 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1462 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1463 Type elementType = memrefType.getElementType();
1473 auto &body = getRegion();
1474 if (body.getNumArguments() != 1)
1475 return emitOpError(
"expected single number of entry block arguments");
1477 if (getResult().getType() != body.getArgument(0).getType())
1478 return emitOpError(
"expected block argument of the same type result type");
1485 "body of 'memref.generic_atomic_rmw' should contain "
1486 "only operations with no side effects");
1516 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1517 <<
"] : " << getMemref().
getType() <<
' ';
1527 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1528 Type resultType = getResult().getType();
1529 if (parentType != resultType)
1530 return emitOpError() <<
"types mismatch between yield op: " << resultType
1531 <<
" and its parent: " << parentType;
1543 if (!op.isExternal()) {
1545 if (op.isUninitialized())
1546 p <<
"uninitialized";
1559 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1560 if (!memrefType || !memrefType.hasStaticShape())
1562 <<
"type should be static shaped memref, but got " << type;
1576 if (!llvm::isa<ElementsAttr>(initialValue))
1578 <<
"initial value should be a unit or elements attribute";
1583 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1584 if (!memrefType || !memrefType.hasStaticShape())
1585 return emitOpError(
"type should be static shaped memref, but got ")
1590 if (getInitialValue().has_value()) {
1591 Attribute initValue = getInitialValue().value();
1592 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1593 return emitOpError(
"initial value should be a unit or elements "
1594 "attribute, but got ")
1599 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1600 Type initType = elementsAttr.getType();
1602 if (initType != tensorType)
1603 return emitOpError(
"initial value expected to be of type ")
1604 << tensorType <<
", but was of type " << initType;
1608 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1609 uint64_t alignment = *alignAttr;
1611 if (!llvm::isPowerOf2_64(alignment))
1612 return emitError() <<
"alignment attribute value " << alignment
1613 <<
" is not a power of 2";
1620 ElementsAttr GlobalOp::getConstantInitValue() {
1621 auto initVal = getInitialValue();
1622 if (getConstant() && initVal.has_value())
1623 return llvm::cast<ElementsAttr>(initVal.value());
1638 return emitOpError(
"'")
1639 << getName() <<
"' does not reference a valid global memref";
1641 Type resultType = getResult().getType();
1642 if (global.getType() != resultType)
1643 return emitOpError(
"result type ")
1644 << resultType <<
" does not match type " << global.getType()
1645 <<
" of the global memref @" << getName();
1655 return emitOpError(
"incorrect number of indices for load");
1670 void MemorySpaceCastOp::getAsmResultNames(
1672 setNameFn(getResult(),
"memspacecast");
1676 if (inputs.size() != 1 || outputs.size() != 1)
1678 Type a = inputs.front(), b = outputs.front();
1679 auto aT = llvm::dyn_cast<MemRefType>(a);
1680 auto bT = llvm::dyn_cast<MemRefType>(b);
1682 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1683 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1686 if (aT.getElementType() != bT.getElementType())
1688 if (aT.getLayout() != bT.getLayout())
1690 if (aT.getShape() != bT.getShape())
1695 return uaT.getElementType() == ubT.getElementType();
1700 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1703 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1704 getSourceMutable().assign(parentCast.getSource());
1715 p <<
" " << getMemref() <<
'[';
1717 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1718 p <<
", locality<" << getLocalityHint();
1719 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1721 (*this)->getAttrs(),
1722 {
"localityHint",
"isWrite",
"isDataCache"});
1729 IntegerAttr localityHint;
1731 StringRef readOrWrite, cacheType;
1748 if (!readOrWrite.equals(
"read") && !readOrWrite.equals(
"write"))
1750 "rw specifier has to be 'read' or 'write'");
1752 PrefetchOp::getIsWriteAttrStrName(),
1755 if (!cacheType.equals(
"data") && !cacheType.equals(
"instr"))
1757 "cache type has to be 'data' or 'instr'");
1760 PrefetchOp::getIsDataCacheAttrStrName(),
1768 return emitOpError(
"too few indices");
1785 auto type = getOperand().getType();
1786 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1787 if (shapedType && shapedType.hasRank())
1789 return IntegerAttr();
1796 void ReinterpretCastOp::getAsmResultNames(
1798 setNameFn(getResult(),
"reinterpret_cast");
1805 MemRefType resultType,
Value source,
1814 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1822 MemRefType resultType,
Value source,
1827 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
1831 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
1835 strideValues, attrs);
1839 MemRefType resultType,
Value source,
Value offset,
1846 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1853 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1854 auto resultType = llvm::cast<MemRefType>(getType());
1855 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1856 return emitError(
"different memory spaces specified for source type ")
1857 << srcType <<
" and result memref type " << resultType;
1858 if (srcType.getElementType() != resultType.getElementType())
1859 return emitError(
"different element types specified for source type ")
1860 << srcType <<
" and result memref type " << resultType;
1863 for (
auto [idx, resultSize, expectedSize] :
1865 if (!ShapedType::isDynamic(resultSize) &&
1866 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1867 return emitError(
"expected result type with size = ")
1868 << expectedSize <<
" instead of " << resultSize
1869 <<
" in dim = " << idx;
1875 int64_t resultOffset;
1878 return emitError(
"expected result type to have strided layout but found ")
1882 int64_t expectedOffset = getStaticOffsets().front();
1883 if (!ShapedType::isDynamic(resultOffset) &&
1884 !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1885 return emitError(
"expected result type with offset = ")
1886 << expectedOffset <<
" instead of " << resultOffset;
1889 for (
auto [idx, resultStride, expectedStride] :
1891 if (!ShapedType::isDynamic(resultStride) &&
1892 !ShapedType::isDynamic(expectedStride) &&
1893 resultStride != expectedStride)
1894 return emitError(
"expected result type with stride = ")
1895 << expectedStride <<
" instead of " << resultStride
1896 <<
" in dim = " << idx;
1903 Value src = getSource();
1904 auto getPrevSrc = [&]() ->
Value {
1907 return prev.getSource();
1911 return prev.getSource();
1916 if (llvm::all_of(prev.getMixedOffsets(), [](
OpFoldResult val) {
1917 return isConstantIntValue(val, 0);
1919 return prev.getSource();
1924 if (
auto prevSrc = getPrevSrc()) {
1925 getSourceMutable().assign(prevSrc);
1930 if (!ShapedType::isDynamicShape(getType().
getShape()) &&
1931 src.
getType() == getType() && getStaticOffsets().front() == 0) {
1941 ShapedType::isDynamic);
1948 ShapedType::isDynamic);
1952 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1954 assert(values.size() == 1 &&
1955 "reinterpret_cast must have one and only one offset");
1957 ShapedType::isDynamic);
1999 struct ReinterpretCastOpExtractStridedMetadataFolder
2006 auto extractStridedMetadata =
2007 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2008 if (!extractStridedMetadata)
2015 extractStridedMetadata.getConstifiedMixedStrides();
2017 op.getConstifiedMixedStrides();
2018 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2021 unsigned rank = op.getType().getRank();
2022 for (
unsigned i = 0; i < rank; ++i) {
2023 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2028 assert(extractStridedMetadata.getSizes().size() ==
2029 op.getMixedSizes().size() &&
2030 "Strides and sizes rank must match");
2032 extractStridedMetadata.getConstifiedMixedSizes();
2034 op.getConstifiedMixedSizes();
2035 for (
unsigned i = 0; i < rank; ++i) {
2036 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2040 assert(op.getMixedOffsets().size() == 1 &&
2041 "reinterpret_cast with more than one offset should have been "
2042 "rejected by the verifier");
2044 extractStridedMetadata.getConstifiedMixedOffset();
2045 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2046 if (extractOffsetOfr != reinterpretOffsetOfr)
2054 Type srcTy = extractStridedMetadata.getSource().getType();
2056 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2059 extractStridedMetadata.getSource());
2068 results.
add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2075 void CollapseShapeOp::getAsmResultNames(
2077 setNameFn(getResult(),
"collapse_shape");
2080 void ExpandShapeOp::getAsmResultNames(
2082 setNameFn(getResult(),
"expand_shape");
2094 bool allowMultipleDynamicDimsPerGroup) {
2096 if (collapsedShape.size() != reassociation.size())
2097 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2098 << reassociation.size() <<
", expected " << collapsedShape.size();
2102 int64_t nextDim = 0;
2105 int64_t collapsedDim = it.index();
2107 bool foundDynamic =
false;
2108 for (int64_t expandedDim : group) {
2109 if (expandedDim != nextDim++)
2110 return op->
emitOpError(
"reassociation indices must be contiguous");
2112 if (expandedDim >=
static_cast<int64_t
>(expandedShape.size()))
2114 << expandedDim <<
" is out of bounds";
2117 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2118 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2120 "at most one dimension in a reassociation group may be dynamic");
2121 foundDynamic =
true;
2126 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2129 <<
") must be dynamic if and only if reassociation group is "
2134 if (!foundDynamic) {
2135 int64_t groupSize = 1;
2136 for (int64_t expandedDim : group)
2137 groupSize *= expandedShape[expandedDim];
2138 if (groupSize != collapsedShape[collapsedDim])
2140 << collapsedShape[collapsedDim]
2141 <<
") must equal reassociation group size (" << groupSize <<
")";
2145 if (collapsedShape.empty()) {
2147 for (int64_t d : expandedShape)
2150 "rank 0 memrefs can only be extended/collapsed with/from ones");
2151 }
else if (nextDim !=
static_cast<int64_t
>(expandedShape.size())) {
2155 << expandedShape.size()
2156 <<
") inconsistent with number of reassociation indices (" << nextDim
2169 getReassociationIndices());
2178 getReassociationIndices());
2190 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2205 reverseResultStrides.reserve(resultShape.size());
2206 unsigned shapeIndex = resultShape.size() - 1;
2207 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2209 int64_t currentStrideToExpand = std::get<1>(it);
2210 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2211 using saturated_arith::Wrapper;
2212 reverseResultStrides.push_back(currentStrideToExpand);
2213 currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
2214 Wrapper::size(resultShape[shapeIndex--]))
2218 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2219 resultStrides.resize(resultShape.size(), 1);
2226 if (srcType.getLayout().isIdentity()) {
2229 MemRefLayoutAttrInterface layout;
2231 srcType.getMemorySpace());
2237 if (
failed(computedLayout))
2239 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2240 srcType.getMemorySpace());
2247 auto srcType = llvm::cast<MemRefType>(src.
getType());
2249 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2252 assert(
succeeded(resultType) &&
"could not compute layout");
2253 build(builder, result, *resultType, src, reassociation);
2257 MemRefType srcType = getSrcType();
2258 MemRefType resultType = getResultType();
2260 if (srcType.getRank() >= resultType.getRank())
2261 return emitOpError(
"expected rank expansion, but found source rank ")
2262 << srcType.getRank() <<
" >= result rank " << resultType.getRank();
2266 resultType.getShape(),
2267 getReassociationIndices(),
2273 srcType, resultType.getShape(), getReassociationIndices());
2274 if (
failed(expectedResultType))
2275 return emitOpError(
"invalid source layout map");
2278 if (*expectedResultType != resultType)
2279 return emitOpError(
"expected expanded type to be ")
2280 << *expectedResultType <<
" but found " << resultType;
2302 bool strict =
false) {
2305 auto srcShape = srcType.getShape();
2315 resultStrides.reserve(reassociation.size());
2318 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2319 ref = ref.drop_back();
2320 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2321 resultStrides.push_back(srcStrides[ref.back()]);
2327 resultStrides.push_back(ShapedType::kDynamic);
2332 unsigned resultStrideIndex = resultStrides.size() - 1;
2335 using saturated_arith::Wrapper;
2336 auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
2337 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2338 stride = stride * Wrapper::size(srcShape[idx]);
2348 auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
2349 if (strict && (stride.saturated || srcStride.saturated))
2352 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2359 bool CollapseShapeOp::isGuaranteedCollapsible(
2362 if (srcType.getLayout().isIdentity())
2369 MemRefType CollapseShapeOp::computeCollapsedType(
2372 resultShape.reserve(reassociation.size());
2374 using saturated_arith::Wrapper;
2375 auto groupSize = Wrapper::size(1);
2376 for (int64_t srcDim : group)
2377 groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
2378 resultShape.push_back(groupSize.asSize());
2381 if (srcType.getLayout().isIdentity()) {
2384 MemRefLayoutAttrInterface layout;
2386 srcType.getMemorySpace());
2395 "invalid source layout map or collapsing non-contiguous dims");
2396 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2397 srcType.getMemorySpace());
2403 auto srcType = llvm::cast<MemRefType>(src.
getType());
2404 MemRefType resultType =
2405 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2406 build(b, result, resultType, src, attrs);
2412 MemRefType srcType = getSrcType();
2413 MemRefType resultType = getResultType();
2415 if (srcType.getRank() <= resultType.getRank())
2416 return emitOpError(
"expected rank reduction, but found source rank ")
2417 << srcType.getRank() <<
" <= result rank " << resultType.getRank();
2421 srcType.getShape(), getReassociationIndices(),
2426 MemRefType expectedResultType;
2427 if (srcType.getLayout().isIdentity()) {
2430 MemRefLayoutAttrInterface layout;
2431 expectedResultType =
2432 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2433 srcType.getMemorySpace());
2440 if (
failed(computedLayout))
2442 "invalid source layout map or collapsing non-contiguous dims");
2443 expectedResultType =
2445 *computedLayout, srcType.getMemorySpace());
2448 if (expectedResultType != resultType)
2449 return emitOpError(
"expected collapsed type to be ")
2450 << expectedResultType <<
" but found " << resultType;
2469 Type newResultType = CollapseShapeOp::computeCollapsedType(
2470 llvm::cast<MemRefType>(cast.getOperand().getType()),
2471 op.getReassociationIndices());
2473 if (newResultType == op.getResultType()) {
2475 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2478 op->
getLoc(), cast.getSource(), op.getReassociationIndices());
2492 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2493 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2494 adaptor.getOperands());
2497 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2498 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2499 adaptor.getOperands());
2506 void ReshapeOp::getAsmResultNames(
2508 setNameFn(getResult(),
"reshape");
2512 Type operandType = getSource().getType();
2513 Type resultType = getResult().getType();
2515 Type operandElementType =
2516 llvm::cast<ShapedType>(operandType).getElementType();
2517 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2518 if (operandElementType != resultElementType)
2519 return emitOpError(
"element types of source and destination memref "
2520 "types should be the same");
2522 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2523 if (!operandMemRefType.getLayout().isIdentity())
2524 return emitOpError(
"source memref type should have identity affine map");
2527 llvm::cast<MemRefType>(
getShape().getType()).getDimSize(0);
2528 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2529 if (resultMemRefType) {
2530 if (!resultMemRefType.getLayout().isIdentity())
2531 return emitOpError(
"result memref type should have identity affine map");
2532 if (shapeSize == ShapedType::kDynamic)
2533 return emitOpError(
"cannot use shape operand with dynamic length to "
2534 "reshape to statically-ranked memref type");
2535 if (shapeSize != resultMemRefType.getRank())
2537 "length of shape operand differs from the result's memref rank");
2548 return emitOpError(
"store index operand count not equal to memref rank");
2563 void SubViewOp::getAsmResultNames(
2565 setNameFn(getResult(),
"subview");
2571 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2575 unsigned rank = sourceMemRefType.getRank();
2577 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
2578 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
2579 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
2586 int64_t targetOffset = sourceOffset;
2587 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
2588 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2589 using saturated_arith::Wrapper;
2591 (Wrapper::offset(targetOffset) +
2592 Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2599 targetStrides.reserve(staticOffsets.size());
2600 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
2601 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2602 using saturated_arith::Wrapper;
2603 targetStrides.push_back(
2604 (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2609 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2611 targetOffset, targetStrides),
2612 sourceMemRefType.getMemorySpace());
2615 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2624 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2625 staticSizes, staticStrides);
2629 MemRefType sourceRankedTensorType,
2633 auto inferredType = llvm::cast<MemRefType>(
2634 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2635 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
2637 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
2638 return inferredType;
2641 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2643 assert(dimsToProject.has_value() &&
"invalid rank reduction");
2646 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2648 rankReducedStrides.reserve(resultShape.size());
2649 for (
auto [idx, value] :
llvm::enumerate(inferredLayout.getStrides())) {
2650 if (!dimsToProject->contains(idx))
2651 rankReducedStrides.push_back(value);
2655 inferredLayout.getOffset(),
2656 rankReducedStrides),
2657 inferredType.getMemorySpace());
2661 MemRefType sourceRankedTensorType,
2670 return SubViewOp::inferRankReducedResultType(
2671 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2678 MemRefType resultType,
Value source,
2688 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
2691 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2692 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2694 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2708 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2717 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2721 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2725 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2728 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2734 MemRefType resultType,
Value source,
2739 llvm::map_range(offsets, [&](int64_t v) ->
OpFoldResult {
2743 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) ->
OpFoldResult {
2747 llvm::map_range(strides, [&](int64_t v) ->
OpFoldResult {
2750 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2766 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2773 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2777 Value SubViewOp::getViewSource() {
return getSource(); }
2782 int64_t t1Offset, t2Offset;
2794 MemRefType candidateRankReducedType,
2801 originalType, candidateRankReducedType, sizes);
2804 if (!optionalUnusedDimsMask)
2807 if (originalType.getMemorySpace() !=
2808 candidateRankReducedType.getMemorySpace())
2818 template <
typename OpTy>
2820 OpTy op,
Type expectedType) {
2821 auto memrefType = llvm::cast<ShapedType>(expectedType);
2826 return op.
emitError(
"expected result rank to be smaller or equal to ")
2827 <<
"the source rank. ";
2829 return op.
emitError(
"expected result type to be ")
2831 <<
" or a rank-reduced version. (mismatch of result sizes) ";
2833 return op.
emitError(
"expected result element type to be ")
2834 << memrefType.getElementType();
2836 return op.
emitError(
"expected result and source memory spaces to match.");
2838 return op.
emitError(
"expected result type to be ")
2840 <<
" or a rank-reduced version. (mismatch of result layout) ";
2842 llvm_unreachable(
"unexpected subview verification result");
2847 MemRefType baseType = getSourceType();
2848 MemRefType subViewType = getType();
2851 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2852 return emitError(
"different memory spaces specified for base memref "
2854 << baseType <<
" and subview memref type " << subViewType;
2858 return emitError(
"base type ") << baseType <<
" is not strided";
2861 auto expectedType = SubViewOp::inferResultType(
2862 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2870 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
2879 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2880 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
2881 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
2883 unsigned rank = ranks[0];
2885 for (
unsigned idx = 0; idx < rank; ++idx) {
2887 op.isDynamicOffset(idx)
2888 ? op.getDynamicOffset(idx)
2891 op.isDynamicSize(idx)
2892 ? op.getDynamicSize(idx)
2895 op.isDynamicStride(idx)
2896 ? op.getDynamicStride(idx)
2898 res.emplace_back(
Range{offset, size, stride});
2911 MemRefType currentResultType, MemRefType currentSourceType,
2914 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2915 sourceType, mixedOffsets, mixedSizes, mixedStrides));
2916 std::optional<llvm::SmallBitVector> unusedDims =
2923 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2925 unsigned numDimsAfterReduction =
2926 nonRankReducedType.getRank() - unusedDims->count();
2927 shape.reserve(numDimsAfterReduction);
2928 strides.reserve(numDimsAfterReduction);
2929 for (
const auto &[idx, size, stride] :
2930 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2931 nonRankReducedType.getShape(), layout.getStrides())) {
2932 if (unusedDims->test(idx))
2934 shape.push_back(size);
2935 strides.push_back(stride);
2940 layout.getOffset(), strides),
2941 nonRankReducedType.getMemorySpace());
2946 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2947 unsigned rank = memrefType.getRank();
2952 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2953 targetShape, memrefType, offsets, sizes, strides));
2954 return b.
createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2961 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
2962 assert(sourceMemrefType &&
"not a ranked memref type");
2963 auto sourceShape = sourceMemrefType.getShape();
2964 if (sourceShape.equals(desiredShape))
2966 auto maybeRankReductionMask =
2968 if (!maybeRankReductionMask)
2978 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2981 auto mixedOffsets = subViewOp.getMixedOffsets();
2982 auto mixedSizes = subViewOp.getMixedSizes();
2983 auto mixedStrides = subViewOp.getMixedStrides();
2988 return !intValue || intValue.value() != 0;
2995 return !intValue || intValue.value() != 1;
3003 if (!intValue || *intValue != sourceShape[size.index()])
3027 class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3035 if (llvm::any_of(subViewOp.getOperands(), [](
Value operand) {
3036 return matchPattern(operand, matchConstantIndex());
3040 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3052 subViewOp.getType(), subViewOp.getSourceType(),
3053 llvm::cast<MemRefType>(castOp.getSource().getType()),
3054 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3055 subViewOp.getMixedStrides());
3060 subViewOp.getLoc(), resultType, castOp.getSource(),
3061 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3062 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3063 subViewOp.getStaticStrides());
3080 if (subViewOp.getSourceType() == subViewOp.getType()) {
3081 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3085 subViewOp.getSource());
3097 MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
3098 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
3101 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3102 if (droppedDims.empty())
3103 return nonReducedType;
3111 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3112 if (droppedDims.test(i))
3114 targetStrides.push_back(nonReducedStrides[i]);
3115 targetShape.push_back(nonReducedType.getDimSize(i));
3120 offset, targetStrides),
3121 nonReducedType.getMemorySpace());
3137 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3141 auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3142 auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3144 if (resultShapedType.hasStaticShape() &&
3145 resultShapedType == sourceShapedType) {
3146 return getViewSource();
3152 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3153 auto srcSizes = srcSubview.getMixedSizes();
3155 auto offsets = getMixedOffsets();
3156 bool allOffsetsZero = llvm::all_of(
3158 auto strides = getMixedStrides();
3159 bool allStridesOne = llvm::all_of(
3161 bool allSizesSame = llvm::equal(sizes, srcSizes);
3162 if (allOffsetsZero && allStridesOne && allSizesSame &&
3163 resultShapedType == sourceShapedType)
3164 return getViewSource();
3174 void TransposeOp::getAsmResultNames(
3176 setNameFn(getResult(),
"transpose");
3182 auto rank = memRefType.getRank();
3183 auto originalSizes = memRefType.getShape();
3185 assert(originalStrides.size() ==
static_cast<unsigned>(rank));
3191 unsigned position = en.value().cast<
AffineDimExpr>().getPosition();
3192 sizes[en.index()] = originalSizes[position];
3193 strides[en.index()] = originalStrides[position];
3203 AffineMapAttr permutation,
3205 auto permutationMap = permutation.getValue();
3206 assert(permutationMap);
3208 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3212 build(b, result, resultType, in, attrs);
3213 result.
addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3218 p <<
" " << getIn() <<
" " << getPermutation();
3220 p <<
" : " << getIn().getType() <<
" to " << getType();
3226 MemRefType srcType, dstType;
3235 result.
addAttribute(TransposeOp::getPermutationAttrStrName(),
3242 return emitOpError(
"expected a permutation map");
3243 if (getPermutation().getNumDims() != getIn().getType().getRank())
3244 return emitOpError(
"expected a permutation map of same rank as the input");
3246 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3247 auto dstType = llvm::cast<MemRefType>(getType());
3249 if (dstType != transposedType)
3250 return emitOpError(
"output type ")
3251 << dstType <<
" does not match transposed input type " << srcType
3252 <<
", " << transposedType;
3266 void ViewOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3267 setNameFn(getResult(),
"view");
3271 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3272 auto viewType = getType();
3275 if (!baseType.getLayout().isIdentity())
3276 return emitError(
"unsupported map for base memref type ") << baseType;
3279 if (!viewType.getLayout().isIdentity())
3280 return emitError(
"unsupported map for result memref type ") << viewType;
3283 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3284 return emitError(
"different memory spaces specified for base memref "
3286 << baseType <<
" and view memref type " << viewType;
3289 unsigned numDynamicDims = viewType.getNumDynamicDims();
3290 if (getSizes().size() != numDynamicDims)
3291 return emitError(
"incorrect number of size operands for type ") << viewType;
3296 Value ViewOp::getViewSource() {
return getSource(); }
3306 if (llvm::none_of(viewOp.getOperands(), [](
Value operand) {
3307 return matchPattern(operand, matchConstantIndex());
3312 auto memrefType = viewOp.getType();
3319 assert(oldOffset == 0 &&
"Expected 0 offset");
3327 newShapeConstants.reserve(memrefType.getRank());
3329 unsigned dynamicDimPos = 0;
3330 unsigned rank = memrefType.getRank();
3331 for (
unsigned dim = 0, e = rank; dim < e; ++dim) {
3332 int64_t dimSize = memrefType.getDimSize(dim);
3334 if (!ShapedType::isDynamic(dimSize)) {
3335 newShapeConstants.push_back(dimSize);
3338 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3339 if (
auto constantIndexOp =
3340 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3342 newShapeConstants.push_back(constantIndexOp.value());
3345 newShapeConstants.push_back(dimSize);
3346 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3352 MemRefType newMemRefType =
3355 if (newMemRefType == memrefType)
3359 auto newViewOp = rewriter.
create<ViewOp>(
3360 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3361 viewOp.getByteShift(), newOperands);
3373 Value memrefOperand = viewOp.getOperand(0);
3374 CastOp memrefCastOp = memrefOperand.
getDefiningOp<CastOp>();
3377 Value allocOperand = memrefCastOp.getOperand();
3382 viewOp.getByteShift(),
3392 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3402 "expects the number of subscripts to be equal to memref rank");
3403 switch (getKind()) {
3404 case arith::AtomicRMWKind::addf:
3405 case arith::AtomicRMWKind::maximumf:
3406 case arith::AtomicRMWKind::minimumf:
3407 case arith::AtomicRMWKind::mulf:
3408 if (!llvm::isa<FloatType>(getValue().getType()))
3409 return emitOpError() <<
"with kind '"
3410 << arith::stringifyAtomicRMWKind(getKind())
3411 <<
"' expects a floating-point type";
3413 case arith::AtomicRMWKind::addi:
3414 case arith::AtomicRMWKind::maxs:
3415 case arith::AtomicRMWKind::maxu:
3416 case arith::AtomicRMWKind::mins:
3417 case arith::AtomicRMWKind::minu:
3418 case arith::AtomicRMWKind::muli:
3419 case arith::AtomicRMWKind::ori:
3420 case arith::AtomicRMWKind::andi:
3421 if (!llvm::isa<IntegerType>(getValue().getType()))
3422 return emitOpError() <<
"with kind '"
3423 << arith::stringifyAtomicRMWKind(getKind())
3424 <<
"' expects an integer type";
3443 #define GET_OP_CLASSES
3444 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > getConstantOffset(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the offset and conforms to the function signatur...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref< SmallVector< int64_t >(MemRefType)> getAttributes, llvm::function_ref< bool(int64_t)> isDynamic)
Helper function that infers the constant values from a list of values, a memRefTy,...
static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef< OpFoldResult > sizes)
Checks if original Type type can be rank reduced to reduced type.
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap tp memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
static SmallVector< int64_t > getConstantStrides(MemRefType memrefType)
Wrapper around getStridesAndOffset that returns only the strides and conforms to the function signatu...
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static SmallVector< int64_t > getConstantSizes(MemRefType memRefTy)
Wrapper around getShape that conforms to the function signature expected for getAttributes in constif...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static std::optional< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
BlockListType & getBlocks()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs)
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
AffineExpr operator+(int64_t val, AffineExpr expr)
AffineExpr operator*(int64_t val, AffineExpr expr)
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Idiomatic saturated operations on offsets, sizes and strides.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...