25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
37 return arith::ConstantOp::materialize(builder, value, type, loc);
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
63 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(
memref.getShape(),
memref.getElementType());
65 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(
memref.getElementType());
72 auto memrefType = llvm::cast<MemRefType>(value.
getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
81 auto memrefType = llvm::cast<MemRefType>(value.
getType());
83 for (
int64_t i = 0; i < memrefType.getRank(); ++i)
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (
auto [i, cstVal] : llvm::enumerate(constValues)) {
104 if (ShapedType::isStatic(cstVal)) {
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
148template <
typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
159 llvm::append_range(operands, op->getOperands());
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
168 MemorySpaceCastOpInterface
result = castOp.cloneMemorySpaceCastOp(
171 return std::optional<SmallVector<Value>>(
179void AllocOp::getAsmResultNames(
181 setNameFn(getResult(),
"alloc");
184void AllocaOp::getAsmResultNames(
186 setNameFn(getResult(),
"alloca");
189template <
typename AllocLikeOp>
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
195 return op.emitOpError(
"result must be a memref");
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError(
"symbol operand count does not equal memref symbol "
206 << numSymbols <<
", got " << op.getSymbolOperands().size();
213LogicalResult AllocaOp::verify() {
217 "requires an ancestor op with AutomaticAllocationScope trait");
224template <
typename AllocLikeOp>
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter)
const override {
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
236 return constSizeArg.isNonNegative();
240 auto memrefType = alloc.getType();
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
248 unsigned dynamicDimPos = 0;
249 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
259 constSizeArg.isNonNegative()) {
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
288 using OpRewritePattern<T>::OpRewritePattern;
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter)
const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
310 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
315 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
325 MemRefType resultType =
getType();
328 if (!sourceType.getLayout().isIdentity())
329 return emitError(
"unsupported layout for source memref type ")
333 if (!resultType.getLayout().isIdentity())
334 return emitError(
"unsupported layout for result memref type ")
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError(
"different memory spaces specified for source memref "
341 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
385 result.regions.reserve(1);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405void AllocaScopeOp::getSuccessorRegions(
422 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
445 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
476 bool hasPotentialAlloca =
489 if (hasPotentialAlloca) {
522 if (!lastParentWithoutScope ||
535 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
536 if (!lastParentWithoutScope ||
543 Region *containingRegion =
nullptr;
544 for (
auto &r : lastParentWithoutScope->
getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion ==
nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
551 assert(containingRegion &&
"op must be contained in a region");
561 return containingRegion->isAncestor(v.getParentRegion());
564 toHoist.push_back(alloc);
571 for (
auto *op : toHoist) {
572 auto *cloned = rewriter.
clone(*op);
573 rewriter.
replaceOp(op, cloned->getResults());
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError(
"alignment must be power of 2");
594void AssumeAlignmentOp::getAsmResultNames(
596 setNameFn(getResult(),
"assume_align");
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
603 if (source.getAlignment() != getAlignment())
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(
OpBuilder &builder,
616 assert(resultIndex == 0 &&
"AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError(
"operand types and result types must match");
628 if (getOperandTypes().empty())
629 return emitOpError(
"expected at least one operand");
634LogicalResult DistinctObjectsOp::inferReturnTypes(
639 llvm::copy(operands.
getTypes(), std::back_inserter(inferredReturnTypes));
648 setNameFn(getResult(),
"cast");
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
694 if (!sourceType || !resultType)
698 if (sourceType.getElementType() != resultType.getElementType())
702 if (sourceType.getRank() != resultType.getRank())
706 int64_t sourceOffset, resultOffset;
708 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
713 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
727 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
738 if (inputs.size() != 1 || outputs.size() != 1)
740 Type a = inputs.front(),
b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(
b);
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
748 if (aT.getElementType() != bT.getElementType())
750 if (aT.getLayout() != bT.getLayout()) {
753 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(
b) || a ==
b);
766 if (!checkCompatible(aOffset, bOffset))
769 if (aT.getDimSize(
index) == 1 || bT.getDimSize(
index) == 1)
771 if (!checkCompatible(aStride, bStrides[
index]))
775 if (aT.getMemorySpace() != bT.getMemorySpace())
779 if (aT.getRank() != bT.getRank())
782 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(
OpBuilder &builder) {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter)
const override {
832 if (copyOp.getSource() != copyOp.getTarget())
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
843 static bool isEmptyMemRef(BaseMemRefType type) {
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter)
const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
862 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
869 for (
OpOperand &operand : op->getOpOperands()) {
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
880 SmallVectorImpl<OpFoldResult> &results) {
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
891 SmallVectorImpl<OpFoldResult> &results) {
900void DimOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(),
"dim");
904void DimOp::build(OpBuilder &builder, OperationState &
result, Value source,
906 auto loc =
result.location;
908 build(builder,
result, source, indexValue);
911std::optional<int64_t> DimOp::getConstantIndex() {
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
921 if (!rankedSourceType)
924 if (rankedSourceType.getRank() <= constantIndex)
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
932 setResultRange(getResult(),
941 std::map<int64_t, unsigned> numOccurences;
942 for (
auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
954static FailureOr<llvm::SmallBitVector>
956 MemRefType reducedType,
958 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
959 if (rankReduction <= 0)
960 return llvm::SmallBitVector(originalType.getRank());
964 for (
const auto &it : llvm::enumerate(sizes)) {
966 sourceSizes[it.index()] = *cst;
968 sourceSizes[it.index()] = ShapedType::kDynamic;
972 llvm::SmallBitVector usedSourceDims(originalType.getRank());
974 for (
int64_t resultSize : resultSizes) {
975 bool matched =
false;
976 for (
int64_t j = startJ;
j < originalType.getRank(); ++
j) {
977 if (sourceSizes[
j] == resultSize) {
978 usedSourceDims.set(
j);
988 llvm::SmallBitVector unusedDims(originalType.getRank());
989 for (
int64_t i = 0; i < originalType.getRank(); ++i)
990 if (!usedSourceDims.test(i))
1003 MemRefType originalType, MemRefType reducedType,
1005 llvm::SmallBitVector unusedDims) {
1013 std::map<int64_t, unsigned> currUnaccountedStrides =
1015 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1017 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1018 if (!unusedDims.test(dim))
1020 int64_t originalStride = originalStrides[dim];
1021 if (currUnaccountedStrides[originalStride] >
1022 candidateStridesNumOccurences[originalStride]) {
1024 currUnaccountedStrides[originalStride]--;
1027 if (currUnaccountedStrides[originalStride] ==
1028 candidateStridesNumOccurences[originalStride]) {
1030 unusedDims.reset(dim);
1033 if (currUnaccountedStrides[originalStride] <
1034 candidateStridesNumOccurences[originalStride]) {
1040 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1041 originalType.getRank())
1053static FailureOr<llvm::SmallBitVector>
1056 llvm::SmallBitVector unusedDims(originalType.getRank());
1057 if (originalType.getRank() == reducedType.getRank())
1060 for (
const auto &dim : llvm::enumerate(sizes))
1061 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1062 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1063 unusedDims.set(dim.index());
1067 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1068 originalType.getRank())
1072 int64_t originalOffset, candidateOffset;
1074 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1076 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1084 if (strides.size() <= 1)
1086 return llvm::any_of(strides.drop_back(),
1087 [](
int64_t s) { return !ShapedType::isDynamic(s); });
1089 if (hasNonTrivialStaticStride(originalStrides) ||
1090 hasNonTrivialStaticStride(candidateStrides)) {
1091 FailureOr<llvm::SmallBitVector> strideBased =
1094 candidateStrides, unusedDims);
1095 if (succeeded(strideBased))
1096 return *strideBased;
1102llvm::SmallBitVector SubViewOp::getDroppedDims() {
1103 MemRefType sourceType = getSourceType();
1104 MemRefType resultType =
getType();
1105 FailureOr<llvm::SmallBitVector> unusedDims =
1107 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1111OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1113 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1118 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1124 int64_t indexVal = index.getInt();
1125 if (indexVal < 0 || indexVal >= memrefType.getRank())
1129 if (!memrefType.isDynamicDim(index.getInt())) {
1131 return builder.
getIndexAttr(memrefType.getShape()[index.getInt()]);
1135 unsigned unsignedIndex = index.getValue().getZExtValue();
1138 Operation *definingOp = getSource().getDefiningOp();
1140 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1141 return *(alloc.getDynamicSizes().begin() +
1142 memrefType.getDynamicDimIndex(unsignedIndex));
1144 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1145 return *(alloca.getDynamicSizes().begin() +
1146 memrefType.getDynamicDimIndex(unsignedIndex));
1148 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1149 return *(view.getDynamicSizes().begin() +
1150 memrefType.getDynamicDimIndex(unsignedIndex));
1152 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1157 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1158 unsigned dynamicIdx = 0;
1159 for (OpFoldResult size : subview.getMixedSizes()) {
1160 if (llvm::isa<Attribute>(size))
1162 if (dynamicIdx == dynamicResultDimIdx)
1179struct DimOfMemRefReshape :
public OpRewritePattern<DimOp> {
1180 using OpRewritePattern<DimOp>::OpRewritePattern;
1182 LogicalResult matchAndRewrite(DimOp dim,
1183 PatternRewriter &rewriter)
const override {
1184 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1188 dim,
"Dim op is not defined by a reshape op.");
1199 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1200 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1201 if (reshape->isBeforeInBlock(definingOp)) {
1204 "dim.getIndex is not defined before reshape in the same block.");
1209 else if (dim->getBlock() != reshape->getBlock() &&
1210 !dim.getIndex().getParentRegion()->isProperAncestor(
1211 reshape->getParentRegion())) {
1216 dim,
"dim.getIndex does not dominate reshape.");
1222 Location loc = dim.getLoc();
1224 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1225 if (
load.getType() != dim.getType())
1226 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(),
load);
1234void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1235 MLIRContext *context) {
1236 results.
add<DimOfMemRefReshape>(context);
1243void DmaStartOp::build(OpBuilder &builder, OperationState &
result,
1244 Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
1246 Value tagMemRef,
ValueRange tagIndices, Value stride,
1247 Value elementsPerStride) {
1248 result.addOperands(srcMemRef);
1249 result.addOperands(srcIndices);
1250 result.addOperands(destMemRef);
1251 result.addOperands(destIndices);
1252 result.addOperands({numElements, tagMemRef});
1253 result.addOperands(tagIndices);
1255 result.addOperands({stride, elementsPerStride});
1258void DmaStartOp::print(OpAsmPrinter &p) {
1259 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1260 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1261 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1263 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1266 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1267 <<
", " << getTagMemRef().getType();
1278ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &
result) {
1279 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1280 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1281 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1283 OpAsmParser::UnresolvedOperand numElementsInfo;
1284 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1285 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1286 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1288 SmallVector<Type, 3> types;
1308 bool isStrided = strideInfo.size() == 2;
1309 if (!strideInfo.empty() && !isStrided) {
1311 "expected two stride related operands");
1316 if (types.size() != 3)
1338LogicalResult DmaStartOp::verify() {
1339 unsigned numOperands = getNumOperands();
1343 if (numOperands < 4)
1344 return emitOpError(
"expected at least 4 operands");
1349 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1350 return emitOpError(
"expected source to be of memref type");
1351 if (numOperands < getSrcMemRefRank() + 4)
1352 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1354 if (!getSrcIndices().empty() &&
1355 !llvm::all_of(getSrcIndices().getTypes(),
1356 [](Type t) {
return t.
isIndex(); }))
1357 return emitOpError(
"expected source indices to be of index type");
1360 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1361 return emitOpError(
"expected destination to be of memref type");
1362 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1363 if (numOperands < numExpectedOperands)
1364 return emitOpError() <<
"expected at least " << numExpectedOperands
1366 if (!getDstIndices().empty() &&
1367 !llvm::all_of(getDstIndices().getTypes(),
1368 [](Type t) {
return t.
isIndex(); }))
1369 return emitOpError(
"expected destination indices to be of index type");
1373 return emitOpError(
"expected num elements to be of index type");
1376 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1377 return emitOpError(
"expected tag to be of memref type");
1378 numExpectedOperands += getTagMemRefRank();
1379 if (numOperands < numExpectedOperands)
1380 return emitOpError() <<
"expected at least " << numExpectedOperands
1382 if (!getTagIndices().empty() &&
1383 !llvm::all_of(getTagIndices().getTypes(),
1384 [](Type t) {
return t.
isIndex(); }))
1385 return emitOpError(
"expected tag indices to be of index type");
1389 if (numOperands != numExpectedOperands &&
1390 numOperands != numExpectedOperands + 2)
1391 return emitOpError(
"incorrect number of operands");
1395 if (!getStride().
getType().isIndex() ||
1396 !getNumElementsPerStride().
getType().isIndex())
1398 "expected stride and num elements per stride to be of type index");
1404LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1405 SmallVectorImpl<OpFoldResult> &results) {
1414LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1415 SmallVectorImpl<OpFoldResult> &results) {
1420LogicalResult DmaWaitOp::verify() {
1422 unsigned numTagIndices = getTagIndices().size();
1423 unsigned tagMemRefRank = getTagMemRefRank();
1424 if (numTagIndices != tagMemRefRank)
1425 return emitOpError() <<
"expected tagIndices to have the same number of "
1426 "elements as the tagMemRef rank, expected "
1427 << tagMemRefRank <<
", but got " << numTagIndices;
1435void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1437 setNameFn(getResult(),
"intptr");
1446LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1447 MLIRContext *context, std::optional<Location> location,
1448 ExtractStridedMetadataOp::Adaptor adaptor,
1449 SmallVectorImpl<Type> &inferredReturnTypes) {
1450 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1454 unsigned sourceRank = sourceType.getRank();
1455 IndexType indexType = IndexType::get(context);
1457 MemRefType::get({}, sourceType.getElementType(),
1458 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1460 inferredReturnTypes.push_back(memrefType);
1462 inferredReturnTypes.push_back(indexType);
1464 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1465 inferredReturnTypes.push_back(indexType);
1469void ExtractStridedMetadataOp::getAsmResultNames(
1471 setNameFn(getBaseBuffer(),
"base_buffer");
1472 setNameFn(getOffset(),
"offset");
1475 if (!getSizes().empty()) {
1476 setNameFn(getSizes().front(),
"sizes");
1477 setNameFn(getStrides().front(),
"strides");
1484template <
typename Container>
1488 assert(values.size() == maybeConstants.size() &&
1489 " expected values and maybeConstants of the same size");
1490 bool atLeastOneReplacement =
false;
1491 for (
auto [maybeConstant,
result] : llvm::zip(maybeConstants, values)) {
1496 assert(isa<Attribute>(maybeConstant) &&
1497 "The constified value should be either unchanged (i.e., == result) "
1501 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1506 atLeastOneReplacement =
true;
1509 return atLeastOneReplacement;
1513ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1514 SmallVectorImpl<OpFoldResult> &results) {
1515 OpBuilder builder(*
this);
1519 getConstifiedMixedOffset());
1521 getConstifiedMixedSizes());
1523 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1526 if (
auto prev = getSource().getDefiningOp<CastOp>())
1527 if (isa<MemRefType>(prev.getSource().getType())) {
1528 getSourceMutable().assign(prev.getSource());
1529 atLeastOneReplacement =
true;
1532 return success(atLeastOneReplacement);
1535SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1541SmallVector<OpFoldResult>
1542ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1544 SmallVector<int64_t> staticValues;
1546 LogicalResult status =
1547 getSource().getType().getStridesAndOffset(staticValues, unused);
1549 assert(succeeded(status) &&
"could not get strides from type");
1554OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1556 SmallVector<OpFoldResult> values(1, offsetOfr);
1557 SmallVector<int64_t> staticValues, unused;
1559 LogicalResult status =
1560 getSource().getType().getStridesAndOffset(unused, offset);
1562 assert(succeeded(status) &&
"could not get offset from type");
1563 staticValues.push_back(offset);
1572void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &
result,
1574 OpBuilder::InsertionGuard g(builder);
1575 result.addOperands(memref);
1578 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1579 Type elementType = memrefType.getElementType();
1580 result.addTypes(elementType);
1582 Region *bodyRegion =
result.addRegion();
1588LogicalResult GenericAtomicRMWOp::verify() {
1589 auto &body = getRegion();
1590 if (body.getNumArguments() != 1)
1591 return emitOpError(
"expected single number of entry block arguments");
1593 if (getResult().
getType() != body.getArgument(0).getType())
1594 return emitOpError(
"expected block argument of the same type result type");
1597 body.walk([&](Operation *nestedOp) {
1601 "body of 'memref.generic_atomic_rmw' should contain "
1602 "only operations with no side effects");
1609ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1610 OperationState &
result) {
1611 OpAsmParser::UnresolvedOperand memref;
1613 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1623 Region *body =
result.addRegion();
1631void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1632 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1633 <<
"] : " << getMemref().
getType() <<
' ';
1642LogicalResult AtomicYieldOp::verify() {
1643 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1644 Type resultType = getResult().getType();
1645 if (parentType != resultType)
1646 return emitOpError() <<
"types mismatch between yield op: " << resultType
1647 <<
" and its parent: " << parentType;
1659 if (!op.isExternal()) {
1661 if (op.isUninitialized())
1662 p <<
"uninitialized";
1675 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1676 if (!memrefType || !memrefType.hasStaticShape())
1678 <<
"type should be static shaped memref, but got " << type;
1679 typeAttr = TypeAttr::get(type);
1685 initialValue = UnitAttr::get(parser.
getContext());
1692 if (!llvm::isa<ElementsAttr>(initialValue))
1694 <<
"initial value should be a unit or elements attribute";
1698LogicalResult GlobalOp::verify() {
1699 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1700 if (!memrefType || !memrefType.hasStaticShape())
1701 return emitOpError(
"type should be static shaped memref, but got ")
1706 if (getInitialValue().has_value()) {
1707 Attribute initValue = getInitialValue().value();
1708 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1709 return emitOpError(
"initial value should be a unit or elements "
1710 "attribute, but got ")
1715 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1717 auto initElementType =
1718 cast<TensorType>(elementsAttr.getType()).getElementType();
1719 auto memrefElementType = memrefType.getElementType();
1721 if (initElementType != memrefElementType)
1722 return emitOpError(
"initial value element expected to be of type ")
1723 << memrefElementType <<
", but was of type " << initElementType;
1728 auto initShape = elementsAttr.getShapedType().getShape();
1729 auto memrefShape = memrefType.getShape();
1730 if (initShape != memrefShape)
1731 return emitOpError(
"initial value shape expected to be ")
1732 << memrefShape <<
" but was " << initShape;
1740ElementsAttr GlobalOp::getConstantInitValue() {
1741 auto initVal = getInitialValue();
1742 if (getConstant() && initVal.has_value())
1743 return llvm::cast<ElementsAttr>(initVal.value());
1752GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1759 << getName() <<
"' does not reference a valid global memref";
1761 Type resultType = getResult().getType();
1762 if (global.getType() != resultType)
1764 << resultType <<
" does not match type " << global.getType()
1765 <<
" of the global memref @" << getName();
1773LogicalResult LoadOp::verify() {
1775 return emitOpError(
"incorrect number of indices for load, expected ")
1781OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1787 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1793 getGlobalOp, getGlobalOp.getNameAttr());
1798 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1802 return splatAttr.getSplatValue<Attribute>();
1805FailureOr<std::optional<SmallVector<Value>>>
1806LoadOp::bubbleDownCasts(OpBuilder &builder) {
1815void MemorySpaceCastOp::getAsmResultNames(
1817 setNameFn(getResult(),
"memspacecast");
1821 if (inputs.size() != 1 || outputs.size() != 1)
1823 Type a = inputs.front(),
b = outputs.front();
1824 auto aT = llvm::dyn_cast<MemRefType>(a);
1825 auto bT = llvm::dyn_cast<MemRefType>(
b);
1827 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1828 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
1831 if (aT.getElementType() != bT.getElementType())
1833 if (aT.getLayout() != bT.getLayout())
1835 if (aT.getShape() != bT.getShape())
1840 return uaT.getElementType() == ubT.getElementType();
1845OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1848 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1849 getSourceMutable().assign(parentCast.getSource());
1863bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1864 PtrLikeTypeInterface src) {
1865 return isa<BaseMemRefType>(tgt) &&
1866 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1869MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1870 OpBuilder &
b, PtrLikeTypeInterface tgt,
1872 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1873 return MemorySpaceCastOp::create(
b, getLoc(), tgt, src);
1877bool MemorySpaceCastOp::isSourcePromotable() {
1878 return getDest().getType().getMemorySpace() ==
nullptr;
1885void PrefetchOp::print(OpAsmPrinter &p) {
1886 p <<
" " << getMemref() <<
'[';
1888 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1889 p <<
", locality<" << getLocalityHint();
1890 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1892 (*this)->getAttrs(),
1893 {
"localityHint",
"isWrite",
"isDataCache"});
1897ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &
result) {
1898 OpAsmParser::UnresolvedOperand memrefInfo;
1899 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1900 IntegerAttr localityHint;
1902 StringRef readOrWrite, cacheType;
1919 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1921 "rw specifier has to be 'read' or 'write'");
1922 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1925 if (cacheType !=
"data" && cacheType !=
"instr")
1927 "cache type has to be 'data' or 'instr'");
1929 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1935LogicalResult PrefetchOp::verify() {
1942LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1943 SmallVectorImpl<OpFoldResult> &results) {
1952OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1954 auto type = getOperand().getType();
1955 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1956 if (shapedType && shapedType.hasRank())
1957 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1958 return IntegerAttr();
1965void ReinterpretCastOp::getAsmResultNames(
1967 setNameFn(getResult(),
"reinterpret_cast");
1973void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
1974 MemRefType resultType, Value source,
1975 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1976 ArrayRef<OpFoldResult> strides,
1977 ArrayRef<NamedAttribute> attrs) {
1978 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1979 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1983 result.addAttributes(attrs);
1984 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
1985 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
1986 b.getDenseI64ArrayAttr(staticSizes),
1987 b.getDenseI64ArrayAttr(staticStrides));
1990void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
1991 Value source, OpFoldResult offset,
1992 ArrayRef<OpFoldResult> sizes,
1993 ArrayRef<OpFoldResult> strides,
1994 ArrayRef<NamedAttribute> attrs) {
1995 auto sourceType = cast<BaseMemRefType>(source.
getType());
1996 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1997 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2001 auto stridedLayout = StridedLayoutAttr::get(
2002 b.getContext(), staticOffsets.front(), staticStrides);
2003 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2004 stridedLayout, sourceType.getMemorySpace());
2005 build(
b,
result, resultType, source, offset, sizes, strides, attrs);
2008void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2009 MemRefType resultType, Value source,
2010 int64_t offset, ArrayRef<int64_t> sizes,
2011 ArrayRef<int64_t> strides,
2012 ArrayRef<NamedAttribute> attrs) {
2013 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2014 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
2015 SmallVector<OpFoldResult> strideValues =
2016 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2017 return b.getI64IntegerAttr(v);
2019 build(
b,
result, resultType, source,
b.getI64IntegerAttr(offset), sizeValues,
2020 strideValues, attrs);
2023void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2024 MemRefType resultType, Value source, Value offset,
2026 ArrayRef<NamedAttribute> attrs) {
2027 SmallVector<OpFoldResult> sizeValues =
2028 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2029 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2030 strides, [](Value v) -> OpFoldResult {
return v; });
2031 build(
b,
result, resultType, source, offset, sizeValues, strideValues, attrs);
2036LogicalResult ReinterpretCastOp::verify() {
2038 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
2039 auto resultType = llvm::cast<MemRefType>(
getType());
2040 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2041 return emitError(
"different memory spaces specified for source type ")
2042 << srcType <<
" and result memref type " << resultType;
2048 for (
auto [idx, resultSize, expectedSize] :
2049 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2050 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2051 return emitError(
"expected result type with size = ")
2052 << (ShapedType::isDynamic(expectedSize)
2053 ? std::string(
"dynamic")
2054 : std::to_string(expectedSize))
2055 <<
" instead of " << resultSize <<
" in dim = " << idx;
2061 int64_t resultOffset;
2062 SmallVector<int64_t, 4> resultStrides;
2063 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2064 return emitError(
"expected result type to have strided layout but found ")
2068 int64_t expectedOffset = getStaticOffsets().front();
2069 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2070 return emitError(
"expected result type with offset = ")
2071 << (ShapedType::isDynamic(expectedOffset)
2072 ? std::string(
"dynamic")
2073 : std::to_string(expectedOffset))
2074 <<
" instead of " << resultOffset;
2077 for (
auto [idx, resultStride, expectedStride] :
2078 llvm::enumerate(resultStrides, getStaticStrides())) {
2079 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2080 return emitError(
"expected result type with stride = ")
2081 << (ShapedType::isDynamic(expectedStride)
2082 ? std::string(
"dynamic")
2083 : std::to_string(expectedStride))
2084 <<
" instead of " << resultStride <<
" in dim = " << idx;
2090OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) {
2091 Value src = getSource();
2092 auto getPrevSrc = [&]() -> Value {
2095 return prev.getSource();
2099 return prev.getSource();
2105 return prev.getSource();
2110 if (
auto prevSrc = getPrevSrc()) {
2111 getSourceMutable().assign(prevSrc);
2124SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2130SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2131 SmallVector<OpFoldResult> values = getMixedStrides();
2132 SmallVector<int64_t> staticValues;
2134 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2136 assert(succeeded(status) &&
"could not get strides from type");
2141OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2142 SmallVector<OpFoldResult> values = getMixedOffsets();
2143 assert(values.size() == 1 &&
2144 "reinterpret_cast must have one and only one offset");
2145 SmallVector<int64_t> staticValues, unused;
2147 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2149 assert(succeeded(status) &&
"could not get offset from type");
2150 staticValues.push_back(offset);
2198struct ReinterpretCastOpExtractStridedMetadataFolder
2199 :
public OpRewritePattern<ReinterpretCastOp> {
2201 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2203 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2204 PatternRewriter &rewriter)
const override {
2205 auto extractStridedMetadata =
2206 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2207 if (!extractStridedMetadata)
2212 auto isReinterpretCastNoop = [&]() ->
bool {
2214 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2215 op.getConstifiedMixedStrides()))
2219 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2220 op.getConstifiedMixedSizes()))
2224 assert(op.getMixedOffsets().size() == 1 &&
2225 "reinterpret_cast with more than one offset should have been "
2226 "rejected by the verifier");
2227 return extractStridedMetadata.getConstifiedMixedOffset() ==
2228 op.getConstifiedMixedOffset();
2231 if (!isReinterpretCastNoop()) {
2248 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2258 Type srcTy = extractStridedMetadata.getSource().getType();
2259 if (srcTy == op.getResult().getType())
2260 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2263 extractStridedMetadata.getSource());
2269struct ReinterpretCastOpConstantFolder
2270 :
public OpRewritePattern<ReinterpretCastOp> {
2272 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2274 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2275 PatternRewriter &rewriter)
const override {
2276 unsigned srcStaticCount = llvm::count_if(
2277 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2278 op.getMixedStrides()),
2279 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2281 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2282 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2283 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2289 if (srcStaticCount ==
2290 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2291 [](OpFoldResult ofr) {
return isa<Attribute>(ofr); }))
2294 auto newReinterpretCast = ReinterpretCastOp::create(
2295 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2303void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2304 MLIRContext *context) {
2305 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2306 ReinterpretCastOpConstantFolder>(context);
2309FailureOr<std::optional<SmallVector<Value>>>
2310ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2318void CollapseShapeOp::getAsmResultNames(
2320 setNameFn(getResult(),
"collapse_shape");
2323void ExpandShapeOp::getAsmResultNames(
2325 setNameFn(getResult(),
"expand_shape");
2328LogicalResult ExpandShapeOp::reifyResultShapes(
2330 reifiedResultShapes = {
2331 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2344 bool allowMultipleDynamicDimsPerGroup) {
2346 if (collapsedShape.size() != reassociation.size())
2347 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2348 << reassociation.size() <<
", expected " << collapsedShape.size();
2353 for (
const auto &it : llvm::enumerate(reassociation)) {
2355 int64_t collapsedDim = it.index();
2357 bool foundDynamic =
false;
2358 for (
int64_t expandedDim : group) {
2359 if (expandedDim != nextDim++)
2360 return op->
emitOpError(
"reassociation indices must be contiguous");
2362 if (expandedDim >=
static_cast<int64_t>(expandedShape.size()))
2364 << expandedDim <<
" is out of bounds";
2367 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2368 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2370 "at most one dimension in a reassociation group may be dynamic");
2371 foundDynamic =
true;
2376 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2379 <<
") must be dynamic if and only if reassociation group is "
2384 if (!foundDynamic) {
2386 for (
int64_t expandedDim : group)
2387 groupSize *= expandedShape[expandedDim];
2388 if (groupSize != collapsedShape[collapsedDim])
2390 << collapsedShape[collapsedDim]
2391 <<
") must equal reassociation group size (" << groupSize <<
")";
2395 if (collapsedShape.empty()) {
2397 for (
int64_t d : expandedShape)
2400 "rank 0 memrefs can only be extended/collapsed with/from ones");
2401 }
else if (nextDim !=
static_cast<int64_t>(expandedShape.size())) {
2405 << expandedShape.size()
2406 <<
") inconsistent with number of reassociation indices (" << nextDim
2413SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2417SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2419 getReassociationIndices());
2422SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2426SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2428 getReassociationIndices());
2433static FailureOr<StridedLayoutAttr>
2438 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2440 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2455 reverseResultStrides.reserve(resultShape.size());
2456 unsigned shapeIndex = resultShape.size() - 1;
2457 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2459 int64_t currentStrideToExpand = std::get<1>(it);
2460 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2461 reverseResultStrides.push_back(currentStrideToExpand);
2462 currentStrideToExpand =
2468 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2469 resultStrides.resize(resultShape.size(), 1);
2470 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2473FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2474 MemRefType srcType, ArrayRef<int64_t> resultShape,
2475 ArrayRef<ReassociationIndices> reassociation) {
2476 if (srcType.getLayout().isIdentity()) {
2479 MemRefLayoutAttrInterface layout;
2480 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2481 srcType.getMemorySpace());
2485 FailureOr<StridedLayoutAttr> computedLayout =
2487 if (
failed(computedLayout))
2489 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2490 srcType.getMemorySpace());
2493FailureOr<SmallVector<OpFoldResult>>
2494ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
2495 MemRefType expandedType,
2496 ArrayRef<ReassociationIndices> reassociation,
2497 ArrayRef<OpFoldResult> inputShape) {
2498 std::optional<SmallVector<OpFoldResult>> outputShape =
2503 return *outputShape;
2506void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2507 Type resultType, Value src,
2508 ArrayRef<ReassociationIndices> reassociation,
2509 ArrayRef<OpFoldResult> outputShape) {
2510 auto [staticOutputShape, dynamicOutputShape] =
2512 build(builder,
result, llvm::cast<MemRefType>(resultType), src,
2514 dynamicOutputShape, staticOutputShape);
2517void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2518 Type resultType, Value src,
2519 ArrayRef<ReassociationIndices> reassociation) {
2520 SmallVector<OpFoldResult> inputShape =
2522 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2523 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2524 builder,
result.location, memrefResultTy, reassociation, inputShape);
2527 assert(succeeded(outputShape) &&
"unable to infer output shape");
2528 build(builder,
result, memrefResultTy, src, reassociation, *outputShape);
2531void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2532 ArrayRef<int64_t> resultShape, Value src,
2533 ArrayRef<ReassociationIndices> reassociation) {
2535 auto srcType = llvm::cast<MemRefType>(src.
getType());
2536 FailureOr<MemRefType> resultType =
2537 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2540 assert(succeeded(resultType) &&
"could not compute layout");
2541 build(builder,
result, *resultType, src, reassociation);
2544void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2545 ArrayRef<int64_t> resultShape, Value src,
2546 ArrayRef<ReassociationIndices> reassociation,
2547 ArrayRef<OpFoldResult> outputShape) {
2549 auto srcType = llvm::cast<MemRefType>(src.
getType());
2550 FailureOr<MemRefType> resultType =
2551 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2554 assert(succeeded(resultType) &&
"could not compute layout");
2555 build(builder,
result, *resultType, src, reassociation, outputShape);
2558LogicalResult ExpandShapeOp::verify() {
2559 MemRefType srcType = getSrcType();
2560 MemRefType resultType = getResultType();
2562 if (srcType.getRank() > resultType.getRank()) {
2563 auto r0 = srcType.getRank();
2564 auto r1 = resultType.getRank();
2566 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2567 << r0 <<
" > " << r1 <<
").";
2572 resultType.getShape(),
2573 getReassociationIndices(),
2578 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2579 srcType, resultType.getShape(), getReassociationIndices());
2580 if (
failed(expectedResultType))
2584 if (*expectedResultType != resultType)
2585 return emitOpError(
"expected expanded type to be ")
2586 << *expectedResultType <<
" but found " << resultType;
2588 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2589 return emitOpError(
"expected number of static shape bounds to be equal to "
2590 "the output rank (")
2591 << resultType.getRank() <<
") but found "
2592 << getStaticOutputShape().size() <<
" inputs instead";
2594 if ((int64_t)getOutputShape().size() !=
2595 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2596 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2597 "static_output_shape: static_output_shape has ")
2598 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2599 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2610 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2611 for (
auto [pos, shape] : llvm::enumerate(resShape)) {
2612 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2613 return emitOpError(
"invalid output shape provided at pos ") << pos;
2626 auto cast = op.getSrc().getDefiningOp<CastOp>();
2630 if (!CastOp::canFoldIntoConsumerOp(cast))
2638 for (
auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2640 if (!sizeOpt.has_value()) {
2641 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2645 newOutputShapeSizes.push_back(sizeOpt.value());
2646 newOutputShape[dimIdx] = rewriter.
getIndexAttr(sizeOpt.value());
2649 Value castSource = cast.getSource();
2650 auto castSourceType = llvm::cast<MemRefType>(castSource.
getType());
2652 op.getReassociationIndices();
2653 for (
auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2654 auto newOutputShapeSizesSlice =
2655 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2656 bool newOutputDynamic =
2657 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2658 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2660 op,
"folding cast will result in changing dynamicity in "
2661 "reassociation group");
2664 FailureOr<MemRefType> newResultTypeOrFailure =
2665 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2666 reassociationIndices);
2668 if (failed(newResultTypeOrFailure))
2670 op,
"could not compute new expanded type after folding cast");
2672 if (*newResultTypeOrFailure == op.getResultType()) {
2674 op, [&]() { op.getSrcMutable().assign(castSource); });
2676 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2677 *newResultTypeOrFailure, castSource,
2678 reassociationIndices, newOutputShape);
2685void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2686 MLIRContext *context) {
2688 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2689 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2690 ExpandShapeOpMemRefCastFolder>(context);
2693FailureOr<std::optional<SmallVector<Value>>>
2694ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2705static FailureOr<StridedLayoutAttr>
2708 bool strict =
false) {
2711 auto srcShape = srcType.getShape();
2712 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2721 resultStrides.reserve(reassociation.size());
2724 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2725 ref = ref.drop_back();
2726 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2727 resultStrides.push_back(srcStrides[ref.back()]);
2733 resultStrides.push_back(ShapedType::kDynamic);
2738 unsigned resultStrideIndex = resultStrides.size() - 1;
2742 for (
int64_t idx : llvm::reverse(trailingReassocs)) {
2747 if (srcShape[idx - 1] == 1)
2759 if (strict && (stride.saturated || srcStride.saturated))
2762 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2766 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2769bool CollapseShapeOp::isGuaranteedCollapsible(
2770 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2772 if (srcType.getLayout().isIdentity())
2779MemRefType CollapseShapeOp::computeCollapsedType(
2780 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2781 SmallVector<int64_t> resultShape;
2782 resultShape.reserve(reassociation.size());
2785 for (int64_t srcDim : group)
2788 resultShape.push_back(groupSize.asInteger());
2791 if (srcType.getLayout().isIdentity()) {
2794 MemRefLayoutAttrInterface layout;
2795 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2796 srcType.getMemorySpace());
2802 FailureOr<StridedLayoutAttr> computedLayout =
2804 assert(succeeded(computedLayout) &&
2805 "invalid source layout map or collapsing non-contiguous dims");
2806 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2807 srcType.getMemorySpace());
2810void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2811 ArrayRef<ReassociationIndices> reassociation,
2812 ArrayRef<NamedAttribute> attrs) {
2813 auto srcType = llvm::cast<MemRefType>(src.
getType());
2814 MemRefType resultType =
2815 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2818 build(
b,
result, resultType, src, attrs);
2821LogicalResult CollapseShapeOp::verify() {
2822 MemRefType srcType = getSrcType();
2823 MemRefType resultType = getResultType();
2825 if (srcType.getRank() < resultType.getRank()) {
2826 auto r0 = srcType.getRank();
2827 auto r1 = resultType.getRank();
2829 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2830 << r0 <<
" < " << r1 <<
").";
2835 srcType.getShape(), getReassociationIndices(),
2840 MemRefType expectedResultType;
2841 if (srcType.getLayout().isIdentity()) {
2844 MemRefLayoutAttrInterface layout;
2845 expectedResultType =
2846 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2847 srcType.getMemorySpace());
2852 FailureOr<StridedLayoutAttr> computedLayout =
2854 if (
failed(computedLayout))
2856 "invalid source layout map or collapsing non-contiguous dims");
2857 expectedResultType =
2858 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2859 *computedLayout, srcType.getMemorySpace());
2862 if (expectedResultType != resultType)
2863 return emitOpError(
"expected collapsed type to be ")
2864 << expectedResultType <<
" but found " << resultType;
2876 auto cast = op.getOperand().getDefiningOp<CastOp>();
2880 if (!CastOp::canFoldIntoConsumerOp(cast))
2883 Type newResultType = CollapseShapeOp::computeCollapsedType(
2884 llvm::cast<MemRefType>(cast.getOperand().getType()),
2885 op.getReassociationIndices());
2887 if (newResultType == op.getResultType()) {
2889 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2892 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2893 op.getReassociationIndices());
2900void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2901 MLIRContext *context) {
2903 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2904 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2905 memref::DimOp, MemRefType>,
2906 CollapseShapeOpMemRefCastFolder>(context);
2909OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2911 adaptor.getOperands());
2914OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2916 adaptor.getOperands());
2919FailureOr<std::optional<SmallVector<Value>>>
2920CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2928void ReshapeOp::getAsmResultNames(
2930 setNameFn(getResult(),
"reshape");
2933LogicalResult ReshapeOp::verify() {
2934 Type operandType = getSource().getType();
2935 Type resultType = getResult().getType();
2937 Type operandElementType =
2938 llvm::cast<ShapedType>(operandType).getElementType();
2939 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2940 if (operandElementType != resultElementType)
2941 return emitOpError(
"element types of source and destination memref "
2942 "types should be the same");
2944 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2945 if (!operandMemRefType.getLayout().isIdentity())
2946 return emitOpError(
"source memref type should have identity affine map");
2950 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2951 if (resultMemRefType) {
2952 if (!resultMemRefType.getLayout().isIdentity())
2953 return emitOpError(
"result memref type should have identity affine map");
2954 if (shapeSize == ShapedType::kDynamic)
2955 return emitOpError(
"cannot use shape operand with dynamic length to "
2956 "reshape to statically-ranked memref type");
2957 if (shapeSize != resultMemRefType.getRank())
2959 "length of shape operand differs from the result's memref rank");
2964FailureOr<std::optional<SmallVector<Value>>>
2965ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2973LogicalResult StoreOp::verify() {
2975 return emitOpError(
"store index operand count not equal to memref rank");
2980LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2981 SmallVectorImpl<OpFoldResult> &results) {
2986FailureOr<std::optional<SmallVector<Value>>>
2987StoreOp::bubbleDownCasts(OpBuilder &builder) {
2996void SubViewOp::getAsmResultNames(
2998 setNameFn(getResult(),
"subview");
3004MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3005 ArrayRef<int64_t> staticOffsets,
3006 ArrayRef<int64_t> staticSizes,
3007 ArrayRef<int64_t> staticStrides) {
3008 unsigned rank = sourceMemRefType.getRank();
3010 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
3011 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
3012 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
3015 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3019 int64_t targetOffset = sourceOffset;
3020 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
3021 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3030 SmallVector<int64_t, 4> targetStrides;
3031 targetStrides.reserve(staticOffsets.size());
3032 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
3033 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3040 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3041 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3042 targetOffset, targetStrides),
3043 sourceMemRefType.getMemorySpace());
3046MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3047 ArrayRef<OpFoldResult> offsets,
3048 ArrayRef<OpFoldResult> sizes,
3049 ArrayRef<OpFoldResult> strides) {
3050 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3051 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3061 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3062 staticSizes, staticStrides);
3065MemRefType SubViewOp::inferRankReducedResultType(
3066 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3067 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3068 ArrayRef<int64_t> strides) {
3069 MemRefType inferredType =
3070 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3071 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
3073 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
3074 return inferredType;
3077 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3079 assert(dimsToProject.has_value() &&
"invalid rank reduction");
3082 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3083 SmallVector<int64_t> rankReducedStrides;
3084 rankReducedStrides.reserve(resultShape.size());
3085 for (
auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3086 if (!dimsToProject->contains(idx))
3087 rankReducedStrides.push_back(value);
3089 return MemRefType::get(resultShape, inferredType.getElementType(),
3090 StridedLayoutAttr::get(inferredLayout.getContext(),
3091 inferredLayout.getOffset(),
3092 rankReducedStrides),
3093 inferredType.getMemorySpace());
3096MemRefType SubViewOp::inferRankReducedResultType(
3097 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3098 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3099 ArrayRef<OpFoldResult> strides) {
3100 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3101 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3105 return SubViewOp::inferRankReducedResultType(
3106 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3112void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3113 MemRefType resultType, Value source,
3114 ArrayRef<OpFoldResult> offsets,
3115 ArrayRef<OpFoldResult> sizes,
3116 ArrayRef<OpFoldResult> strides,
3117 ArrayRef<NamedAttribute> attrs) {
3118 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3119 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3123 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
3126 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3127 staticSizes, staticStrides);
3129 result.addAttributes(attrs);
3130 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
3131 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3132 b.getDenseI64ArrayAttr(staticSizes),
3133 b.getDenseI64ArrayAttr(staticStrides));
3138void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3139 ArrayRef<OpFoldResult> offsets,
3140 ArrayRef<OpFoldResult> sizes,
3141 ArrayRef<OpFoldResult> strides,
3142 ArrayRef<NamedAttribute> attrs) {
3143 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3147void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3148 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3149 ArrayRef<int64_t> strides,
3150 ArrayRef<NamedAttribute> attrs) {
3151 SmallVector<OpFoldResult> offsetValues =
3152 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3153 return b.getI64IntegerAttr(v);
3155 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3156 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3157 SmallVector<OpFoldResult> strideValues =
3158 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3159 return b.getI64IntegerAttr(v);
3161 build(
b,
result, source, offsetValues, sizeValues, strideValues, attrs);
3166void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3167 MemRefType resultType, Value source,
3168 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3169 ArrayRef<int64_t> strides,
3170 ArrayRef<NamedAttribute> attrs) {
3171 SmallVector<OpFoldResult> offsetValues =
3172 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3173 return b.getI64IntegerAttr(v);
3175 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3176 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3177 SmallVector<OpFoldResult> strideValues =
3178 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3179 return b.getI64IntegerAttr(v);
3181 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues,
3187void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3188 MemRefType resultType, Value source,
ValueRange offsets,
3190 ArrayRef<NamedAttribute> attrs) {
3191 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3192 offsets, [](Value v) -> OpFoldResult {
return v; });
3193 SmallVector<OpFoldResult> sizeValues =
3194 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3195 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3196 strides, [](Value v) -> OpFoldResult {
return v; });
3197 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
3201void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3203 ArrayRef<NamedAttribute> attrs) {
3204 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3208Value SubViewOp::getViewSource() {
return getSource(); }
3215 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3216 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3217 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3224 const llvm::SmallBitVector &droppedDims) {
3225 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3226 "incorrect number of bits");
3227 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3228 "incorrect number of dropped dims");
3231 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3232 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3233 if (failed(res1) || failed(res2))
3235 for (
int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3238 if (t1Strides[i] != t2Strides[
j])
3246 SubViewOp op,
Type expectedType) {
3247 auto memrefType = llvm::cast<ShapedType>(expectedType);
3252 return op->emitError(
"expected result rank to be smaller or equal to ")
3253 <<
"the source rank, but got " << op.getType();
3255 return op->emitError(
"expected result type to be ")
3257 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3260 return op->emitError(
"expected result element type to be ")
3261 << memrefType.getElementType() <<
", but got " << op.getType();
3263 return op->emitError(
3264 "expected result and source memory spaces to match, but got ")
3267 return op->emitError(
"expected result type to be ")
3269 <<
" or a rank-reduced version. (mismatch of result layout), but "
3273 llvm_unreachable(
"unexpected subview verification result");
3277LogicalResult SubViewOp::verify() {
3278 MemRefType baseType = getSourceType();
3279 MemRefType subViewType =
getType();
3280 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3281 ArrayRef<int64_t> staticSizes = getStaticSizes();
3282 ArrayRef<int64_t> staticStrides = getStaticStrides();
3285 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3286 return emitError(
"different memory spaces specified for base memref "
3288 << baseType <<
" and subview memref type " << subViewType;
3291 if (!baseType.isStrided())
3292 return emitError(
"base type ") << baseType <<
" is not strided";
3296 MemRefType expectedType = SubViewOp::inferResultType(
3297 baseType, staticOffsets, staticSizes, staticStrides);
3302 expectedType, subViewType);
3307 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3309 *
this, expectedType);
3314 *
this, expectedType);
3324 *
this, expectedType);
3329 *
this, expectedType);
3333 SliceBoundsVerificationResult boundsResult =
3335 staticStrides,
true);
3337 return getOperation()->emitError(boundsResult.
errorMessage);
3343 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3352 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3353 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3354 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3356 unsigned rank = ranks[0];
3358 for (
unsigned idx = 0; idx < rank; ++idx) {
3360 op.isDynamicOffset(idx)
3361 ? op.getDynamicOffset(idx)
3364 op.isDynamicSize(idx)
3365 ? op.getDynamicSize(idx)
3368 op.isDynamicStride(idx)
3369 ? op.getDynamicStride(idx)
3371 res.emplace_back(
Range{offset, size, stride});
3384 MemRefType currentResultType, MemRefType currentSourceType,
3387 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3388 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3390 currentSourceType, currentResultType, mixedSizes);
3391 if (failed(unusedDims))
3394 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3396 unsigned numDimsAfterReduction =
3397 nonRankReducedType.getRank() - unusedDims->count();
3398 shape.reserve(numDimsAfterReduction);
3399 strides.reserve(numDimsAfterReduction);
3400 for (
const auto &[idx, size, stride] :
3401 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3402 nonRankReducedType.getShape(), layout.getStrides())) {
3403 if (unusedDims->test(idx))
3405 shape.push_back(size);
3406 strides.push_back(stride);
3409 return MemRefType::get(
shape, nonRankReducedType.getElementType(),
3410 StridedLayoutAttr::get(sourceType.getContext(),
3411 layout.getOffset(), strides),
3412 nonRankReducedType.getMemorySpace());
3417 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3418 unsigned rank = memrefType.getRank();
3422 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3423 targetShape, memrefType, offsets, sizes, strides);
3424 return b.createOrFold<memref::SubViewOp>(loc, targetType,
memref, offsets,
3431 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3432 assert(sourceMemrefType &&
"not a ranked memref type");
3433 auto sourceShape = sourceMemrefType.getShape();
3434 if (sourceShape.equals(desiredShape))
3436 auto maybeRankReductionMask =
3438 if (!maybeRankReductionMask)
3448 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3451 auto mixedOffsets = subViewOp.getMixedOffsets();
3452 auto mixedSizes = subViewOp.getMixedSizes();
3453 auto mixedStrides = subViewOp.getMixedStrides();
3458 return !intValue || intValue.value() != 0;
3465 return !intValue || intValue.value() != 1;
3471 for (
const auto &size : llvm::enumerate(mixedSizes)) {
3473 if (!intValue || *intValue != sourceShape[size.index()])
3497class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3499 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3501 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3502 PatternRewriter &rewriter)
const override {
3505 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3506 return matchPattern(operand, matchConstantIndex());
3510 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3514 if (!CastOp::canFoldIntoConsumerOp(castOp))
3522 subViewOp.getType(), subViewOp.getSourceType(),
3523 llvm::cast<MemRefType>(castOp.getSource().getType()),
3524 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3525 subViewOp.getMixedStrides());
3529 Value newSubView = SubViewOp::create(
3530 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3531 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3532 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3533 subViewOp.getStaticStrides());
3542class TrivialSubViewOpFolder final :
public OpRewritePattern<SubViewOp> {
3544 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3546 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3547 PatternRewriter &rewriter)
const override {
3550 if (subViewOp.getSourceType() == subViewOp.getType()) {
3551 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3555 subViewOp.getSource());
3567 MemRefType resTy = SubViewOp::inferResultType(
3568 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3571 MemRefType nonReducedType = resTy;
3574 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3575 if (droppedDims.none())
3576 return nonReducedType;
3579 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3584 for (
int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3585 if (droppedDims.test(i))
3587 targetStrides.push_back(nonReducedStrides[i]);
3588 targetShape.push_back(nonReducedType.getDimSize(i));
3591 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3592 StridedLayoutAttr::get(nonReducedType.getContext(),
3593 offset, targetStrides),
3594 nonReducedType.getMemorySpace());
3605void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3606 MLIRContext *context) {
3608 .
add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3609 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3610 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3613OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3614 MemRefType sourceMemrefType = getSource().getType();
3615 MemRefType resultMemrefType = getResult().getType();
3617 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3619 if (resultMemrefType == sourceMemrefType &&
3620 resultMemrefType.hasStaticShape() &&
3621 (!resultLayout || resultLayout.hasStaticLayout())) {
3622 return getViewSource();
3628 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3629 auto srcSizes = srcSubview.getMixedSizes();
3631 auto offsets = getMixedOffsets();
3633 auto strides = getMixedStrides();
3634 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3635 bool allSizesSame = llvm::equal(sizes, srcSizes);
3636 if (allOffsetsZero && allStridesOne && allSizesSame &&
3637 resultMemrefType == sourceMemrefType)
3638 return getViewSource();
3644FailureOr<std::optional<SmallVector<Value>>>
3645SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3649void SubViewOp::inferStridedMetadataRanges(
3650 ArrayRef<StridedMetadataRange> ranges,
GetIntRangeFn getIntRange,
3652 auto isUninitialized =
3653 +[](IntegerValueRange range) {
return range.isUninitialized(); };
3656 SmallVector<IntegerValueRange> offsetOperands =
3658 if (llvm::any_of(offsetOperands, isUninitialized))
3661 SmallVector<IntegerValueRange> sizeOperands =
3663 if (llvm::any_of(sizeOperands, isUninitialized))
3666 SmallVector<IntegerValueRange> stridesOperands =
3668 if (llvm::any_of(stridesOperands, isUninitialized))
3671 StridedMetadataRange sourceRange =
3672 ranges[getSourceMutable().getOperandNumber()];
3676 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.
getStrides();
3682 ConstantIntRanges offset = sourceRange.
getOffsets()[0];
3683 SmallVector<ConstantIntRanges> strides, sizes;
3685 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3686 bool dropped = droppedDims.test(i);
3688 ConstantIntRanges off =
3699 sizes.push_back(sizeOperands[i].getValue());
3702 setMetadata(getResult(),
3704 SmallVector<ConstantIntRanges>({std::move(offset)}),
3705 std::move(sizes), std::move(strides)));
3712void TransposeOp::getAsmResultNames(
3714 setNameFn(getResult(),
"transpose");
3720 auto originalSizes = memRefType.getShape();
3721 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3722 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3731 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3734void TransposeOp::build(OpBuilder &
b, OperationState &
result, Value in,
3735 AffineMapAttr permutation,
3736 ArrayRef<NamedAttribute> attrs) {
3737 auto permutationMap = permutation.getValue();
3738 assert(permutationMap);
3740 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3744 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3745 build(
b,
result, resultType, in, attrs);
3749void TransposeOp::print(OpAsmPrinter &p) {
3750 p <<
" " << getIn() <<
" " << getPermutation();
3752 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3755ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
3756 OpAsmParser::UnresolvedOperand in;
3757 AffineMap permutation;
3758 MemRefType srcType, dstType;
3767 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3768 AffineMapAttr::get(permutation));
3772LogicalResult TransposeOp::verify() {
3775 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3776 return emitOpError(
"expected a permutation map of same rank as the input");
3778 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3779 auto resultType = llvm::cast<MemRefType>(
getType());
3781 .canonicalizeStridedLayout();
3783 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3786 <<
" is not equivalent to the canonical transposed input type "
3787 << canonicalResultType;
3791OpFoldResult TransposeOp::fold(FoldAdaptor) {
3794 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3798 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3799 AffineMap composedPermutation =
3800 getPermutation().compose(otherTransposeOp.getPermutation());
3801 getInMutable().assign(otherTransposeOp.getIn());
3802 setPermutation(composedPermutation);
3808FailureOr<std::optional<SmallVector<Value>>>
3809TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3817void ViewOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
3818 setNameFn(getResult(),
"view");
3821LogicalResult ViewOp::verify() {
3822 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3826 if (!baseType.getLayout().isIdentity())
3827 return emitError(
"unsupported map for base memref type ") << baseType;
3830 if (!viewType.getLayout().isIdentity())
3831 return emitError(
"unsupported map for result memref type ") << viewType;
3834 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3835 return emitError(
"different memory spaces specified for base memref "
3837 << baseType <<
" and view memref type " << viewType;
3846Value ViewOp::getViewSource() {
return getSource(); }
3848OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3849 MemRefType sourceMemrefType = getSource().getType();
3850 MemRefType resultMemrefType = getResult().getType();
3852 if (resultMemrefType == sourceMemrefType &&
3853 resultMemrefType.hasStaticShape() &&
isZeroInteger(getByteShift()))
3854 return getViewSource();
3859SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3860 SmallVector<OpFoldResult>
result;
3864 if (ShapedType::isDynamic(dim)) {
3865 result.push_back(getSizes()[ctr++]);
3867 result.push_back(
b.getIndexAttr(dim));
3879 SmallVectorImpl<Value> &foldedDynamicSizes) {
3880 SmallVector<int64_t> staticShape(type.getShape());
3881 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3882 "incorrect number of dynamic sizes");
3886 for (
auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3887 if (ShapedType::isStatic(dimSize))
3890 Value dynamicSize = dynamicSizes[ctr++];
3893 if (cst.value() < 0) {
3894 foldedDynamicSizes.push_back(dynamicSize);
3897 staticShape[dim] = cst.value();
3899 foldedDynamicSizes.push_back(dynamicSize);
3903 return MemRefType::Builder(type).setShape(staticShape);
3917struct ViewOpShapeFolder :
public OpRewritePattern<ViewOp> {
3920 LogicalResult matchAndRewrite(ViewOp viewOp,
3921 PatternRewriter &rewriter)
const override {
3922 SmallVector<Value> foldedDynamicSizes;
3923 MemRefType resultType = viewOp.getType();
3925 resultType, viewOp.getSizes(), foldedDynamicSizes);
3928 if (foldedMemRefType == resultType)
3932 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
3933 viewOp.getSource(), viewOp.getByteShift(),
3934 foldedDynamicSizes);
3942struct ViewOpMemrefCastFolder :
public OpRewritePattern<ViewOp> {
3945 LogicalResult matchAndRewrite(ViewOp viewOp,
3946 PatternRewriter &rewriter)
const override {
3947 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3952 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3953 viewOp.getByteShift(), viewOp.getSizes());
3959void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3960 MLIRContext *context) {
3961 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3964FailureOr<std::optional<SmallVector<Value>>>
3965ViewOp::bubbleDownCasts(OpBuilder &builder) {
3973LogicalResult AtomicRMWOp::verify() {
3976 "expects the number of subscripts to be equal to memref rank");
3977 switch (getKind()) {
3978 case arith::AtomicRMWKind::addf:
3979 case arith::AtomicRMWKind::maximumf:
3980 case arith::AtomicRMWKind::minimumf:
3981 case arith::AtomicRMWKind::mulf:
3982 if (!llvm::isa<FloatType>(getValue().
getType()))
3984 << arith::stringifyAtomicRMWKind(getKind())
3985 <<
"' expects a floating-point type";
3987 case arith::AtomicRMWKind::addi:
3988 case arith::AtomicRMWKind::maxs:
3989 case arith::AtomicRMWKind::maxu:
3990 case arith::AtomicRMWKind::mins:
3991 case arith::AtomicRMWKind::minu:
3992 case arith::AtomicRMWKind::muli:
3993 case arith::AtomicRMWKind::ori:
3994 case arith::AtomicRMWKind::xori:
3995 case arith::AtomicRMWKind::andi:
3996 if (!llvm::isa<IntegerType>(getValue().
getType()))
3998 << arith::stringifyAtomicRMWKind(getKind())
3999 <<
"' expects an integer type";
4007OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4011 return OpFoldResult();
4014FailureOr<std::optional<SmallVector<Value>>>
4015AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4024#define GET_OP_CLASSES
4025#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.