25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
37 return arith::ConstantOp::materialize(builder, value, type, loc);
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
63 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(
memref.getShape(),
memref.getElementType());
65 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(
memref.getElementType());
72 auto memrefType = llvm::cast<MemRefType>(value.
getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
81 auto memrefType = llvm::cast<MemRefType>(value.
getType());
83 for (
int64_t i = 0; i < memrefType.getRank(); ++i)
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (
auto [i, cstVal] : llvm::enumerate(constValues)) {
104 if (ShapedType::isStatic(cstVal)) {
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
148template <
typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
159 llvm::append_range(operands, op->getOperands());
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
168 MemorySpaceCastOpInterface
result = castOp.cloneMemorySpaceCastOp(
171 return std::optional<SmallVector<Value>>(
179void AllocOp::getAsmResultNames(
181 setNameFn(getResult(),
"alloc");
184void AllocaOp::getAsmResultNames(
186 setNameFn(getResult(),
"alloca");
189template <
typename AllocLikeOp>
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
195 return op.emitOpError(
"result must be a memref");
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError(
"symbol operand count does not equal memref symbol "
206 << numSymbols <<
", got " << op.getSymbolOperands().size();
213LogicalResult AllocaOp::verify() {
217 "requires an ancestor op with AutomaticAllocationScope trait");
224template <
typename AllocLikeOp>
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter)
const override {
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
236 return constSizeArg.isNonNegative();
240 auto memrefType = alloc.getType();
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
248 unsigned dynamicDimPos = 0;
249 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
259 constSizeArg.isNonNegative()) {
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
288 using OpRewritePattern<T>::OpRewritePattern;
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter)
const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
310 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
315 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
325 MemRefType resultType =
getType();
328 if (!sourceType.getLayout().isIdentity())
329 return emitError(
"unsupported layout for source memref type ")
333 if (!resultType.getLayout().isIdentity())
334 return emitError(
"unsupported layout for result memref type ")
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError(
"different memory spaces specified for source memref "
341 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
385 result.regions.reserve(1);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405void AllocaScopeOp::getSuccessorRegions(
422 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
445 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
476 bool hasPotentialAlloca =
489 if (hasPotentialAlloca) {
522 if (!lastParentWithoutScope ||
535 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
536 if (!lastParentWithoutScope ||
543 Region *containingRegion =
nullptr;
544 for (
auto &r : lastParentWithoutScope->
getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion ==
nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
551 assert(containingRegion &&
"op must be contained in a region");
561 return containingRegion->isAncestor(v.getParentRegion());
564 toHoist.push_back(alloc);
571 for (
auto *op : toHoist) {
572 auto *cloned = rewriter.
clone(*op);
573 rewriter.
replaceOp(op, cloned->getResults());
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError(
"alignment must be power of 2");
594void AssumeAlignmentOp::getAsmResultNames(
596 setNameFn(getResult(),
"assume_align");
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
603 if (source.getAlignment() != getAlignment())
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(
OpBuilder &builder,
616 assert(resultIndex == 0 &&
"AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError(
"operand types and result types must match");
628 if (getOperandTypes().empty())
629 return emitOpError(
"expected at least one operand");
634LogicalResult DistinctObjectsOp::inferReturnTypes(
639 llvm::copy(operands.
getTypes(), std::back_inserter(inferredReturnTypes));
648 setNameFn(getResult(),
"cast");
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
694 if (!sourceType || !resultType)
698 if (sourceType.getElementType() != resultType.getElementType())
702 if (sourceType.getRank() != resultType.getRank())
706 int64_t sourceOffset, resultOffset;
708 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
713 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
727 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
738 if (inputs.size() != 1 || outputs.size() != 1)
740 Type a = inputs.front(),
b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(
b);
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
748 if (aT.getElementType() != bT.getElementType())
750 if (aT.getLayout() != bT.getLayout()) {
753 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(
b) || a ==
b);
766 if (!checkCompatible(aOffset, bOffset))
769 if (aT.getDimSize(
index) == 1 || bT.getDimSize(
index) == 1)
771 if (!checkCompatible(aStride, bStrides[
index]))
775 if (aT.getMemorySpace() != bT.getMemorySpace())
779 if (aT.getRank() != bT.getRank())
782 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(
OpBuilder &builder) {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter)
const override {
832 if (copyOp.getSource() != copyOp.getTarget())
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
843 static bool isEmptyMemRef(BaseMemRefType type) {
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter)
const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
862 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
869 for (
OpOperand &operand : op->getOpOperands()) {
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
880 SmallVectorImpl<OpFoldResult> &results) {
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
891 SmallVectorImpl<OpFoldResult> &results) {
900void DimOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(),
"dim");
904void DimOp::build(OpBuilder &builder, OperationState &
result, Value source,
906 auto loc =
result.location;
908 build(builder,
result, source, indexValue);
911std::optional<int64_t> DimOp::getConstantIndex() {
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
921 if (!rankedSourceType)
924 if (rankedSourceType.getRank() <= constantIndex)
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
932 setResultRange(getResult(),
941 std::map<int64_t, unsigned> numOccurences;
942 for (
auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
954static FailureOr<llvm::SmallBitVector>
956 MemRefType reducedType,
958 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
959 if (rankReduction <= 0)
960 return llvm::SmallBitVector(originalType.getRank());
964 for (
const auto &it : llvm::enumerate(sizes)) {
966 sourceSizes[it.index()] = *cst;
968 sourceSizes[it.index()] = ShapedType::kDynamic;
972 llvm::SmallBitVector usedSourceDims(originalType.getRank());
974 for (
int64_t resultSize : resultSizes) {
975 bool matched =
false;
976 for (
int64_t j = startJ;
j < originalType.getRank(); ++
j) {
977 if (sourceSizes[
j] == resultSize) {
978 usedSourceDims.set(
j);
988 llvm::SmallBitVector unusedDims(originalType.getRank());
989 for (
int64_t i = 0; i < originalType.getRank(); ++i)
990 if (!usedSourceDims.test(i))
1003 MemRefType originalType, MemRefType reducedType,
1005 llvm::SmallBitVector unusedDims) {
1013 std::map<int64_t, unsigned> currUnaccountedStrides =
1015 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1017 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1018 if (!unusedDims.test(dim))
1020 int64_t originalStride = originalStrides[dim];
1021 if (currUnaccountedStrides[originalStride] >
1022 candidateStridesNumOccurences[originalStride]) {
1024 currUnaccountedStrides[originalStride]--;
1027 if (currUnaccountedStrides[originalStride] ==
1028 candidateStridesNumOccurences[originalStride]) {
1030 unusedDims.reset(dim);
1033 if (currUnaccountedStrides[originalStride] <
1034 candidateStridesNumOccurences[originalStride]) {
1040 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1041 originalType.getRank())
1053static FailureOr<llvm::SmallBitVector>
1056 llvm::SmallBitVector unusedDims(originalType.getRank());
1057 if (originalType.getRank() == reducedType.getRank())
1060 for (
const auto &dim : llvm::enumerate(sizes))
1061 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1062 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1063 unusedDims.set(dim.index());
1067 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1068 originalType.getRank())
1072 int64_t originalOffset, candidateOffset;
1074 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1076 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1084 if (strides.size() <= 1)
1086 return llvm::any_of(strides.drop_back(),
1087 [](
int64_t s) { return !ShapedType::isDynamic(s); });
1089 if (hasNonTrivialStaticStride(originalStrides) ||
1090 hasNonTrivialStaticStride(candidateStrides)) {
1091 FailureOr<llvm::SmallBitVector> strideBased =
1094 candidateStrides, unusedDims);
1095 if (succeeded(strideBased))
1096 return *strideBased;
1102llvm::SmallBitVector SubViewOp::getDroppedDims() {
1103 MemRefType sourceType = getSourceType();
1104 MemRefType resultType =
getType();
1105 FailureOr<llvm::SmallBitVector> unusedDims =
1107 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1111OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1113 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1118 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1124 int64_t indexVal = index.getInt();
1125 if (indexVal < 0 || indexVal >= memrefType.getRank())
1129 if (!memrefType.isDynamicDim(index.getInt())) {
1131 return builder.
getIndexAttr(memrefType.getShape()[index.getInt()]);
1135 unsigned unsignedIndex = index.getValue().getZExtValue();
1138 Operation *definingOp = getSource().getDefiningOp();
1140 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1141 return *(alloc.getDynamicSizes().begin() +
1142 memrefType.getDynamicDimIndex(unsignedIndex));
1144 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1145 return *(alloca.getDynamicSizes().begin() +
1146 memrefType.getDynamicDimIndex(unsignedIndex));
1148 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1149 return *(view.getDynamicSizes().begin() +
1150 memrefType.getDynamicDimIndex(unsignedIndex));
1152 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1157 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1158 unsigned dynamicIdx = 0;
1159 for (OpFoldResult size : subview.getMixedSizes()) {
1160 if (llvm::isa<Attribute>(size))
1162 if (dynamicIdx == dynamicResultDimIdx)
1179struct DimOfMemRefReshape :
public OpRewritePattern<DimOp> {
1180 using OpRewritePattern<DimOp>::OpRewritePattern;
1182 LogicalResult matchAndRewrite(DimOp dim,
1183 PatternRewriter &rewriter)
const override {
1184 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1188 dim,
"Dim op is not defined by a reshape op.");
1199 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1200 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1201 if (reshape->isBeforeInBlock(definingOp)) {
1204 "dim.getIndex is not defined before reshape in the same block.");
1209 else if (dim->getBlock() != reshape->getBlock() &&
1210 !dim.getIndex().getParentRegion()->isProperAncestor(
1211 reshape->getParentRegion())) {
1216 dim,
"dim.getIndex does not dominate reshape.");
1222 Location loc = dim.getLoc();
1224 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1225 if (
load.getType() != dim.getType())
1226 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(),
load);
1234void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1235 MLIRContext *context) {
1236 results.
add<DimOfMemRefReshape>(context);
1243void DmaStartOp::build(OpBuilder &builder, OperationState &
result,
1244 Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
1246 Value tagMemRef,
ValueRange tagIndices, Value stride,
1247 Value elementsPerStride) {
1248 result.addOperands(srcMemRef);
1249 result.addOperands(srcIndices);
1250 result.addOperands(destMemRef);
1251 result.addOperands(destIndices);
1252 result.addOperands({numElements, tagMemRef});
1253 result.addOperands(tagIndices);
1255 result.addOperands({stride, elementsPerStride});
1258void DmaStartOp::print(OpAsmPrinter &p) {
1259 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1260 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1261 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1263 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1266 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1267 <<
", " << getTagMemRef().getType();
1278ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &
result) {
1279 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1280 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1281 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1283 OpAsmParser::UnresolvedOperand numElementsInfo;
1284 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1285 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1286 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1288 SmallVector<Type, 3> types;
1308 bool isStrided = strideInfo.size() == 2;
1309 if (!strideInfo.empty() && !isStrided) {
1311 "expected two stride related operands");
1316 if (types.size() != 3)
1338LogicalResult DmaStartOp::verify() {
1339 unsigned numOperands = getNumOperands();
1343 if (numOperands < 4)
1344 return emitOpError(
"expected at least 4 operands");
1349 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1350 return emitOpError(
"expected source to be of memref type");
1351 if (numOperands < getSrcMemRefRank() + 4)
1352 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1354 if (!getSrcIndices().empty() &&
1355 !llvm::all_of(getSrcIndices().getTypes(),
1356 [](Type t) {
return t.
isIndex(); }))
1357 return emitOpError(
"expected source indices to be of index type");
1360 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1361 return emitOpError(
"expected destination to be of memref type");
1362 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1363 if (numOperands < numExpectedOperands)
1364 return emitOpError() <<
"expected at least " << numExpectedOperands
1366 if (!getDstIndices().empty() &&
1367 !llvm::all_of(getDstIndices().getTypes(),
1368 [](Type t) {
return t.
isIndex(); }))
1369 return emitOpError(
"expected destination indices to be of index type");
1373 return emitOpError(
"expected num elements to be of index type");
1376 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1377 return emitOpError(
"expected tag to be of memref type");
1378 numExpectedOperands += getTagMemRefRank();
1379 if (numOperands < numExpectedOperands)
1380 return emitOpError() <<
"expected at least " << numExpectedOperands
1382 if (!getTagIndices().empty() &&
1383 !llvm::all_of(getTagIndices().getTypes(),
1384 [](Type t) {
return t.
isIndex(); }))
1385 return emitOpError(
"expected tag indices to be of index type");
1389 if (numOperands != numExpectedOperands &&
1390 numOperands != numExpectedOperands + 2)
1391 return emitOpError(
"incorrect number of operands");
1395 if (!getStride().
getType().isIndex() ||
1396 !getNumElementsPerStride().
getType().isIndex())
1398 "expected stride and num elements per stride to be of type index");
1404LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1405 SmallVectorImpl<OpFoldResult> &results) {
1410void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
1414 SmallVector<Value> newOperands;
1415 newOperands.push_back(newSrc);
1416 llvm::append_range(newOperands, newSrcIndices);
1417 newOperands.push_back(newDst);
1418 llvm::append_range(newOperands, newDstIndices);
1420 newOperands.push_back(getTagMemRef());
1421 llvm::append_range(newOperands, getTagIndices());
1423 newOperands.push_back(getStride());
1424 newOperands.push_back(getNumElementsPerStride());
1427 rewriter.
modifyOpInPlace(*
this, [&]() { (*this)->setOperands(newOperands); });
1434LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1435 SmallVectorImpl<OpFoldResult> &results) {
1440LogicalResult DmaWaitOp::verify() {
1442 unsigned numTagIndices = getTagIndices().size();
1443 unsigned tagMemRefRank = getTagMemRefRank();
1444 if (numTagIndices != tagMemRefRank)
1445 return emitOpError() <<
"expected tagIndices to have the same number of "
1446 "elements as the tagMemRef rank, expected "
1447 << tagMemRefRank <<
", but got " << numTagIndices;
1455void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1457 setNameFn(getResult(),
"intptr");
1466LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1467 MLIRContext *context, std::optional<Location> location,
1468 ExtractStridedMetadataOp::Adaptor adaptor,
1469 SmallVectorImpl<Type> &inferredReturnTypes) {
1470 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1474 unsigned sourceRank = sourceType.getRank();
1475 IndexType indexType = IndexType::get(context);
1477 MemRefType::get({}, sourceType.getElementType(),
1478 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1480 inferredReturnTypes.push_back(memrefType);
1482 inferredReturnTypes.push_back(indexType);
1484 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1485 inferredReturnTypes.push_back(indexType);
1489void ExtractStridedMetadataOp::getAsmResultNames(
1491 setNameFn(getBaseBuffer(),
"base_buffer");
1492 setNameFn(getOffset(),
"offset");
1495 if (!getSizes().empty()) {
1496 setNameFn(getSizes().front(),
"sizes");
1497 setNameFn(getStrides().front(),
"strides");
1504template <
typename Container>
1508 assert(values.size() == maybeConstants.size() &&
1509 " expected values and maybeConstants of the same size");
1510 bool atLeastOneReplacement =
false;
1511 for (
auto [maybeConstant,
result] : llvm::zip(maybeConstants, values)) {
1516 assert(isa<Attribute>(maybeConstant) &&
1517 "The constified value should be either unchanged (i.e., == result) "
1521 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1526 atLeastOneReplacement =
true;
1529 return atLeastOneReplacement;
1533ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1534 SmallVectorImpl<OpFoldResult> &results) {
1535 OpBuilder builder(*
this);
1539 getConstifiedMixedOffset());
1541 getConstifiedMixedSizes());
1543 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1546 if (
auto prev = getSource().getDefiningOp<CastOp>())
1547 if (isa<MemRefType>(prev.getSource().getType())) {
1548 getSourceMutable().assign(prev.getSource());
1549 atLeastOneReplacement =
true;
1552 return success(atLeastOneReplacement);
1555SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1561SmallVector<OpFoldResult>
1562ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1564 SmallVector<int64_t> staticValues;
1566 LogicalResult status =
1567 getSource().getType().getStridesAndOffset(staticValues, unused);
1569 assert(succeeded(status) &&
"could not get strides from type");
1574OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1576 SmallVector<OpFoldResult> values(1, offsetOfr);
1577 SmallVector<int64_t> staticValues, unused;
1579 LogicalResult status =
1580 getSource().getType().getStridesAndOffset(unused, offset);
1582 assert(succeeded(status) &&
"could not get offset from type");
1583 staticValues.push_back(offset);
1592void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &
result,
1594 OpBuilder::InsertionGuard g(builder);
1595 result.addOperands(memref);
1598 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1599 Type elementType = memrefType.getElementType();
1600 result.addTypes(elementType);
1602 Region *bodyRegion =
result.addRegion();
1608LogicalResult GenericAtomicRMWOp::verify() {
1609 auto &body = getRegion();
1610 if (body.getNumArguments() != 1)
1611 return emitOpError(
"expected single number of entry block arguments");
1613 if (getResult().
getType() != body.getArgument(0).getType())
1614 return emitOpError(
"expected block argument of the same type result type");
1617 body.walk([&](Operation *nestedOp) {
1621 "body of 'memref.generic_atomic_rmw' should contain "
1622 "only operations with no side effects");
1629ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1630 OperationState &
result) {
1631 OpAsmParser::UnresolvedOperand memref;
1633 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1643 Region *body =
result.addRegion();
1651void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1652 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1653 <<
"] : " << getMemref().
getType() <<
' ';
1662std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
1663 RewriterBase &rewriter, Value newMemref,
ValueRange newIndices) {
1665 getMemrefMutable().assign(newMemref);
1666 getIndicesMutable().assign(newIndices);
1668 return std::nullopt;
1675LogicalResult AtomicYieldOp::verify() {
1676 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1677 Type resultType = getResult().getType();
1678 if (parentType != resultType)
1679 return emitOpError() <<
"types mismatch between yield op: " << resultType
1680 <<
" and its parent: " << parentType;
1692 if (!op.isExternal()) {
1694 if (op.isUninitialized())
1695 p <<
"uninitialized";
1708 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1709 if (!memrefType || !memrefType.hasStaticShape())
1711 <<
"type should be static shaped memref, but got " << type;
1712 typeAttr = TypeAttr::get(type);
1718 initialValue = UnitAttr::get(parser.
getContext());
1725 if (!llvm::isa<ElementsAttr>(initialValue))
1727 <<
"initial value should be a unit or elements attribute";
1731LogicalResult GlobalOp::verify() {
1732 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1733 if (!memrefType || !memrefType.hasStaticShape())
1734 return emitOpError(
"type should be static shaped memref, but got ")
1739 if (getInitialValue().has_value()) {
1740 Attribute initValue = getInitialValue().value();
1741 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1742 return emitOpError(
"initial value should be a unit or elements "
1743 "attribute, but got ")
1748 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1750 auto initElementType =
1751 cast<TensorType>(elementsAttr.getType()).getElementType();
1752 auto memrefElementType = memrefType.getElementType();
1754 if (initElementType != memrefElementType)
1755 return emitOpError(
"initial value element expected to be of type ")
1756 << memrefElementType <<
", but was of type " << initElementType;
1761 auto initShape = elementsAttr.getShapedType().getShape();
1762 auto memrefShape = memrefType.getShape();
1763 if (initShape != memrefShape)
1764 return emitOpError(
"initial value shape expected to be ")
1765 << memrefShape <<
" but was " << initShape;
1773ElementsAttr GlobalOp::getConstantInitValue() {
1774 auto initVal = getInitialValue();
1775 if (getConstant() && initVal.has_value())
1776 return llvm::cast<ElementsAttr>(initVal.value());
1785GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1792 << getName() <<
"' does not reference a valid global memref";
1794 Type resultType = getResult().getType();
1795 if (global.getType() != resultType)
1797 << resultType <<
" does not match type " << global.getType()
1798 <<
" of the global memref @" << getName();
1806OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1812 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1818 getGlobalOp, getGlobalOp.getNameAttr());
1823 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1827 return splatAttr.getSplatValue<Attribute>();
1832std::optional<SmallVector<Value>>
1833LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1836 getMemrefMutable().assign(newMemref);
1837 getIndicesMutable().assign(newIndices);
1839 return std::nullopt;
1842FailureOr<std::optional<SmallVector<Value>>>
1843LoadOp::bubbleDownCasts(OpBuilder &builder) {
1852void MemorySpaceCastOp::getAsmResultNames(
1854 setNameFn(getResult(),
"memspacecast");
1858 if (inputs.size() != 1 || outputs.size() != 1)
1860 Type a = inputs.front(),
b = outputs.front();
1861 auto aT = llvm::dyn_cast<MemRefType>(a);
1862 auto bT = llvm::dyn_cast<MemRefType>(
b);
1864 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1865 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
1868 if (aT.getElementType() != bT.getElementType())
1870 if (aT.getLayout() != bT.getLayout())
1872 if (aT.getShape() != bT.getShape())
1877 return uaT.getElementType() == ubT.getElementType();
1882OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1885 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1886 getSourceMutable().assign(parentCast.getSource());
1900bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1901 PtrLikeTypeInterface src) {
1902 return isa<BaseMemRefType>(tgt) &&
1903 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1906MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1907 OpBuilder &
b, PtrLikeTypeInterface tgt,
1909 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1910 return MemorySpaceCastOp::create(
b, getLoc(), tgt, src);
1914bool MemorySpaceCastOp::isSourcePromotable() {
1915 return getDest().getType().getMemorySpace() ==
nullptr;
1922void PrefetchOp::print(OpAsmPrinter &p) {
1923 p <<
" " << getMemref() <<
'[';
1925 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1926 p <<
", locality<" << getLocalityHint();
1927 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1929 (*this)->getAttrs(),
1930 {
"localityHint",
"isWrite",
"isDataCache"});
1934ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &
result) {
1935 OpAsmParser::UnresolvedOperand memrefInfo;
1936 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1937 IntegerAttr localityHint;
1939 StringRef readOrWrite, cacheType;
1956 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1958 "rw specifier has to be 'read' or 'write'");
1959 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1962 if (cacheType !=
"data" && cacheType !=
"instr")
1964 "cache type has to be 'data' or 'instr'");
1966 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1972LogicalResult PrefetchOp::verify() {
1979LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1980 SmallVectorImpl<OpFoldResult> &results) {
1987std::optional<SmallVector<Value>>
1988PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1991 getMemrefMutable().assign(newMemref);
1992 getIndicesMutable().assign(newIndices);
1994 return std::nullopt;
2001OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
2003 auto type = getOperand().getType();
2004 auto shapedType = llvm::dyn_cast<ShapedType>(type);
2005 if (shapedType && shapedType.hasRank())
2006 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
2007 return IntegerAttr();
2014void ReinterpretCastOp::getAsmResultNames(
2016 setNameFn(getResult(),
"reinterpret_cast");
2022void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2023 MemRefType resultType, Value source,
2024 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
2025 ArrayRef<OpFoldResult> strides,
2026 ArrayRef<NamedAttribute> attrs) {
2027 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2028 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2032 result.addAttributes(attrs);
2033 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2034 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2035 b.getDenseI64ArrayAttr(staticSizes),
2036 b.getDenseI64ArrayAttr(staticStrides));
2039void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2040 Value source, OpFoldResult offset,
2041 ArrayRef<OpFoldResult> sizes,
2042 ArrayRef<OpFoldResult> strides,
2043 ArrayRef<NamedAttribute> attrs) {
2044 auto sourceType = cast<BaseMemRefType>(source.
getType());
2045 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2046 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2050 auto stridedLayout = StridedLayoutAttr::get(
2051 b.getContext(), staticOffsets.front(), staticStrides);
2052 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2053 stridedLayout, sourceType.getMemorySpace());
2054 build(
b,
result, resultType, source, offset, sizes, strides, attrs);
2057void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2058 MemRefType resultType, Value source,
2059 int64_t offset, ArrayRef<int64_t> sizes,
2060 ArrayRef<int64_t> strides,
2061 ArrayRef<NamedAttribute> attrs) {
2062 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2063 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
2064 SmallVector<OpFoldResult> strideValues =
2065 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2066 return b.getI64IntegerAttr(v);
2068 build(
b,
result, resultType, source,
b.getI64IntegerAttr(offset), sizeValues,
2069 strideValues, attrs);
2072void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2073 MemRefType resultType, Value source, Value offset,
2075 ArrayRef<NamedAttribute> attrs) {
2076 SmallVector<OpFoldResult> sizeValues =
2077 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2078 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2079 strides, [](Value v) -> OpFoldResult {
return v; });
2080 build(
b,
result, resultType, source, offset, sizeValues, strideValues, attrs);
2085LogicalResult ReinterpretCastOp::verify() {
2087 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
2088 auto resultType = llvm::cast<MemRefType>(
getType());
2089 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2090 return emitError(
"different memory spaces specified for source type ")
2091 << srcType <<
" and result memref type " << resultType;
2097 for (
auto [idx, resultSize, expectedSize] :
2098 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2099 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2100 return emitError(
"expected result type with size = ")
2101 << (ShapedType::isDynamic(expectedSize)
2102 ? std::string(
"dynamic")
2103 : std::to_string(expectedSize))
2104 <<
" instead of " << resultSize <<
" in dim = " << idx;
2110 int64_t resultOffset;
2111 SmallVector<int64_t, 4> resultStrides;
2112 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2113 return emitError(
"expected result type to have strided layout but found ")
2117 int64_t expectedOffset = getStaticOffsets().front();
2118 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2119 return emitError(
"expected result type with offset = ")
2120 << (ShapedType::isDynamic(expectedOffset)
2121 ? std::string(
"dynamic")
2122 : std::to_string(expectedOffset))
2123 <<
" instead of " << resultOffset;
2126 for (
auto [idx, resultStride, expectedStride] :
2127 llvm::enumerate(resultStrides, getStaticStrides())) {
2128 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2129 return emitError(
"expected result type with stride = ")
2130 << (ShapedType::isDynamic(expectedStride)
2131 ? std::string(
"dynamic")
2132 : std::to_string(expectedStride))
2133 <<
" instead of " << resultStride <<
" in dim = " << idx;
2139OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) {
2140 Value src = getSource();
2141 auto getPrevSrc = [&]() -> Value {
2144 return prev.getSource();
2148 return prev.getSource();
2154 return prev.getSource();
2159 if (
auto prevSrc = getPrevSrc()) {
2160 getSourceMutable().assign(prevSrc);
2173SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2179SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2180 SmallVector<OpFoldResult> values = getMixedStrides();
2181 SmallVector<int64_t> staticValues;
2183 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2185 assert(succeeded(status) &&
"could not get strides from type");
2190OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2191 SmallVector<OpFoldResult> values = getMixedOffsets();
2192 assert(values.size() == 1 &&
2193 "reinterpret_cast must have one and only one offset");
2194 SmallVector<int64_t> staticValues, unused;
2196 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2198 assert(succeeded(status) &&
"could not get offset from type");
2199 staticValues.push_back(offset);
2247struct ReinterpretCastOpExtractStridedMetadataFolder
2248 :
public OpRewritePattern<ReinterpretCastOp> {
2250 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2252 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2253 PatternRewriter &rewriter)
const override {
2254 auto extractStridedMetadata =
2255 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2256 if (!extractStridedMetadata)
2261 auto isReinterpretCastNoop = [&]() ->
bool {
2263 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2264 op.getConstifiedMixedStrides()))
2268 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2269 op.getConstifiedMixedSizes()))
2273 assert(op.getMixedOffsets().size() == 1 &&
2274 "reinterpret_cast with more than one offset should have been "
2275 "rejected by the verifier");
2276 return extractStridedMetadata.getConstifiedMixedOffset() ==
2277 op.getConstifiedMixedOffset();
2280 if (!isReinterpretCastNoop()) {
2297 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2307 Type srcTy = extractStridedMetadata.getSource().getType();
2308 if (srcTy == op.getResult().getType())
2309 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2312 extractStridedMetadata.getSource());
2318struct ReinterpretCastOpConstantFolder
2319 :
public OpRewritePattern<ReinterpretCastOp> {
2321 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2323 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2324 PatternRewriter &rewriter)
const override {
2325 unsigned srcStaticCount = llvm::count_if(
2326 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2327 op.getMixedStrides()),
2328 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2330 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2331 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2332 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2339 offsets[0] = op.getMixedOffsets()[0];
2344 for (
auto it : llvm::zip(op.getMixedSizes(), sizes)) {
2345 auto &srcSizeOfr = std::get<0>(it);
2346 auto &sizeOfr = std::get<1>(it);
2349 sizeOfr = srcSizeOfr;
2356 if (srcStaticCount ==
2357 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2358 [](OpFoldResult ofr) {
return isa<Attribute>(ofr); }))
2361 auto newReinterpretCast = ReinterpretCastOp::create(
2362 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2370void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2371 MLIRContext *context) {
2372 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2373 ReinterpretCastOpConstantFolder>(context);
2376FailureOr<std::optional<SmallVector<Value>>>
2377ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2385void CollapseShapeOp::getAsmResultNames(
2387 setNameFn(getResult(),
"collapse_shape");
2390void ExpandShapeOp::getAsmResultNames(
2392 setNameFn(getResult(),
"expand_shape");
2395LogicalResult ExpandShapeOp::reifyResultShapes(
2397 reifiedResultShapes = {
2398 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2411 bool allowMultipleDynamicDimsPerGroup) {
2413 if (collapsedShape.size() != reassociation.size())
2414 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2415 << reassociation.size() <<
", expected " << collapsedShape.size();
2420 for (
const auto &it : llvm::enumerate(reassociation)) {
2422 int64_t collapsedDim = it.index();
2424 bool foundDynamic =
false;
2425 for (
int64_t expandedDim : group) {
2426 if (expandedDim != nextDim++)
2427 return op->
emitOpError(
"reassociation indices must be contiguous");
2429 if (expandedDim >=
static_cast<int64_t>(expandedShape.size()))
2431 << expandedDim <<
" is out of bounds";
2434 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2435 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2437 "at most one dimension in a reassociation group may be dynamic");
2438 foundDynamic =
true;
2443 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2446 <<
") must be dynamic if and only if reassociation group is "
2451 if (!foundDynamic) {
2453 for (
int64_t expandedDim : group)
2454 groupSize *= expandedShape[expandedDim];
2455 if (groupSize != collapsedShape[collapsedDim])
2457 << collapsedShape[collapsedDim]
2458 <<
") must equal reassociation group size (" << groupSize <<
")";
2462 if (collapsedShape.empty()) {
2464 for (
int64_t d : expandedShape)
2467 "rank 0 memrefs can only be extended/collapsed with/from ones");
2468 }
else if (nextDim !=
static_cast<int64_t>(expandedShape.size())) {
2472 << expandedShape.size()
2473 <<
") inconsistent with number of reassociation indices (" << nextDim
2480SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2484SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2486 getReassociationIndices());
2489SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2493SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2495 getReassociationIndices());
2500static FailureOr<StridedLayoutAttr>
2505 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2507 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2522 reverseResultStrides.reserve(resultShape.size());
2523 unsigned shapeIndex = resultShape.size() - 1;
2524 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2526 int64_t currentStrideToExpand = std::get<1>(it);
2527 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2528 reverseResultStrides.push_back(currentStrideToExpand);
2529 currentStrideToExpand =
2535 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2536 resultStrides.resize(resultShape.size(), 1);
2537 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2540FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2541 MemRefType srcType, ArrayRef<int64_t> resultShape,
2542 ArrayRef<ReassociationIndices> reassociation) {
2543 if (srcType.getLayout().isIdentity()) {
2546 MemRefLayoutAttrInterface layout;
2547 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2548 srcType.getMemorySpace());
2552 FailureOr<StridedLayoutAttr> computedLayout =
2554 if (
failed(computedLayout))
2556 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2557 srcType.getMemorySpace());
2560FailureOr<SmallVector<OpFoldResult>>
2561ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
2562 MemRefType expandedType,
2563 ArrayRef<ReassociationIndices> reassociation,
2564 ArrayRef<OpFoldResult> inputShape) {
2565 std::optional<SmallVector<OpFoldResult>> outputShape =
2570 return *outputShape;
2573void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2574 Type resultType, Value src,
2575 ArrayRef<ReassociationIndices> reassociation,
2576 ArrayRef<OpFoldResult> outputShape) {
2577 auto [staticOutputShape, dynamicOutputShape] =
2579 build(builder,
result, llvm::cast<MemRefType>(resultType), src,
2581 dynamicOutputShape, staticOutputShape);
2584void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2585 Type resultType, Value src,
2586 ArrayRef<ReassociationIndices> reassociation) {
2587 SmallVector<OpFoldResult> inputShape =
2589 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2590 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2591 builder,
result.location, memrefResultTy, reassociation, inputShape);
2594 assert(succeeded(outputShape) &&
"unable to infer output shape");
2595 build(builder,
result, memrefResultTy, src, reassociation, *outputShape);
2598void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2599 ArrayRef<int64_t> resultShape, Value src,
2600 ArrayRef<ReassociationIndices> reassociation) {
2602 auto srcType = llvm::cast<MemRefType>(src.
getType());
2603 FailureOr<MemRefType> resultType =
2604 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2607 assert(succeeded(resultType) &&
"could not compute layout");
2608 build(builder,
result, *resultType, src, reassociation);
2611void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2612 ArrayRef<int64_t> resultShape, Value src,
2613 ArrayRef<ReassociationIndices> reassociation,
2614 ArrayRef<OpFoldResult> outputShape) {
2616 auto srcType = llvm::cast<MemRefType>(src.
getType());
2617 FailureOr<MemRefType> resultType =
2618 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2621 assert(succeeded(resultType) &&
"could not compute layout");
2622 build(builder,
result, *resultType, src, reassociation, outputShape);
2625LogicalResult ExpandShapeOp::verify() {
2626 MemRefType srcType = getSrcType();
2627 MemRefType resultType = getResultType();
2629 if (srcType.getRank() > resultType.getRank()) {
2630 auto r0 = srcType.getRank();
2631 auto r1 = resultType.getRank();
2633 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2634 << r0 <<
" > " << r1 <<
").";
2639 resultType.getShape(),
2640 getReassociationIndices(),
2645 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2646 srcType, resultType.getShape(), getReassociationIndices());
2647 if (
failed(expectedResultType))
2651 if (*expectedResultType != resultType)
2652 return emitOpError(
"expected expanded type to be ")
2653 << *expectedResultType <<
" but found " << resultType;
2655 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2656 return emitOpError(
"expected number of static shape bounds to be equal to "
2657 "the output rank (")
2658 << resultType.getRank() <<
") but found "
2659 << getStaticOutputShape().size() <<
" inputs instead";
2661 if ((int64_t)getOutputShape().size() !=
2662 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2663 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2664 "static_output_shape: static_output_shape has ")
2665 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2666 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2677 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2678 for (
auto [pos, shape] : llvm::enumerate(resShape)) {
2679 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2680 return emitOpError(
"invalid output shape provided at pos ") << pos;
2693 auto cast = op.getSrc().getDefiningOp<CastOp>();
2697 if (!CastOp::canFoldIntoConsumerOp(cast))
2705 for (
auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2707 if (!sizeOpt.has_value()) {
2708 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2712 newOutputShapeSizes.push_back(sizeOpt.value());
2713 newOutputShape[dimIdx] = rewriter.
getIndexAttr(sizeOpt.value());
2716 Value castSource = cast.getSource();
2717 auto castSourceType = llvm::cast<MemRefType>(castSource.
getType());
2719 op.getReassociationIndices();
2720 for (
auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2721 auto newOutputShapeSizesSlice =
2722 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2723 bool newOutputDynamic =
2724 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2725 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2727 op,
"folding cast will result in changing dynamicity in "
2728 "reassociation group");
2731 FailureOr<MemRefType> newResultTypeOrFailure =
2732 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2733 reassociationIndices);
2735 if (failed(newResultTypeOrFailure))
2737 op,
"could not compute new expanded type after folding cast");
2739 if (*newResultTypeOrFailure == op.getResultType()) {
2741 op, [&]() { op.getSrcMutable().assign(castSource); });
2743 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2744 *newResultTypeOrFailure, castSource,
2745 reassociationIndices, newOutputShape);
2752void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2753 MLIRContext *context) {
2755 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2756 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2757 ExpandShapeOpMemRefCastFolder>(context);
2760FailureOr<std::optional<SmallVector<Value>>>
2761ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2772static FailureOr<StridedLayoutAttr>
2775 bool strict =
false) {
2778 auto srcShape = srcType.getShape();
2779 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2788 resultStrides.reserve(reassociation.size());
2791 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2792 ref = ref.drop_back();
2793 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2794 resultStrides.push_back(srcStrides[ref.back()]);
2800 resultStrides.push_back(ShapedType::kDynamic);
2805 unsigned resultStrideIndex = resultStrides.size() - 1;
2809 for (
int64_t idx : llvm::reverse(trailingReassocs)) {
2814 if (srcShape[idx - 1] == 1)
2826 if (strict && (stride.saturated || srcStride.saturated))
2829 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2833 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2836bool CollapseShapeOp::isGuaranteedCollapsible(
2837 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2839 if (srcType.getLayout().isIdentity())
2846MemRefType CollapseShapeOp::computeCollapsedType(
2847 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2848 SmallVector<int64_t> resultShape;
2849 resultShape.reserve(reassociation.size());
2852 for (int64_t srcDim : group)
2855 resultShape.push_back(groupSize.asInteger());
2858 if (srcType.getLayout().isIdentity()) {
2861 MemRefLayoutAttrInterface layout;
2862 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2863 srcType.getMemorySpace());
2869 FailureOr<StridedLayoutAttr> computedLayout =
2871 assert(succeeded(computedLayout) &&
2872 "invalid source layout map or collapsing non-contiguous dims");
2873 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2874 srcType.getMemorySpace());
2877void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2878 ArrayRef<ReassociationIndices> reassociation,
2879 ArrayRef<NamedAttribute> attrs) {
2880 auto srcType = llvm::cast<MemRefType>(src.
getType());
2881 MemRefType resultType =
2882 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2885 build(
b,
result, resultType, src, attrs);
2888LogicalResult CollapseShapeOp::verify() {
2889 MemRefType srcType = getSrcType();
2890 MemRefType resultType = getResultType();
2892 if (srcType.getRank() < resultType.getRank()) {
2893 auto r0 = srcType.getRank();
2894 auto r1 = resultType.getRank();
2896 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2897 << r0 <<
" < " << r1 <<
").";
2902 srcType.getShape(), getReassociationIndices(),
2907 MemRefType expectedResultType;
2908 if (srcType.getLayout().isIdentity()) {
2911 MemRefLayoutAttrInterface layout;
2912 expectedResultType =
2913 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2914 srcType.getMemorySpace());
2919 FailureOr<StridedLayoutAttr> computedLayout =
2921 if (
failed(computedLayout))
2923 "invalid source layout map or collapsing non-contiguous dims");
2924 expectedResultType =
2925 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2926 *computedLayout, srcType.getMemorySpace());
2929 if (expectedResultType != resultType)
2930 return emitOpError(
"expected collapsed type to be ")
2931 << expectedResultType <<
" but found " << resultType;
2943 auto cast = op.getOperand().getDefiningOp<CastOp>();
2947 if (!CastOp::canFoldIntoConsumerOp(cast))
2950 Type newResultType = CollapseShapeOp::computeCollapsedType(
2951 llvm::cast<MemRefType>(cast.getOperand().getType()),
2952 op.getReassociationIndices());
2954 if (newResultType == op.getResultType()) {
2956 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2959 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2960 op.getReassociationIndices());
2967void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2968 MLIRContext *context) {
2970 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2971 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2972 memref::DimOp, MemRefType>,
2973 CollapseShapeOpMemRefCastFolder>(context);
2976OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2978 adaptor.getOperands());
2981OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2983 adaptor.getOperands());
2986FailureOr<std::optional<SmallVector<Value>>>
2987CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2995void ReshapeOp::getAsmResultNames(
2997 setNameFn(getResult(),
"reshape");
3000LogicalResult ReshapeOp::verify() {
3001 Type operandType = getSource().getType();
3002 Type resultType = getResult().getType();
3004 Type operandElementType =
3005 llvm::cast<ShapedType>(operandType).getElementType();
3006 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
3007 if (operandElementType != resultElementType)
3008 return emitOpError(
"element types of source and destination memref "
3009 "types should be the same");
3011 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
3012 if (!operandMemRefType.getLayout().isIdentity())
3013 return emitOpError(
"source memref type should have identity affine map");
3017 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
3018 if (resultMemRefType) {
3019 if (!resultMemRefType.getLayout().isIdentity())
3020 return emitOpError(
"result memref type should have identity affine map");
3021 if (shapeSize == ShapedType::kDynamic)
3022 return emitOpError(
"cannot use shape operand with dynamic length to "
3023 "reshape to statically-ranked memref type");
3024 if (shapeSize != resultMemRefType.getRank())
3026 "length of shape operand differs from the result's memref rank");
3031FailureOr<std::optional<SmallVector<Value>>>
3032ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
3040LogicalResult StoreOp::fold(FoldAdaptor adaptor,
3041 SmallVectorImpl<OpFoldResult> &results) {
3048std::optional<SmallVector<Value>>
3049StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
3052 getMemrefMutable().assign(newMemref);
3053 getIndicesMutable().assign(newIndices);
3055 return std::nullopt;
3058FailureOr<std::optional<SmallVector<Value>>>
3059StoreOp::bubbleDownCasts(OpBuilder &builder) {
3068void SubViewOp::getAsmResultNames(
3070 setNameFn(getResult(),
"subview");
3076MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3077 ArrayRef<int64_t> staticOffsets,
3078 ArrayRef<int64_t> staticSizes,
3079 ArrayRef<int64_t> staticStrides) {
3080 unsigned rank = sourceMemRefType.getRank();
3082 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
3083 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
3084 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
3087 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3091 int64_t targetOffset = sourceOffset;
3092 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
3093 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3102 SmallVector<int64_t, 4> targetStrides;
3103 targetStrides.reserve(staticOffsets.size());
3104 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
3105 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3112 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3113 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3114 targetOffset, targetStrides),
3115 sourceMemRefType.getMemorySpace());
3118MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3119 ArrayRef<OpFoldResult> offsets,
3120 ArrayRef<OpFoldResult> sizes,
3121 ArrayRef<OpFoldResult> strides) {
3122 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3123 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3133 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3134 staticSizes, staticStrides);
3137MemRefType SubViewOp::inferRankReducedResultType(
3138 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3139 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3140 ArrayRef<int64_t> strides) {
3141 MemRefType inferredType =
3142 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3143 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
3145 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
3146 return inferredType;
3149 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3151 assert(dimsToProject.has_value() &&
"invalid rank reduction");
3154 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3155 SmallVector<int64_t> rankReducedStrides;
3156 rankReducedStrides.reserve(resultShape.size());
3157 for (
auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3158 if (!dimsToProject->contains(idx))
3159 rankReducedStrides.push_back(value);
3161 return MemRefType::get(resultShape, inferredType.getElementType(),
3162 StridedLayoutAttr::get(inferredLayout.getContext(),
3163 inferredLayout.getOffset(),
3164 rankReducedStrides),
3165 inferredType.getMemorySpace());
3168MemRefType SubViewOp::inferRankReducedResultType(
3169 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3170 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3171 ArrayRef<OpFoldResult> strides) {
3172 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3173 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3177 return SubViewOp::inferRankReducedResultType(
3178 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3184void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3185 MemRefType resultType, Value source,
3186 ArrayRef<OpFoldResult> offsets,
3187 ArrayRef<OpFoldResult> sizes,
3188 ArrayRef<OpFoldResult> strides,
3189 ArrayRef<NamedAttribute> attrs) {
3190 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3191 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3195 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
3198 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3199 staticSizes, staticStrides);
3201 result.addAttributes(attrs);
3202 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
3203 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3204 b.getDenseI64ArrayAttr(staticSizes),
3205 b.getDenseI64ArrayAttr(staticStrides));
3210void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3211 ArrayRef<OpFoldResult> offsets,
3212 ArrayRef<OpFoldResult> sizes,
3213 ArrayRef<OpFoldResult> strides,
3214 ArrayRef<NamedAttribute> attrs) {
3215 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3219void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3220 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3221 ArrayRef<int64_t> strides,
3222 ArrayRef<NamedAttribute> attrs) {
3223 SmallVector<OpFoldResult> offsetValues =
3224 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3225 return b.getI64IntegerAttr(v);
3227 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3228 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3229 SmallVector<OpFoldResult> strideValues =
3230 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3231 return b.getI64IntegerAttr(v);
3233 build(
b,
result, source, offsetValues, sizeValues, strideValues, attrs);
3238void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3239 MemRefType resultType, Value source,
3240 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3241 ArrayRef<int64_t> strides,
3242 ArrayRef<NamedAttribute> attrs) {
3243 SmallVector<OpFoldResult> offsetValues =
3244 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3245 return b.getI64IntegerAttr(v);
3247 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3248 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3249 SmallVector<OpFoldResult> strideValues =
3250 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3251 return b.getI64IntegerAttr(v);
3253 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues,
3259void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3260 MemRefType resultType, Value source,
ValueRange offsets,
3262 ArrayRef<NamedAttribute> attrs) {
3263 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3264 offsets, [](Value v) -> OpFoldResult {
return v; });
3265 SmallVector<OpFoldResult> sizeValues =
3266 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3267 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3268 strides, [](Value v) -> OpFoldResult {
return v; });
3269 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
3273void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3275 ArrayRef<NamedAttribute> attrs) {
3276 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3280Value SubViewOp::getViewSource() {
return getSource(); }
3287 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3288 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3289 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3296 const llvm::SmallBitVector &droppedDims) {
3297 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3298 "incorrect number of bits");
3299 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3300 "incorrect number of dropped dims");
3303 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3304 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3305 if (failed(res1) || failed(res2))
3307 for (
int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3310 if (t1Strides[i] != t2Strides[
j])
3318 SubViewOp op,
Type expectedType) {
3319 auto memrefType = llvm::cast<ShapedType>(expectedType);
3324 return op->emitError(
"expected result rank to be smaller or equal to ")
3325 <<
"the source rank, but got " << op.getType();
3327 return op->emitError(
"expected result type to be ")
3329 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3332 return op->emitError(
"expected result element type to be ")
3333 << memrefType.getElementType() <<
", but got " << op.getType();
3335 return op->emitError(
3336 "expected result and source memory spaces to match, but got ")
3339 return op->emitError(
"expected result type to be ")
3341 <<
" or a rank-reduced version. (mismatch of result layout), but "
3345 llvm_unreachable(
"unexpected subview verification result");
3349LogicalResult SubViewOp::verify() {
3350 MemRefType baseType = getSourceType();
3351 MemRefType subViewType =
getType();
3352 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3353 ArrayRef<int64_t> staticSizes = getStaticSizes();
3354 ArrayRef<int64_t> staticStrides = getStaticStrides();
3357 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3358 return emitError(
"different memory spaces specified for base memref "
3360 << baseType <<
" and subview memref type " << subViewType;
3363 if (!baseType.isStrided())
3364 return emitError(
"base type ") << baseType <<
" is not strided";
3368 MemRefType expectedType = SubViewOp::inferResultType(
3369 baseType, staticOffsets, staticSizes, staticStrides);
3374 expectedType, subViewType);
3379 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3381 *
this, expectedType);
3386 *
this, expectedType);
3396 *
this, expectedType);
3401 *
this, expectedType);
3405 SliceBoundsVerificationResult boundsResult =
3407 staticStrides,
true);
3409 return getOperation()->emitError(boundsResult.
errorMessage);
3415 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3424 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3425 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3426 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3428 unsigned rank = ranks[0];
3430 for (
unsigned idx = 0; idx < rank; ++idx) {
3432 op.isDynamicOffset(idx)
3433 ? op.getDynamicOffset(idx)
3436 op.isDynamicSize(idx)
3437 ? op.getDynamicSize(idx)
3440 op.isDynamicStride(idx)
3441 ? op.getDynamicStride(idx)
3443 res.emplace_back(
Range{offset, size, stride});
3456 MemRefType currentResultType, MemRefType currentSourceType,
3459 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3460 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3462 currentSourceType, currentResultType, mixedSizes);
3463 if (failed(unusedDims))
3466 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3468 unsigned numDimsAfterReduction =
3469 nonRankReducedType.getRank() - unusedDims->count();
3470 shape.reserve(numDimsAfterReduction);
3471 strides.reserve(numDimsAfterReduction);
3472 for (
const auto &[idx, size, stride] :
3473 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3474 nonRankReducedType.getShape(), layout.getStrides())) {
3475 if (unusedDims->test(idx))
3477 shape.push_back(size);
3478 strides.push_back(stride);
3481 return MemRefType::get(
shape, nonRankReducedType.getElementType(),
3482 StridedLayoutAttr::get(sourceType.getContext(),
3483 layout.getOffset(), strides),
3484 nonRankReducedType.getMemorySpace());
3489 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3490 unsigned rank = memrefType.getRank();
3494 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3495 targetShape, memrefType, offsets, sizes, strides);
3496 return b.createOrFold<memref::SubViewOp>(loc, targetType,
memref, offsets,
3503 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3504 assert(sourceMemrefType &&
"not a ranked memref type");
3505 auto sourceShape = sourceMemrefType.getShape();
3506 if (sourceShape.equals(desiredShape))
3508 auto maybeRankReductionMask =
3510 if (!maybeRankReductionMask)
3520 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3523 auto mixedOffsets = subViewOp.getMixedOffsets();
3524 auto mixedSizes = subViewOp.getMixedSizes();
3525 auto mixedStrides = subViewOp.getMixedStrides();
3530 return !intValue || intValue.value() != 0;
3537 return !intValue || intValue.value() != 1;
3543 for (
const auto &size : llvm::enumerate(mixedSizes)) {
3545 if (!intValue || *intValue != sourceShape[size.index()])
3569class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3571 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3573 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3574 PatternRewriter &rewriter)
const override {
3577 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3578 return matchPattern(operand, matchConstantIndex());
3582 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3586 if (!CastOp::canFoldIntoConsumerOp(castOp))
3594 subViewOp.getType(), subViewOp.getSourceType(),
3595 llvm::cast<MemRefType>(castOp.getSource().getType()),
3596 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3597 subViewOp.getMixedStrides());
3601 Value newSubView = SubViewOp::create(
3602 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3603 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3604 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3605 subViewOp.getStaticStrides());
3614class TrivialSubViewOpFolder final :
public OpRewritePattern<SubViewOp> {
3616 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3618 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3619 PatternRewriter &rewriter)
const override {
3622 if (subViewOp.getSourceType() == subViewOp.getType()) {
3623 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3627 subViewOp.getSource());
3639 MemRefType resTy = SubViewOp::inferResultType(
3640 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3643 MemRefType nonReducedType = resTy;
3646 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3647 if (droppedDims.none())
3648 return nonReducedType;
3651 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3656 for (
int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3657 if (droppedDims.test(i))
3659 targetStrides.push_back(nonReducedStrides[i]);
3660 targetShape.push_back(nonReducedType.getDimSize(i));
3663 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3664 StridedLayoutAttr::get(nonReducedType.getContext(),
3665 offset, targetStrides),
3666 nonReducedType.getMemorySpace());
3677void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3678 MLIRContext *context) {
3680 .
add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3681 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3682 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3685OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3686 MemRefType sourceMemrefType = getSource().getType();
3687 MemRefType resultMemrefType = getResult().getType();
3689 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3691 if (resultMemrefType == sourceMemrefType &&
3692 resultMemrefType.hasStaticShape() &&
3693 (!resultLayout || resultLayout.hasStaticLayout())) {
3694 return getViewSource();
3700 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3701 auto srcSizes = srcSubview.getMixedSizes();
3703 auto offsets = getMixedOffsets();
3705 auto strides = getMixedStrides();
3706 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3707 bool allSizesSame = llvm::equal(sizes, srcSizes);
3708 if (allOffsetsZero && allStridesOne && allSizesSame &&
3709 resultMemrefType == sourceMemrefType)
3710 return getViewSource();
3716FailureOr<std::optional<SmallVector<Value>>>
3717SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3721void SubViewOp::inferStridedMetadataRanges(
3722 ArrayRef<StridedMetadataRange> ranges,
GetIntRangeFn getIntRange,
3724 auto isUninitialized =
3725 +[](IntegerValueRange range) {
return range.isUninitialized(); };
3728 SmallVector<IntegerValueRange> offsetOperands =
3730 if (llvm::any_of(offsetOperands, isUninitialized))
3733 SmallVector<IntegerValueRange> sizeOperands =
3735 if (llvm::any_of(sizeOperands, isUninitialized))
3738 SmallVector<IntegerValueRange> stridesOperands =
3740 if (llvm::any_of(stridesOperands, isUninitialized))
3743 StridedMetadataRange sourceRange =
3744 ranges[getSourceMutable().getOperandNumber()];
3748 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.
getStrides();
3754 ConstantIntRanges offset = sourceRange.
getOffsets()[0];
3755 SmallVector<ConstantIntRanges> strides, sizes;
3757 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3758 bool dropped = droppedDims.test(i);
3760 ConstantIntRanges off =
3771 sizes.push_back(sizeOperands[i].getValue());
3774 setMetadata(getResult(),
3776 SmallVector<ConstantIntRanges>({std::move(offset)}),
3777 std::move(sizes), std::move(strides)));
3784void TransposeOp::getAsmResultNames(
3786 setNameFn(getResult(),
"transpose");
3792 auto originalSizes = memRefType.getShape();
3793 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3794 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3803 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3806void TransposeOp::build(OpBuilder &
b, OperationState &
result, Value in,
3807 AffineMapAttr permutation,
3808 ArrayRef<NamedAttribute> attrs) {
3809 auto permutationMap = permutation.getValue();
3810 assert(permutationMap);
3812 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3816 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3817 build(
b,
result, resultType, in, attrs);
3821void TransposeOp::print(OpAsmPrinter &p) {
3822 p <<
" " << getIn() <<
" " << getPermutation();
3824 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3827ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
3828 OpAsmParser::UnresolvedOperand in;
3829 AffineMap permutation;
3830 MemRefType srcType, dstType;
3839 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3840 AffineMapAttr::get(permutation));
3844LogicalResult TransposeOp::verify() {
3847 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3848 return emitOpError(
"expected a permutation map of same rank as the input");
3850 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3851 auto resultType = llvm::cast<MemRefType>(
getType());
3853 .canonicalizeStridedLayout();
3855 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3858 <<
" is not equivalent to the canonical transposed input type "
3859 << canonicalResultType;
3863OpFoldResult TransposeOp::fold(FoldAdaptor) {
3866 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3870 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3871 AffineMap composedPermutation =
3872 getPermutation().compose(otherTransposeOp.getPermutation());
3873 getInMutable().assign(otherTransposeOp.getIn());
3874 setPermutation(composedPermutation);
3880FailureOr<std::optional<SmallVector<Value>>>
3881TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3889void ViewOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
3890 setNameFn(getResult(),
"view");
3893LogicalResult ViewOp::verify() {
3894 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3898 if (!baseType.getLayout().isIdentity())
3899 return emitError(
"unsupported map for base memref type ") << baseType;
3902 if (!viewType.getLayout().isIdentity())
3903 return emitError(
"unsupported map for result memref type ") << viewType;
3906 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3907 return emitError(
"different memory spaces specified for base memref "
3909 << baseType <<
" and view memref type " << viewType;
3918Value ViewOp::getViewSource() {
return getSource(); }
3920OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3921 MemRefType sourceMemrefType = getSource().getType();
3922 MemRefType resultMemrefType = getResult().getType();
3924 if (resultMemrefType == sourceMemrefType &&
3925 resultMemrefType.hasStaticShape() &&
isZeroInteger(getByteShift()))
3926 return getViewSource();
3931SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3932 SmallVector<OpFoldResult>
result;
3936 if (ShapedType::isDynamic(dim)) {
3937 result.push_back(getSizes()[ctr++]);
3939 result.push_back(
b.getIndexAttr(dim));
3951 SmallVectorImpl<Value> &foldedDynamicSizes) {
3952 SmallVector<int64_t> staticShape(type.getShape());
3953 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3954 "incorrect number of dynamic sizes");
3958 for (
auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3959 if (ShapedType::isStatic(dimSize))
3962 Value dynamicSize = dynamicSizes[ctr++];
3965 if (cst.value() < 0) {
3966 foldedDynamicSizes.push_back(dynamicSize);
3969 staticShape[dim] = cst.value();
3971 foldedDynamicSizes.push_back(dynamicSize);
3975 return MemRefType::Builder(type).setShape(staticShape);
3989struct ViewOpShapeFolder :
public OpRewritePattern<ViewOp> {
3992 LogicalResult matchAndRewrite(ViewOp viewOp,
3993 PatternRewriter &rewriter)
const override {
3994 SmallVector<Value> foldedDynamicSizes;
3995 MemRefType resultType = viewOp.getType();
3997 resultType, viewOp.getSizes(), foldedDynamicSizes);
4000 if (foldedMemRefType == resultType)
4004 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
4005 viewOp.getSource(), viewOp.getByteShift(),
4006 foldedDynamicSizes);
4014struct ViewOpMemrefCastFolder :
public OpRewritePattern<ViewOp> {
4017 LogicalResult matchAndRewrite(ViewOp viewOp,
4018 PatternRewriter &rewriter)
const override {
4019 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
4024 viewOp, viewOp.getType(), memrefCastOp.getSource(),
4025 viewOp.getByteShift(), viewOp.getSizes());
4031void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
4032 MLIRContext *context) {
4033 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4036FailureOr<std::optional<SmallVector<Value>>>
4037ViewOp::bubbleDownCasts(OpBuilder &builder) {
4045LogicalResult AtomicRMWOp::verify() {
4046 switch (getKind()) {
4047 case arith::AtomicRMWKind::addf:
4048 case arith::AtomicRMWKind::maximumf:
4049 case arith::AtomicRMWKind::minimumf:
4050 case arith::AtomicRMWKind::mulf:
4051 if (!llvm::isa<FloatType>(getValue().
getType()))
4053 << arith::stringifyAtomicRMWKind(getKind())
4054 <<
"' expects a floating-point type";
4056 case arith::AtomicRMWKind::addi:
4057 case arith::AtomicRMWKind::maxs:
4058 case arith::AtomicRMWKind::maxu:
4059 case arith::AtomicRMWKind::mins:
4060 case arith::AtomicRMWKind::minu:
4061 case arith::AtomicRMWKind::muli:
4062 case arith::AtomicRMWKind::ori:
4063 case arith::AtomicRMWKind::xori:
4064 case arith::AtomicRMWKind::andi:
4065 if (!llvm::isa<IntegerType>(getValue().
getType()))
4067 << arith::stringifyAtomicRMWKind(getKind())
4068 <<
"' expects an integer type";
4076OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4080 return OpFoldResult();
4083FailureOr<std::optional<SmallVector<Value>>>
4084AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4091std::optional<SmallVector<Value>>
4092AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
4095 getMemrefMutable().assign(newMemref);
4096 getIndicesMutable().assign(newIndices);
4098 return std::nullopt;
4105#define GET_OP_CLASSES
4106#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.