25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
37 return arith::ConstantOp::materialize(builder, value, type, loc);
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
63 if (
auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(
memref.getShape(),
memref.getElementType());
65 if (
auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(
memref.getElementType());
72 auto memrefType = llvm::cast<MemRefType>(value.
getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.
createOrFold<memref::DimOp>(loc, value, dim);
81 auto memrefType = llvm::cast<MemRefType>(value.
getType());
83 for (
int64_t i = 0; i < memrefType.getRank(); ++i)
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (
auto [i, cstVal] : llvm::enumerate(constValues)) {
104 if (ShapedType::isStatic(cstVal)) {
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.
clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.
clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
148template <
typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
159 llvm::append_range(operands, op->getOperands());
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(),
TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
168 MemorySpaceCastOpInterface
result = castOp.cloneMemorySpaceCastOp(
171 return std::optional<SmallVector<Value>>(
179void AllocOp::getAsmResultNames(
181 setNameFn(getResult(),
"alloc");
184void AllocaOp::getAsmResultNames(
186 setNameFn(getResult(),
"alloca");
189template <
typename AllocLikeOp>
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
195 return op.emitOpError(
"result must be a memref");
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError(
"symbol operand count does not equal memref symbol "
206 << numSymbols <<
", got " << op.getSymbolOperands().size();
213LogicalResult AllocaOp::verify() {
217 "requires an ancestor op with AutomaticAllocationScope trait");
224template <
typename AllocLikeOp>
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter)
const override {
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
236 return constSizeArg.isNonNegative();
240 auto memrefType = alloc.getType();
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
248 unsigned dynamicDimPos = 0;
249 for (
unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
259 constSizeArg.isNonNegative()) {
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
288 using OpRewritePattern<T>::OpRewritePattern;
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter)
const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
310 results.
add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
315 results.
add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).
getType());
325 MemRefType resultType =
getType();
328 if (!sourceType.getLayout().isIdentity())
329 return emitError(
"unsupported layout for source memref type ")
333 if (!resultType.getLayout().isIdentity())
334 return emitError(
"unsupported layout for result memref type ")
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError(
"different memory spaces specified for source memref "
341 << sourceType <<
" and result memref type " << resultType;
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError(
"missing dimension operand for result type ")
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError(
"unnecessary dimension operand for result type ")
361 results.
add<SimplifyDeadAlloc<ReallocOp>>(context);
369 bool printBlockTerminators =
false;
372 if (!getResults().empty()) {
373 p <<
" -> (" << getResultTypes() <<
")";
374 printBlockTerminators =
true;
379 printBlockTerminators);
385 result.regions.reserve(1);
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.
getBuilder(),
405void AllocaScopeOp::getSuccessorRegions(
422 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
445 MemoryEffectOpInterface
interface = dyn_cast<MemoryEffectOpInterface>(op);
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
476 bool hasPotentialAlloca =
489 if (hasPotentialAlloca) {
522 if (!lastParentWithoutScope ||
535 lastParentWithoutScope = lastParentWithoutScope->
getParentOp();
536 if (!lastParentWithoutScope ||
543 Region *containingRegion =
nullptr;
544 for (
auto &r : lastParentWithoutScope->
getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion ==
nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
551 assert(containingRegion &&
"op must be contained in a region");
561 return containingRegion->isAncestor(v.getParentRegion());
564 toHoist.push_back(alloc);
571 for (
auto *op : toHoist) {
572 auto *cloned = rewriter.
clone(*op);
573 rewriter.
replaceOp(op, cloned->getResults());
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError(
"alignment must be power of 2");
594void AssumeAlignmentOp::getAsmResultNames(
596 setNameFn(getResult(),
"assume_align");
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
603 if (source.getAlignment() != getAlignment())
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(
OpBuilder &builder) {
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(
OpBuilder &builder,
616 assert(resultIndex == 0 &&
"AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError(
"operand types and result types must match");
628 if (getOperandTypes().empty())
629 return emitOpError(
"expected at least one operand");
634LogicalResult DistinctObjectsOp::inferReturnTypes(
639 llvm::copy(operands.
getTypes(), std::back_inserter(inferredReturnTypes));
648 setNameFn(getResult(),
"cast");
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
694 if (!sourceType || !resultType)
698 if (sourceType.getElementType() != resultType.getElementType())
702 if (sourceType.getRank() != resultType.getRank())
706 int64_t sourceOffset, resultOffset;
708 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
713 for (
auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
727 for (
auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
738 if (inputs.size() != 1 || outputs.size() != 1)
740 if (inputs == outputs)
742 Type a = inputs.front(),
b = outputs.front();
743 auto aT = llvm::dyn_cast<MemRefType>(a);
744 auto bT = llvm::dyn_cast<MemRefType>(
b);
746 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
747 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
750 if (aT.getElementType() != bT.getElementType())
752 if (aT.getLayout() != bT.getLayout()) {
755 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
756 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
757 aStrides.size() != bStrides.size())
766 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(
b) || a ==
b);
768 if (!checkCompatible(aOffset, bOffset))
771 if (aT.getDimSize(
index) == 1 || bT.getDimSize(
index) == 1)
773 if (!checkCompatible(aStride, bStrides[
index]))
777 if (aT.getMemorySpace() != bT.getMemorySpace())
781 if (aT.getRank() != bT.getRank())
784 for (
unsigned i = 0, e = aT.getRank(); i != e; ++i) {
785 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
786 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
800 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
801 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
802 if (aEltType != bEltType)
805 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
806 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
807 return aMemSpace == bMemSpace;
817FailureOr<std::optional<SmallVector<Value>>>
818CastOp::bubbleDownCasts(
OpBuilder &builder) {
830 using OpRewritePattern<CopyOp>::OpRewritePattern;
832 LogicalResult matchAndRewrite(CopyOp copyOp,
833 PatternRewriter &rewriter)
const override {
834 if (copyOp.getSource() != copyOp.getTarget())
843 using OpRewritePattern<CopyOp>::OpRewritePattern;
845 static bool isEmptyMemRef(BaseMemRefType type) {
849 LogicalResult matchAndRewrite(CopyOp copyOp,
850 PatternRewriter &rewriter)
const override {
851 if (isEmptyMemRef(copyOp.getSource().getType()) ||
852 isEmptyMemRef(copyOp.getTarget().getType())) {
864 results.
add<FoldEmptyCopy, FoldSelfCopy>(context);
871 for (
OpOperand &operand : op->getOpOperands()) {
873 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
874 operand.set(castOp.getOperand());
881LogicalResult CopyOp::fold(FoldAdaptor adaptor,
882 SmallVectorImpl<OpFoldResult> &results) {
892LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
893 SmallVectorImpl<OpFoldResult> &results) {
902void DimOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
903 setNameFn(getResult(),
"dim");
906void DimOp::build(OpBuilder &builder, OperationState &
result, Value source,
908 auto loc =
result.location;
910 build(builder,
result, source, indexValue);
913std::optional<int64_t> DimOp::getConstantIndex() {
922 auto rankedSourceType = dyn_cast<MemRefType>(getSource().
getType());
923 if (!rankedSourceType)
926 if (rankedSourceType.getRank() <= constantIndex)
932void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
934 setResultRange(getResult(),
943 std::map<int64_t, unsigned> numOccurences;
944 for (
auto val : vals)
945 numOccurences[val]++;
946 return numOccurences;
956static FailureOr<llvm::SmallBitVector>
958 MemRefType reducedType,
960 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
961 if (rankReduction <= 0)
962 return llvm::SmallBitVector(originalType.getRank());
966 for (
const auto &it : llvm::enumerate(sizes)) {
968 sourceSizes[it.index()] = *cst;
970 sourceSizes[it.index()] = ShapedType::kDynamic;
974 llvm::SmallBitVector usedSourceDims(originalType.getRank());
976 for (
int64_t resultSize : resultSizes) {
977 bool matched =
false;
978 for (
int64_t j = startJ;
j < originalType.getRank(); ++
j) {
979 if (sourceSizes[
j] == resultSize) {
980 usedSourceDims.set(
j);
990 llvm::SmallBitVector unusedDims(originalType.getRank());
991 for (
int64_t i = 0; i < originalType.getRank(); ++i)
992 if (!usedSourceDims.test(i))
1005 MemRefType originalType, MemRefType reducedType,
1007 llvm::SmallBitVector unusedDims) {
1015 std::map<int64_t, unsigned> currUnaccountedStrides =
1017 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1019 for (
size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1020 if (!unusedDims.test(dim))
1022 int64_t originalStride = originalStrides[dim];
1023 if (currUnaccountedStrides[originalStride] >
1024 candidateStridesNumOccurences[originalStride]) {
1026 currUnaccountedStrides[originalStride]--;
1029 if (currUnaccountedStrides[originalStride] ==
1030 candidateStridesNumOccurences[originalStride]) {
1032 unusedDims.reset(dim);
1035 if (currUnaccountedStrides[originalStride] <
1036 candidateStridesNumOccurences[originalStride]) {
1042 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1043 originalType.getRank())
1055static FailureOr<llvm::SmallBitVector>
1058 llvm::SmallBitVector unusedDims(originalType.getRank());
1059 if (originalType.getRank() == reducedType.getRank())
1062 for (
const auto &dim : llvm::enumerate(sizes))
1063 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1064 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1065 unusedDims.set(dim.index());
1069 if (
static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1070 originalType.getRank())
1074 int64_t originalOffset, candidateOffset;
1076 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1078 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1086 if (strides.size() <= 1)
1088 return llvm::any_of(strides.drop_back(),
1089 [](
int64_t s) { return !ShapedType::isDynamic(s); });
1091 if (hasNonTrivialStaticStride(originalStrides) ||
1092 hasNonTrivialStaticStride(candidateStrides)) {
1093 FailureOr<llvm::SmallBitVector> strideBased =
1096 candidateStrides, unusedDims);
1097 if (succeeded(strideBased))
1098 return *strideBased;
1104llvm::SmallBitVector SubViewOp::getDroppedDims() {
1105 MemRefType sourceType = getSourceType();
1106 MemRefType resultType =
getType();
1107 FailureOr<llvm::SmallBitVector> unusedDims =
1109 assert(succeeded(unusedDims) &&
"unable to find unused dims of subview");
1113OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1115 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1120 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().
getType());
1126 int64_t indexVal = index.getInt();
1127 if (indexVal < 0 || indexVal >= memrefType.getRank())
1131 if (!memrefType.isDynamicDim(index.getInt())) {
1133 return builder.
getIndexAttr(memrefType.getShape()[index.getInt()]);
1137 unsigned unsignedIndex = index.getValue().getZExtValue();
1140 Operation *definingOp = getSource().getDefiningOp();
1142 if (
auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1143 return *(alloc.getDynamicSizes().begin() +
1144 memrefType.getDynamicDimIndex(unsignedIndex));
1146 if (
auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1147 return *(alloca.getDynamicSizes().begin() +
1148 memrefType.getDynamicDimIndex(unsignedIndex));
1150 if (
auto view = dyn_cast_or_null<ViewOp>(definingOp))
1151 return *(view.getDynamicSizes().begin() +
1152 memrefType.getDynamicDimIndex(unsignedIndex));
1154 if (
auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1159 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1160 unsigned dynamicIdx = 0;
1161 for (OpFoldResult size : subview.getMixedSizes()) {
1162 if (llvm::isa<Attribute>(size))
1164 if (dynamicIdx == dynamicResultDimIdx)
1181struct DimOfMemRefReshape :
public OpRewritePattern<DimOp> {
1182 using OpRewritePattern<DimOp>::OpRewritePattern;
1184 LogicalResult matchAndRewrite(DimOp dim,
1185 PatternRewriter &rewriter)
const override {
1186 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1190 dim,
"Dim op is not defined by a reshape op.");
1201 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1202 if (
auto *definingOp = dim.getIndex().getDefiningOp()) {
1203 if (reshape->isBeforeInBlock(definingOp)) {
1206 "dim.getIndex is not defined before reshape in the same block.");
1211 else if (dim->getBlock() != reshape->getBlock() &&
1212 !dim.getIndex().getParentRegion()->isProperAncestor(
1213 reshape->getParentRegion())) {
1218 dim,
"dim.getIndex does not dominate reshape.");
1224 Location loc = dim.getLoc();
1226 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1227 if (
load.getType() != dim.getType())
1228 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(),
load);
1236void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1237 MLIRContext *context) {
1238 results.
add<DimOfMemRefReshape>(context);
1245void DmaStartOp::build(OpBuilder &builder, OperationState &
result,
1246 Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
1248 Value tagMemRef,
ValueRange tagIndices, Value stride,
1249 Value elementsPerStride) {
1250 result.addOperands(srcMemRef);
1251 result.addOperands(srcIndices);
1252 result.addOperands(destMemRef);
1253 result.addOperands(destIndices);
1254 result.addOperands({numElements, tagMemRef});
1255 result.addOperands(tagIndices);
1257 result.addOperands({stride, elementsPerStride});
1260void DmaStartOp::print(OpAsmPrinter &p) {
1261 p <<
" " << getSrcMemRef() <<
'[' << getSrcIndices() <<
"], "
1262 << getDstMemRef() <<
'[' << getDstIndices() <<
"], " <<
getNumElements()
1263 <<
", " << getTagMemRef() <<
'[' << getTagIndices() <<
']';
1265 p <<
", " << getStride() <<
", " << getNumElementsPerStride();
1268 p <<
" : " << getSrcMemRef().getType() <<
", " << getDstMemRef().getType()
1269 <<
", " << getTagMemRef().getType();
1280ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &
result) {
1281 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1283 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1284 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1285 OpAsmParser::UnresolvedOperand numElementsInfo;
1286 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1287 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1288 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1290 SmallVector<Type, 3> types;
1310 bool isStrided = strideInfo.size() == 2;
1311 if (!strideInfo.empty() && !isStrided) {
1313 "expected two stride related operands");
1318 if (types.size() != 3)
1340LogicalResult DmaStartOp::verify() {
1341 unsigned numOperands = getNumOperands();
1345 if (numOperands < 4)
1346 return emitOpError(
"expected at least 4 operands");
1351 if (!llvm::isa<MemRefType>(getSrcMemRef().
getType()))
1352 return emitOpError(
"expected source to be of memref type");
1353 if (numOperands < getSrcMemRefRank() + 4)
1354 return emitOpError() <<
"expected at least " << getSrcMemRefRank() + 4
1356 if (!getSrcIndices().empty() &&
1357 !llvm::all_of(getSrcIndices().getTypes(),
1358 [](Type t) {
return t.
isIndex(); }))
1359 return emitOpError(
"expected source indices to be of index type");
1362 if (!llvm::isa<MemRefType>(getDstMemRef().
getType()))
1363 return emitOpError(
"expected destination to be of memref type");
1364 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1365 if (numOperands < numExpectedOperands)
1366 return emitOpError() <<
"expected at least " << numExpectedOperands
1368 if (!getDstIndices().empty() &&
1369 !llvm::all_of(getDstIndices().getTypes(),
1370 [](Type t) {
return t.
isIndex(); }))
1371 return emitOpError(
"expected destination indices to be of index type");
1375 return emitOpError(
"expected num elements to be of index type");
1378 if (!llvm::isa<MemRefType>(getTagMemRef().
getType()))
1379 return emitOpError(
"expected tag to be of memref type");
1380 numExpectedOperands += getTagMemRefRank();
1381 if (numOperands < numExpectedOperands)
1382 return emitOpError() <<
"expected at least " << numExpectedOperands
1384 if (!getTagIndices().empty() &&
1385 !llvm::all_of(getTagIndices().getTypes(),
1386 [](Type t) {
return t.
isIndex(); }))
1387 return emitOpError(
"expected tag indices to be of index type");
1391 if (numOperands != numExpectedOperands &&
1392 numOperands != numExpectedOperands + 2)
1393 return emitOpError(
"incorrect number of operands");
1397 if (!getStride().
getType().isIndex() ||
1398 !getNumElementsPerStride().
getType().isIndex())
1400 "expected stride and num elements per stride to be of type index");
1406LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1407 SmallVectorImpl<OpFoldResult> &results) {
1412void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
1416 SmallVector<Value> newOperands;
1417 newOperands.push_back(newSrc);
1418 llvm::append_range(newOperands, newSrcIndices);
1419 newOperands.push_back(newDst);
1420 llvm::append_range(newOperands, newDstIndices);
1422 newOperands.push_back(getTagMemRef());
1423 llvm::append_range(newOperands, getTagIndices());
1425 newOperands.push_back(getStride());
1426 newOperands.push_back(getNumElementsPerStride());
1429 rewriter.
modifyOpInPlace(*
this, [&]() { (*this)->setOperands(newOperands); });
1436LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1437 SmallVectorImpl<OpFoldResult> &results) {
1442LogicalResult DmaWaitOp::verify() {
1444 unsigned numTagIndices = getTagIndices().size();
1445 unsigned tagMemRefRank = getTagMemRefRank();
1446 if (numTagIndices != tagMemRefRank)
1447 return emitOpError() <<
"expected tagIndices to have the same number of "
1448 "elements as the tagMemRef rank, expected "
1449 << tagMemRefRank <<
", but got " << numTagIndices;
1457void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1459 setNameFn(getResult(),
"intptr");
1468LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1469 MLIRContext *context, std::optional<Location> location,
1470 ExtractStridedMetadataOp::Adaptor adaptor,
1471 SmallVectorImpl<Type> &inferredReturnTypes) {
1472 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1476 unsigned sourceRank = sourceType.getRank();
1477 IndexType indexType = IndexType::get(context);
1479 MemRefType::get({}, sourceType.getElementType(),
1480 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1482 inferredReturnTypes.push_back(memrefType);
1484 inferredReturnTypes.push_back(indexType);
1486 for (
unsigned i = 0; i < sourceRank * 2; ++i)
1487 inferredReturnTypes.push_back(indexType);
1491void ExtractStridedMetadataOp::getAsmResultNames(
1493 setNameFn(getBaseBuffer(),
"base_buffer");
1494 setNameFn(getOffset(),
"offset");
1497 if (!getSizes().empty()) {
1498 setNameFn(getSizes().front(),
"sizes");
1499 setNameFn(getStrides().front(),
"strides");
1506template <
typename Container>
1510 assert(values.size() == maybeConstants.size() &&
1511 " expected values and maybeConstants of the same size");
1512 bool atLeastOneReplacement =
false;
1513 for (
auto [maybeConstant,
result] : llvm::zip(maybeConstants, values)) {
1518 assert(isa<Attribute>(maybeConstant) &&
1519 "The constified value should be either unchanged (i.e., == result) "
1523 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1528 atLeastOneReplacement =
true;
1531 return atLeastOneReplacement;
1535ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1536 SmallVectorImpl<OpFoldResult> &results) {
1537 OpBuilder builder(*
this);
1541 getConstifiedMixedOffset());
1543 getConstifiedMixedSizes());
1545 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1548 if (
auto prev = getSource().getDefiningOp<CastOp>())
1549 if (isa<MemRefType>(prev.getSource().getType())) {
1550 getSourceMutable().assign(prev.getSource());
1551 atLeastOneReplacement =
true;
1554 return success(atLeastOneReplacement);
1557SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1563SmallVector<OpFoldResult>
1564ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1566 SmallVector<int64_t> staticValues;
1568 LogicalResult status =
1569 getSource().getType().getStridesAndOffset(staticValues, unused);
1571 assert(succeeded(status) &&
"could not get strides from type");
1576OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1578 SmallVector<OpFoldResult> values(1, offsetOfr);
1579 SmallVector<int64_t> staticValues, unused;
1581 LogicalResult status =
1582 getSource().getType().getStridesAndOffset(unused, offset);
1584 assert(succeeded(status) &&
"could not get offset from type");
1585 staticValues.push_back(offset);
1594void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &
result,
1596 OpBuilder::InsertionGuard g(builder);
1597 result.addOperands(memref);
1600 if (
auto memrefType = llvm::dyn_cast<MemRefType>(memref.
getType())) {
1601 Type elementType = memrefType.getElementType();
1602 result.addTypes(elementType);
1604 Region *bodyRegion =
result.addRegion();
1610LogicalResult GenericAtomicRMWOp::verify() {
1611 auto &body = getRegion();
1612 if (body.getNumArguments() != 1)
1613 return emitOpError(
"expected single number of entry block arguments");
1615 if (getResult().
getType() != body.getArgument(0).getType())
1616 return emitOpError(
"expected block argument of the same type result type");
1619 body.walk([&](Operation *nestedOp) {
1623 "body of 'memref.generic_atomic_rmw' should contain "
1624 "only operations with no side effects");
1631ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1632 OperationState &
result) {
1633 OpAsmParser::UnresolvedOperand memref;
1635 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1645 Region *body =
result.addRegion();
1653void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1654 p <<
' ' << getMemref() <<
"[" <<
getIndices()
1655 <<
"] : " << getMemref().
getType() <<
' ';
1664std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
1665 RewriterBase &rewriter, Value newMemref,
ValueRange newIndices) {
1667 getMemrefMutable().assign(newMemref);
1668 getIndicesMutable().assign(newIndices);
1670 return std::nullopt;
1677LogicalResult AtomicYieldOp::verify() {
1678 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1679 Type resultType = getResult().getType();
1680 if (parentType != resultType)
1681 return emitOpError() <<
"types mismatch between yield op: " << resultType
1682 <<
" and its parent: " << parentType;
1694 if (!op.isExternal()) {
1696 if (op.isUninitialized())
1697 p <<
"uninitialized";
1710 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1711 if (!memrefType || !memrefType.hasStaticShape())
1713 <<
"type should be static shaped memref, but got " << type;
1714 typeAttr = TypeAttr::get(type);
1720 initialValue = UnitAttr::get(parser.
getContext());
1727 if (!llvm::isa<ElementsAttr>(initialValue))
1729 <<
"initial value should be a unit or elements attribute";
1733LogicalResult GlobalOp::verify() {
1734 auto memrefType = llvm::dyn_cast<MemRefType>(
getType());
1735 if (!memrefType || !memrefType.hasStaticShape())
1736 return emitOpError(
"type should be static shaped memref, but got ")
1741 if (getInitialValue().has_value()) {
1742 Attribute initValue = getInitialValue().value();
1743 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1744 return emitOpError(
"initial value should be a unit or elements "
1745 "attribute, but got ")
1750 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1752 auto initElementType =
1753 cast<TensorType>(elementsAttr.getType()).getElementType();
1754 auto memrefElementType = memrefType.getElementType();
1756 if (initElementType != memrefElementType)
1757 return emitOpError(
"initial value element expected to be of type ")
1758 << memrefElementType <<
", but was of type " << initElementType;
1763 auto initShape = elementsAttr.getShapedType().getShape();
1764 auto memrefShape = memrefType.getShape();
1765 if (initShape != memrefShape)
1766 return emitOpError(
"initial value shape expected to be ")
1767 << memrefShape <<
" but was " << initShape;
1775ElementsAttr GlobalOp::getConstantInitValue() {
1776 auto initVal = getInitialValue();
1777 if (getConstant() && initVal.has_value())
1778 return llvm::cast<ElementsAttr>(initVal.value());
1787GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1794 << getName() <<
"' does not reference a valid global memref";
1796 Type resultType = getResult().getType();
1797 if (global.getType() != resultType)
1799 << resultType <<
" does not match type " << global.getType()
1800 <<
" of the global memref @" << getName();
1808OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1814 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1820 getGlobalOp, getGlobalOp.getNameAttr());
1825 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1829 return splatAttr.getSplatValue<Attribute>();
1834std::optional<SmallVector<Value>>
1835LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1838 getMemrefMutable().assign(newMemref);
1839 getIndicesMutable().assign(newIndices);
1841 return std::nullopt;
1844FailureOr<std::optional<SmallVector<Value>>>
1845LoadOp::bubbleDownCasts(OpBuilder &builder) {
1854void MemorySpaceCastOp::getAsmResultNames(
1856 setNameFn(getResult(),
"memspacecast");
1860 if (inputs.size() != 1 || outputs.size() != 1)
1862 Type a = inputs.front(),
b = outputs.front();
1863 auto aT = llvm::dyn_cast<MemRefType>(a);
1864 auto bT = llvm::dyn_cast<MemRefType>(
b);
1866 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1867 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(
b);
1870 if (aT.getElementType() != bT.getElementType())
1872 if (aT.getLayout() != bT.getLayout())
1874 if (aT.getShape() != bT.getShape())
1879 return uaT.getElementType() == ubT.getElementType();
1884OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1887 if (
auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1888 getSourceMutable().assign(parentCast.getSource());
1902bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1903 PtrLikeTypeInterface src) {
1904 return isa<BaseMemRefType>(tgt) &&
1905 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1908MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1909 OpBuilder &
b, PtrLikeTypeInterface tgt,
1911 assert(isValidMemorySpaceCast(tgt, src.getType()) &&
"invalid arguments");
1912 return MemorySpaceCastOp::create(
b, getLoc(), tgt, src);
1916bool MemorySpaceCastOp::isSourcePromotable() {
1917 return getDest().getType().getMemorySpace() ==
nullptr;
1924void PrefetchOp::print(OpAsmPrinter &p) {
1925 p <<
" " << getMemref() <<
'[';
1927 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read");
1928 p <<
", locality<" << getLocalityHint();
1929 p <<
">, " << (getIsDataCache() ?
"data" :
"instr");
1931 (*this)->getAttrs(),
1932 {
"localityHint",
"isWrite",
"isDataCache"});
1936ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &
result) {
1937 OpAsmParser::UnresolvedOperand memrefInfo;
1938 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1939 IntegerAttr localityHint;
1941 StringRef readOrWrite, cacheType;
1958 if (readOrWrite !=
"read" && readOrWrite !=
"write")
1960 "rw specifier has to be 'read' or 'write'");
1961 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1964 if (cacheType !=
"data" && cacheType !=
"instr")
1966 "cache type has to be 'data' or 'instr'");
1968 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1974LogicalResult PrefetchOp::verify() {
1981LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1989std::optional<SmallVector<Value>>
1990PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1993 getMemrefMutable().assign(newMemref);
1994 getIndicesMutable().assign(newIndices);
1996 return std::nullopt;
2003OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
2005 auto type = getOperand().getType();
2006 auto shapedType = llvm::dyn_cast<ShapedType>(type);
2007 if (shapedType && shapedType.hasRank())
2008 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
2009 return IntegerAttr();
2016void ReinterpretCastOp::getAsmResultNames(
2018 setNameFn(getResult(),
"reinterpret_cast");
2024void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2025 MemRefType resultType, Value source,
2026 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
2027 ArrayRef<OpFoldResult> strides,
2028 ArrayRef<NamedAttribute> attrs) {
2029 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2030 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2034 result.addAttributes(attrs);
2035 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2036 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2037 b.getDenseI64ArrayAttr(staticSizes),
2038 b.getDenseI64ArrayAttr(staticStrides));
2041void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2042 Value source, OpFoldResult offset,
2043 ArrayRef<OpFoldResult> sizes,
2044 ArrayRef<OpFoldResult> strides,
2045 ArrayRef<NamedAttribute> attrs) {
2046 auto sourceType = cast<BaseMemRefType>(source.
getType());
2047 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2048 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2052 auto stridedLayout = StridedLayoutAttr::get(
2053 b.getContext(), staticOffsets.front(), staticStrides);
2054 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2055 stridedLayout, sourceType.getMemorySpace());
2056 build(
b,
result, resultType, source, offset, sizes, strides, attrs);
2059void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2060 MemRefType resultType, Value source,
2061 int64_t offset, ArrayRef<int64_t> sizes,
2062 ArrayRef<int64_t> strides,
2063 ArrayRef<NamedAttribute> attrs) {
2064 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2065 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
2066 SmallVector<OpFoldResult> strideValues =
2067 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2068 return b.getI64IntegerAttr(v);
2070 build(
b,
result, resultType, source,
b.getI64IntegerAttr(offset), sizeValues,
2071 strideValues, attrs);
2074void ReinterpretCastOp::build(OpBuilder &
b, OperationState &
result,
2075 MemRefType resultType, Value source, Value offset,
2077 ArrayRef<NamedAttribute> attrs) {
2078 SmallVector<OpFoldResult> sizeValues =
2079 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2080 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2081 strides, [](Value v) -> OpFoldResult {
return v; });
2082 build(
b,
result, resultType, source, offset, sizeValues, strideValues, attrs);
2087LogicalResult ReinterpretCastOp::verify() {
2089 auto srcType = llvm::cast<BaseMemRefType>(getSource().
getType());
2090 auto resultType = llvm::cast<MemRefType>(
getType());
2091 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2092 return emitError(
"different memory spaces specified for source type ")
2093 << srcType <<
" and result memref type " << resultType;
2099 for (
auto [idx, resultSize, expectedSize] :
2100 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2101 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2102 return emitError(
"expected result type with size = ")
2103 << (ShapedType::isDynamic(expectedSize)
2104 ? std::string(
"dynamic")
2105 : std::to_string(expectedSize))
2106 <<
" instead of " << resultSize <<
" in dim = " << idx;
2112 int64_t resultOffset;
2113 SmallVector<int64_t, 4> resultStrides;
2114 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2115 return emitError(
"expected result type to have strided layout but found ")
2119 int64_t expectedOffset = getStaticOffsets().front();
2120 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2121 return emitError(
"expected result type with offset = ")
2122 << (ShapedType::isDynamic(expectedOffset)
2123 ? std::string(
"dynamic")
2124 : std::to_string(expectedOffset))
2125 <<
" instead of " << resultOffset;
2128 for (
auto [idx, resultStride, expectedStride] :
2129 llvm::enumerate(resultStrides, getStaticStrides())) {
2130 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2131 return emitError(
"expected result type with stride = ")
2132 << (ShapedType::isDynamic(expectedStride)
2133 ? std::string(
"dynamic")
2134 : std::to_string(expectedStride))
2135 <<
" instead of " << resultStride <<
" in dim = " << idx;
2141OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) {
2142 Value src = getSource();
2143 auto getPrevSrc = [&]() -> Value {
2146 return prev.getSource();
2150 return prev.getSource();
2156 return prev.getSource();
2161 if (
auto prevSrc = getPrevSrc()) {
2162 getSourceMutable().assign(prevSrc);
2175SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2181SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2182 SmallVector<OpFoldResult> values = getMixedStrides();
2183 SmallVector<int64_t> staticValues;
2185 LogicalResult status =
getType().getStridesAndOffset(staticValues, unused);
2187 assert(succeeded(status) &&
"could not get strides from type");
2192OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2193 SmallVector<OpFoldResult> values = getMixedOffsets();
2194 assert(values.size() == 1 &&
2195 "reinterpret_cast must have one and only one offset");
2196 SmallVector<int64_t> staticValues, unused;
2198 LogicalResult status =
getType().getStridesAndOffset(unused, offset);
2200 assert(succeeded(status) &&
"could not get offset from type");
2201 staticValues.push_back(offset);
2249struct ReinterpretCastOpExtractStridedMetadataFolder
2250 :
public OpRewritePattern<ReinterpretCastOp> {
2252 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2254 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2255 PatternRewriter &rewriter)
const override {
2256 auto extractStridedMetadata =
2257 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2258 if (!extractStridedMetadata)
2263 auto isReinterpretCastNoop = [&]() ->
bool {
2265 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2266 op.getConstifiedMixedStrides()))
2270 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2271 op.getConstifiedMixedSizes()))
2275 assert(op.getMixedOffsets().size() == 1 &&
2276 "reinterpret_cast with more than one offset should have been "
2277 "rejected by the verifier");
2278 return extractStridedMetadata.getConstifiedMixedOffset() ==
2279 op.getConstifiedMixedOffset();
2282 if (!isReinterpretCastNoop()) {
2299 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2309 Type srcTy = extractStridedMetadata.getSource().getType();
2310 if (srcTy == op.getResult().getType())
2311 rewriter.
replaceOp(op, extractStridedMetadata.getSource());
2314 extractStridedMetadata.getSource());
2320struct ReinterpretCastOpConstantFolder
2321 :
public OpRewritePattern<ReinterpretCastOp> {
2323 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2325 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2326 PatternRewriter &rewriter)
const override {
2327 unsigned srcStaticCount = llvm::count_if(
2328 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2329 op.getMixedStrides()),
2330 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2332 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2333 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2334 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2341 offsets[0] = op.getMixedOffsets()[0];
2346 for (
auto it : llvm::zip(op.getMixedSizes(), sizes)) {
2347 auto &srcSizeOfr = std::get<0>(it);
2348 auto &sizeOfr = std::get<1>(it);
2351 sizeOfr = srcSizeOfr;
2358 if (srcStaticCount ==
2359 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2360 [](OpFoldResult ofr) {
return isa<Attribute>(ofr); }))
2363 auto newReinterpretCast = ReinterpretCastOp::create(
2364 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2372void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2373 MLIRContext *context) {
2374 results.
add<ReinterpretCastOpExtractStridedMetadataFolder,
2375 ReinterpretCastOpConstantFolder>(context);
2378FailureOr<std::optional<SmallVector<Value>>>
2379ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2387void CollapseShapeOp::getAsmResultNames(
2389 setNameFn(getResult(),
"collapse_shape");
2392void ExpandShapeOp::getAsmResultNames(
2394 setNameFn(getResult(),
"expand_shape");
2397LogicalResult ExpandShapeOp::reifyResultShapes(
2399 reifiedResultShapes = {
2400 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2413 bool allowMultipleDynamicDimsPerGroup) {
2415 if (collapsedShape.size() != reassociation.size())
2416 return op->
emitOpError(
"invalid number of reassociation groups: found ")
2417 << reassociation.size() <<
", expected " << collapsedShape.size();
2422 for (
const auto &it : llvm::enumerate(reassociation)) {
2424 int64_t collapsedDim = it.index();
2426 bool foundDynamic =
false;
2427 for (
int64_t expandedDim : group) {
2428 if (expandedDim != nextDim++)
2429 return op->
emitOpError(
"reassociation indices must be contiguous");
2431 if (expandedDim >=
static_cast<int64_t>(expandedShape.size()))
2433 << expandedDim <<
" is out of bounds";
2436 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2437 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2439 "at most one dimension in a reassociation group may be dynamic");
2440 foundDynamic =
true;
2445 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2448 <<
") must be dynamic if and only if reassociation group is "
2453 if (!foundDynamic) {
2455 for (
int64_t expandedDim : group)
2456 groupSize *= expandedShape[expandedDim];
2457 if (groupSize != collapsedShape[collapsedDim])
2459 << collapsedShape[collapsedDim]
2460 <<
") must equal reassociation group size (" << groupSize <<
")";
2464 if (collapsedShape.empty()) {
2466 for (
int64_t d : expandedShape)
2469 "rank 0 memrefs can only be extended/collapsed with/from ones");
2470 }
else if (nextDim !=
static_cast<int64_t>(expandedShape.size())) {
2474 << expandedShape.size()
2475 <<
") inconsistent with number of reassociation indices (" << nextDim
2482SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2486SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2488 getReassociationIndices());
2491SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2495SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2497 getReassociationIndices());
2502static FailureOr<StridedLayoutAttr>
2507 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2509 assert(srcStrides.size() == reassociation.size() &&
"invalid reassociation");
2524 reverseResultStrides.reserve(resultShape.size());
2525 unsigned shapeIndex = resultShape.size() - 1;
2526 for (
auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2528 int64_t currentStrideToExpand = std::get<1>(it);
2529 for (
unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2530 reverseResultStrides.push_back(currentStrideToExpand);
2531 currentStrideToExpand =
2537 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2538 resultStrides.resize(resultShape.size(), 1);
2539 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2542FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2543 MemRefType srcType, ArrayRef<int64_t> resultShape,
2544 ArrayRef<ReassociationIndices> reassociation) {
2545 if (srcType.getLayout().isIdentity()) {
2548 MemRefLayoutAttrInterface layout;
2549 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2550 srcType.getMemorySpace());
2554 FailureOr<StridedLayoutAttr> computedLayout =
2556 if (
failed(computedLayout))
2558 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2559 srcType.getMemorySpace());
2562FailureOr<SmallVector<OpFoldResult>>
2563ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
2564 MemRefType expandedType,
2565 ArrayRef<ReassociationIndices> reassociation,
2566 ArrayRef<OpFoldResult> inputShape) {
2567 std::optional<SmallVector<OpFoldResult>> outputShape =
2572 return *outputShape;
2575void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2576 Type resultType, Value src,
2577 ArrayRef<ReassociationIndices> reassociation,
2578 ArrayRef<OpFoldResult> outputShape) {
2579 auto [staticOutputShape, dynamicOutputShape] =
2581 build(builder,
result, llvm::cast<MemRefType>(resultType), src,
2583 dynamicOutputShape, staticOutputShape);
2586void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2587 Type resultType, Value src,
2588 ArrayRef<ReassociationIndices> reassociation) {
2589 SmallVector<OpFoldResult> inputShape =
2591 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2592 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2593 builder,
result.location, memrefResultTy, reassociation, inputShape);
2596 assert(succeeded(outputShape) &&
"unable to infer output shape");
2597 build(builder,
result, memrefResultTy, src, reassociation, *outputShape);
2600void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2601 ArrayRef<int64_t> resultShape, Value src,
2602 ArrayRef<ReassociationIndices> reassociation) {
2604 auto srcType = llvm::cast<MemRefType>(src.
getType());
2605 FailureOr<MemRefType> resultType =
2606 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2609 assert(succeeded(resultType) &&
"could not compute layout");
2610 build(builder,
result, *resultType, src, reassociation);
2613void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
2614 ArrayRef<int64_t> resultShape, Value src,
2615 ArrayRef<ReassociationIndices> reassociation,
2616 ArrayRef<OpFoldResult> outputShape) {
2618 auto srcType = llvm::cast<MemRefType>(src.
getType());
2619 FailureOr<MemRefType> resultType =
2620 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2623 assert(succeeded(resultType) &&
"could not compute layout");
2624 build(builder,
result, *resultType, src, reassociation, outputShape);
2627LogicalResult ExpandShapeOp::verify() {
2628 MemRefType srcType = getSrcType();
2629 MemRefType resultType = getResultType();
2631 if (srcType.getRank() > resultType.getRank()) {
2632 auto r0 = srcType.getRank();
2633 auto r1 = resultType.getRank();
2635 << r0 <<
" and result rank " << r1 <<
". This is not an expansion ("
2636 << r0 <<
" > " << r1 <<
").";
2641 resultType.getShape(),
2642 getReassociationIndices(),
2647 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2648 srcType, resultType.getShape(), getReassociationIndices());
2649 if (
failed(expectedResultType))
2653 if (*expectedResultType != resultType)
2654 return emitOpError(
"expected expanded type to be ")
2655 << *expectedResultType <<
" but found " << resultType;
2657 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2658 return emitOpError(
"expected number of static shape bounds to be equal to "
2659 "the output rank (")
2660 << resultType.getRank() <<
") but found "
2661 << getStaticOutputShape().size() <<
" inputs instead";
2663 if ((int64_t)getOutputShape().size() !=
2664 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2665 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2666 "static_output_shape: static_output_shape has ")
2667 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2668 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2679 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2680 for (
auto [pos, shape] : llvm::enumerate(resShape)) {
2681 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2682 return emitOpError(
"invalid output shape provided at pos ") << pos;
2695 auto cast = op.getSrc().getDefiningOp<CastOp>();
2699 if (!CastOp::canFoldIntoConsumerOp(cast))
2707 for (
auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2709 if (!sizeOpt.has_value()) {
2710 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2714 newOutputShapeSizes.push_back(sizeOpt.value());
2715 newOutputShape[dimIdx] = rewriter.
getIndexAttr(sizeOpt.value());
2718 Value castSource = cast.getSource();
2719 auto castSourceType = llvm::cast<MemRefType>(castSource.
getType());
2721 op.getReassociationIndices();
2722 for (
auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2723 auto newOutputShapeSizesSlice =
2724 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2725 bool newOutputDynamic =
2726 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2727 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2729 op,
"folding cast will result in changing dynamicity in "
2730 "reassociation group");
2733 FailureOr<MemRefType> newResultTypeOrFailure =
2734 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2735 reassociationIndices);
2737 if (failed(newResultTypeOrFailure))
2739 op,
"could not compute new expanded type after folding cast");
2741 if (*newResultTypeOrFailure == op.getResultType()) {
2743 op, [&]() { op.getSrcMutable().assign(castSource); });
2745 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2746 *newResultTypeOrFailure, castSource,
2747 reassociationIndices, newOutputShape);
2754void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2755 MLIRContext *context) {
2757 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2758 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2759 ExpandShapeOpMemRefCastFolder>(context);
2762FailureOr<std::optional<SmallVector<Value>>>
2763ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2774static FailureOr<StridedLayoutAttr>
2777 bool strict =
false) {
2780 auto srcShape = srcType.getShape();
2781 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2790 resultStrides.reserve(reassociation.size());
2793 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2794 ref = ref.drop_back();
2795 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2796 resultStrides.push_back(srcStrides[ref.back()]);
2802 resultStrides.push_back(ShapedType::kDynamic);
2807 unsigned resultStrideIndex = resultStrides.size() - 1;
2811 for (
int64_t idx : llvm::reverse(trailingReassocs)) {
2816 if (srcShape[idx - 1] == 1)
2828 if (strict && (stride.saturated || srcStride.saturated))
2831 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2835 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2838bool CollapseShapeOp::isGuaranteedCollapsible(
2839 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2841 if (srcType.getLayout().isIdentity())
2848MemRefType CollapseShapeOp::computeCollapsedType(
2849 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2850 SmallVector<int64_t> resultShape;
2851 resultShape.reserve(reassociation.size());
2854 for (int64_t srcDim : group)
2857 resultShape.push_back(groupSize.asInteger());
2860 if (srcType.getLayout().isIdentity()) {
2863 MemRefLayoutAttrInterface layout;
2864 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2865 srcType.getMemorySpace());
2871 FailureOr<StridedLayoutAttr> computedLayout =
2873 assert(succeeded(computedLayout) &&
2874 "invalid source layout map or collapsing non-contiguous dims");
2875 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2876 srcType.getMemorySpace());
2879void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2880 ArrayRef<ReassociationIndices> reassociation,
2881 ArrayRef<NamedAttribute> attrs) {
2882 auto srcType = llvm::cast<MemRefType>(src.
getType());
2883 MemRefType resultType =
2884 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2887 build(
b,
result, resultType, src, attrs);
2890LogicalResult CollapseShapeOp::verify() {
2891 MemRefType srcType = getSrcType();
2892 MemRefType resultType = getResultType();
2894 if (srcType.getRank() < resultType.getRank()) {
2895 auto r0 = srcType.getRank();
2896 auto r1 = resultType.getRank();
2898 << r0 <<
" and result rank " << r1 <<
". This is not a collapse ("
2899 << r0 <<
" < " << r1 <<
").";
2904 srcType.getShape(), getReassociationIndices(),
2909 MemRefType expectedResultType;
2910 if (srcType.getLayout().isIdentity()) {
2913 MemRefLayoutAttrInterface layout;
2914 expectedResultType =
2915 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2916 srcType.getMemorySpace());
2921 FailureOr<StridedLayoutAttr> computedLayout =
2923 if (
failed(computedLayout))
2925 "invalid source layout map or collapsing non-contiguous dims");
2926 expectedResultType =
2927 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2928 *computedLayout, srcType.getMemorySpace());
2931 if (expectedResultType != resultType)
2932 return emitOpError(
"expected collapsed type to be ")
2933 << expectedResultType <<
" but found " << resultType;
2945 auto cast = op.getOperand().getDefiningOp<CastOp>();
2949 if (!CastOp::canFoldIntoConsumerOp(cast))
2952 Type newResultType = CollapseShapeOp::computeCollapsedType(
2953 llvm::cast<MemRefType>(cast.getOperand().getType()),
2954 op.getReassociationIndices());
2956 if (newResultType == op.getResultType()) {
2958 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2961 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2962 op.getReassociationIndices());
2969void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2970 MLIRContext *context) {
2972 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2973 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2974 memref::DimOp, MemRefType>,
2975 CollapseShapeOpMemRefCastFolder>(context);
2978OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2980 adaptor.getOperands());
2983OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2985 adaptor.getOperands());
2988FailureOr<std::optional<SmallVector<Value>>>
2989CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2997void ReshapeOp::getAsmResultNames(
2999 setNameFn(getResult(),
"reshape");
3002LogicalResult ReshapeOp::verify() {
3003 Type operandType = getSource().getType();
3004 Type resultType = getResult().getType();
3006 Type operandElementType =
3007 llvm::cast<ShapedType>(operandType).getElementType();
3008 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
3009 if (operandElementType != resultElementType)
3010 return emitOpError(
"element types of source and destination memref "
3011 "types should be the same");
3013 if (
auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
3014 if (!operandMemRefType.getLayout().isIdentity())
3015 return emitOpError(
"source memref type should have identity affine map");
3019 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
3020 if (resultMemRefType) {
3021 if (!resultMemRefType.getLayout().isIdentity())
3022 return emitOpError(
"result memref type should have identity affine map");
3023 if (shapeSize == ShapedType::kDynamic)
3024 return emitOpError(
"cannot use shape operand with dynamic length to "
3025 "reshape to statically-ranked memref type");
3026 if (shapeSize != resultMemRefType.getRank())
3028 "length of shape operand differs from the result's memref rank");
3033FailureOr<std::optional<SmallVector<Value>>>
3034ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
3042LogicalResult StoreOp::fold(FoldAdaptor adaptor,
3043 SmallVectorImpl<OpFoldResult> &results) {
3050std::optional<SmallVector<Value>>
3051StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
3054 getMemrefMutable().assign(newMemref);
3055 getIndicesMutable().assign(newIndices);
3057 return std::nullopt;
3060FailureOr<std::optional<SmallVector<Value>>>
3061StoreOp::bubbleDownCasts(OpBuilder &builder) {
3070void SubViewOp::getAsmResultNames(
3072 setNameFn(getResult(),
"subview");
3078MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3079 ArrayRef<int64_t> staticOffsets,
3080 ArrayRef<int64_t> staticSizes,
3081 ArrayRef<int64_t> staticStrides) {
3082 unsigned rank = sourceMemRefType.getRank();
3084 assert(staticOffsets.size() == rank &&
"staticOffsets length mismatch");
3085 assert(staticSizes.size() == rank &&
"staticSizes length mismatch");
3086 assert(staticStrides.size() == rank &&
"staticStrides length mismatch");
3089 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3093 int64_t targetOffset = sourceOffset;
3094 for (
auto it : llvm::zip(staticOffsets, sourceStrides)) {
3095 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3104 SmallVector<int64_t, 4> targetStrides;
3105 targetStrides.reserve(staticOffsets.size());
3106 for (
auto it : llvm::zip(sourceStrides, staticStrides)) {
3107 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3114 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3115 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3116 targetOffset, targetStrides),
3117 sourceMemRefType.getMemorySpace());
3120MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3121 ArrayRef<OpFoldResult> offsets,
3122 ArrayRef<OpFoldResult> sizes,
3123 ArrayRef<OpFoldResult> strides) {
3124 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3125 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3135 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3136 staticSizes, staticStrides);
3139MemRefType SubViewOp::inferRankReducedResultType(
3140 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3141 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3142 ArrayRef<int64_t> strides) {
3143 MemRefType inferredType =
3144 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3145 assert(inferredType.getRank() >=
static_cast<int64_t
>(resultShape.size()) &&
3147 if (inferredType.getRank() ==
static_cast<int64_t
>(resultShape.size()))
3148 return inferredType;
3151 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3153 assert(dimsToProject.has_value() &&
"invalid rank reduction");
3156 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3157 SmallVector<int64_t> rankReducedStrides;
3158 rankReducedStrides.reserve(resultShape.size());
3159 for (
auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3160 if (!dimsToProject->contains(idx))
3161 rankReducedStrides.push_back(value);
3163 return MemRefType::get(resultShape, inferredType.getElementType(),
3164 StridedLayoutAttr::get(inferredLayout.getContext(),
3165 inferredLayout.getOffset(),
3166 rankReducedStrides),
3167 inferredType.getMemorySpace());
3170MemRefType SubViewOp::inferRankReducedResultType(
3171 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3172 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3173 ArrayRef<OpFoldResult> strides) {
3174 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3175 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3179 return SubViewOp::inferRankReducedResultType(
3180 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3186void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3187 MemRefType resultType, Value source,
3188 ArrayRef<OpFoldResult> offsets,
3189 ArrayRef<OpFoldResult> sizes,
3190 ArrayRef<OpFoldResult> strides,
3191 ArrayRef<NamedAttribute> attrs) {
3192 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3193 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3197 auto sourceMemRefType = llvm::cast<MemRefType>(source.
getType());
3200 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3201 staticSizes, staticStrides);
3203 result.addAttributes(attrs);
3204 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
3205 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3206 b.getDenseI64ArrayAttr(staticSizes),
3207 b.getDenseI64ArrayAttr(staticStrides));
3212void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3213 ArrayRef<OpFoldResult> offsets,
3214 ArrayRef<OpFoldResult> sizes,
3215 ArrayRef<OpFoldResult> strides,
3216 ArrayRef<NamedAttribute> attrs) {
3217 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3221void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3222 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3223 ArrayRef<int64_t> strides,
3224 ArrayRef<NamedAttribute> attrs) {
3225 SmallVector<OpFoldResult> offsetValues =
3226 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3227 return b.getI64IntegerAttr(v);
3229 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3230 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3231 SmallVector<OpFoldResult> strideValues =
3232 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3233 return b.getI64IntegerAttr(v);
3235 build(
b,
result, source, offsetValues, sizeValues, strideValues, attrs);
3240void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3241 MemRefType resultType, Value source,
3242 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3243 ArrayRef<int64_t> strides,
3244 ArrayRef<NamedAttribute> attrs) {
3245 SmallVector<OpFoldResult> offsetValues =
3246 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3247 return b.getI64IntegerAttr(v);
3249 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3250 sizes, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v); });
3251 SmallVector<OpFoldResult> strideValues =
3252 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3253 return b.getI64IntegerAttr(v);
3255 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues,
3261void SubViewOp::build(OpBuilder &
b, OperationState &
result,
3262 MemRefType resultType, Value source,
ValueRange offsets,
3264 ArrayRef<NamedAttribute> attrs) {
3265 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3266 offsets, [](Value v) -> OpFoldResult {
return v; });
3267 SmallVector<OpFoldResult> sizeValues =
3268 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3269 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3270 strides, [](Value v) -> OpFoldResult {
return v; });
3271 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
3275void SubViewOp::build(OpBuilder &
b, OperationState &
result, Value source,
3277 ArrayRef<NamedAttribute> attrs) {
3278 build(
b,
result, MemRefType(), source, offsets, sizes, strides, attrs);
3282Value SubViewOp::getViewSource() {
return getSource(); }
3289 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3290 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3291 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3298 const llvm::SmallBitVector &droppedDims) {
3299 assert(
size_t(t1.getRank()) == droppedDims.size() &&
3300 "incorrect number of bits");
3301 assert(
size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3302 "incorrect number of dropped dims");
3305 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3306 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3307 if (failed(res1) || failed(res2))
3309 for (
int64_t i = 0,
j = 0, e = t1.getRank(); i < e; ++i) {
3312 if (t1Strides[i] != t2Strides[
j])
3320 SubViewOp op,
Type expectedType) {
3321 auto memrefType = llvm::cast<ShapedType>(expectedType);
3326 return op->emitError(
"expected result rank to be smaller or equal to ")
3327 <<
"the source rank, but got " << op.getType();
3329 return op->emitError(
"expected result type to be ")
3331 <<
" or a rank-reduced version. (mismatch of result sizes), but got "
3334 return op->emitError(
"expected result element type to be ")
3335 << memrefType.getElementType() <<
", but got " << op.getType();
3337 return op->emitError(
3338 "expected result and source memory spaces to match, but got ")
3341 return op->emitError(
"expected result type to be ")
3343 <<
" or a rank-reduced version. (mismatch of result layout), but "
3347 llvm_unreachable(
"unexpected subview verification result");
3351LogicalResult SubViewOp::verify() {
3352 MemRefType baseType = getSourceType();
3353 MemRefType subViewType =
getType();
3354 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3355 ArrayRef<int64_t> staticSizes = getStaticSizes();
3356 ArrayRef<int64_t> staticStrides = getStaticStrides();
3359 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3360 return emitError(
"different memory spaces specified for base memref "
3362 << baseType <<
" and subview memref type " << subViewType;
3365 if (!baseType.isStrided())
3366 return emitError(
"base type ") << baseType <<
" is not strided";
3370 MemRefType expectedType = SubViewOp::inferResultType(
3371 baseType, staticOffsets, staticSizes, staticStrides);
3376 expectedType, subViewType);
3381 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3383 *
this, expectedType);
3388 *
this, expectedType);
3398 *
this, expectedType);
3403 *
this, expectedType);
3407 SliceBoundsVerificationResult boundsResult =
3409 staticStrides,
true);
3411 return getOperation()->emitError(boundsResult.
errorMessage);
3417 return os <<
"range " << range.
offset <<
":" << range.
size <<
":"
3426 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3427 assert(ranks[0] == ranks[1] &&
"expected offset and sizes of equal ranks");
3428 assert(ranks[1] == ranks[2] &&
"expected sizes and strides of equal ranks");
3430 unsigned rank = ranks[0];
3432 for (
unsigned idx = 0; idx < rank; ++idx) {
3434 op.isDynamicOffset(idx)
3435 ? op.getDynamicOffset(idx)
3438 op.isDynamicSize(idx)
3439 ? op.getDynamicSize(idx)
3442 op.isDynamicStride(idx)
3443 ? op.getDynamicStride(idx)
3445 res.emplace_back(
Range{offset, size, stride});
3458 MemRefType currentResultType, MemRefType currentSourceType,
3461 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3462 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3464 currentSourceType, currentResultType, mixedSizes);
3465 if (failed(unusedDims))
3468 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3470 unsigned numDimsAfterReduction =
3471 nonRankReducedType.getRank() - unusedDims->count();
3472 shape.reserve(numDimsAfterReduction);
3473 strides.reserve(numDimsAfterReduction);
3474 for (
const auto &[idx, size, stride] :
3475 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3476 nonRankReducedType.getShape(), layout.getStrides())) {
3477 if (unusedDims->test(idx))
3479 shape.push_back(size);
3480 strides.push_back(stride);
3483 return MemRefType::get(
shape, nonRankReducedType.getElementType(),
3484 StridedLayoutAttr::get(sourceType.getContext(),
3485 layout.getOffset(), strides),
3486 nonRankReducedType.getMemorySpace());
3491 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3492 unsigned rank = memrefType.getRank();
3496 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3497 targetShape, memrefType, offsets, sizes, strides);
3498 return b.createOrFold<memref::SubViewOp>(loc, targetType,
memref, offsets,
3505 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.
getType());
3506 assert(sourceMemrefType &&
"not a ranked memref type");
3507 auto sourceShape = sourceMemrefType.getShape();
3508 if (sourceShape.equals(desiredShape))
3510 auto maybeRankReductionMask =
3512 if (!maybeRankReductionMask)
3522 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3525 auto mixedOffsets = subViewOp.getMixedOffsets();
3526 auto mixedSizes = subViewOp.getMixedSizes();
3527 auto mixedStrides = subViewOp.getMixedStrides();
3532 return !intValue || intValue.value() != 0;
3539 return !intValue || intValue.value() != 1;
3545 for (
const auto &size : llvm::enumerate(mixedSizes)) {
3547 if (!intValue || *intValue != sourceShape[size.index()])
3571class SubViewOpMemRefCastFolder final :
public OpRewritePattern<SubViewOp> {
3573 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3575 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3576 PatternRewriter &rewriter)
const override {
3579 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3580 return matchPattern(operand, matchConstantIndex());
3584 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3588 if (!CastOp::canFoldIntoConsumerOp(castOp))
3596 subViewOp.getType(), subViewOp.getSourceType(),
3597 llvm::cast<MemRefType>(castOp.getSource().getType()),
3598 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3599 subViewOp.getMixedStrides());
3603 Value newSubView = SubViewOp::create(
3604 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3605 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3606 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3607 subViewOp.getStaticStrides());
3616class TrivialSubViewOpFolder final :
public OpRewritePattern<SubViewOp> {
3618 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3620 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3621 PatternRewriter &rewriter)
const override {
3624 if (subViewOp.getSourceType() == subViewOp.getType()) {
3625 rewriter.
replaceOp(subViewOp, subViewOp.getSource());
3629 subViewOp.getSource());
3641 MemRefType resTy = SubViewOp::inferResultType(
3642 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3645 MemRefType nonReducedType = resTy;
3648 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3649 if (droppedDims.none())
3650 return nonReducedType;
3653 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3658 for (
int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3659 if (droppedDims.test(i))
3661 targetStrides.push_back(nonReducedStrides[i]);
3662 targetShape.push_back(nonReducedType.getDimSize(i));
3665 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3666 StridedLayoutAttr::get(nonReducedType.getContext(),
3667 offset, targetStrides),
3668 nonReducedType.getMemorySpace());
3679void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3680 MLIRContext *context) {
3682 .
add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3683 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3684 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3687OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3688 MemRefType sourceMemrefType = getSource().getType();
3689 MemRefType resultMemrefType = getResult().getType();
3691 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3693 if (resultMemrefType == sourceMemrefType &&
3694 resultMemrefType.hasStaticShape() &&
3695 (!resultLayout || resultLayout.hasStaticLayout())) {
3696 return getViewSource();
3702 if (
auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3703 auto srcSizes = srcSubview.getMixedSizes();
3705 auto offsets = getMixedOffsets();
3707 auto strides = getMixedStrides();
3708 bool allStridesOne = llvm::all_of(strides,
isOneInteger);
3709 bool allSizesSame = llvm::equal(sizes, srcSizes);
3710 if (allOffsetsZero && allStridesOne && allSizesSame &&
3711 resultMemrefType == sourceMemrefType)
3712 return getViewSource();
3718FailureOr<std::optional<SmallVector<Value>>>
3719SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3723void SubViewOp::inferStridedMetadataRanges(
3724 ArrayRef<StridedMetadataRange> ranges,
GetIntRangeFn getIntRange,
3726 auto isUninitialized =
3727 +[](IntegerValueRange range) {
return range.isUninitialized(); };
3730 SmallVector<IntegerValueRange> offsetOperands =
3732 if (llvm::any_of(offsetOperands, isUninitialized))
3735 SmallVector<IntegerValueRange> sizeOperands =
3737 if (llvm::any_of(sizeOperands, isUninitialized))
3740 SmallVector<IntegerValueRange> stridesOperands =
3742 if (llvm::any_of(stridesOperands, isUninitialized))
3745 StridedMetadataRange sourceRange =
3746 ranges[getSourceMutable().getOperandNumber()];
3750 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.
getStrides();
3756 ConstantIntRanges offset = sourceRange.
getOffsets()[0];
3757 SmallVector<ConstantIntRanges> strides, sizes;
3759 for (
size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3760 bool dropped = droppedDims.test(i);
3762 ConstantIntRanges off =
3773 sizes.push_back(sizeOperands[i].getValue());
3776 setMetadata(getResult(),
3778 SmallVector<ConstantIntRanges>({std::move(offset)}),
3779 std::move(sizes), std::move(strides)));
3786void TransposeOp::getAsmResultNames(
3788 setNameFn(getResult(),
"transpose");
3794 auto originalSizes = memRefType.getShape();
3795 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3796 assert(originalStrides.size() ==
static_cast<unsigned>(memRefType.getRank()));
3805 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3808void TransposeOp::build(OpBuilder &
b, OperationState &
result, Value in,
3809 AffineMapAttr permutation,
3810 ArrayRef<NamedAttribute> attrs) {
3811 auto permutationMap = permutation.getValue();
3812 assert(permutationMap);
3814 auto memRefType = llvm::cast<MemRefType>(in.
getType());
3818 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3819 build(
b,
result, resultType, in, attrs);
3823void TransposeOp::print(OpAsmPrinter &p) {
3824 p <<
" " << getIn() <<
" " << getPermutation();
3826 p <<
" : " << getIn().getType() <<
" to " <<
getType();
3829ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
3830 OpAsmParser::UnresolvedOperand in;
3831 AffineMap permutation;
3832 MemRefType srcType, dstType;
3841 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3842 AffineMapAttr::get(permutation));
3846LogicalResult TransposeOp::verify() {
3849 if (getPermutation().getNumDims() != getIn().
getType().getRank())
3850 return emitOpError(
"expected a permutation map of same rank as the input");
3852 auto srcType = llvm::cast<MemRefType>(getIn().
getType());
3853 auto resultType = llvm::cast<MemRefType>(
getType());
3855 .canonicalizeStridedLayout();
3857 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3860 <<
" is not equivalent to the canonical transposed input type "
3861 << canonicalResultType;
3865OpFoldResult TransposeOp::fold(FoldAdaptor) {
3868 if (getPermutation().isIdentity() &&
getType() == getIn().
getType())
3872 if (
auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3873 AffineMap composedPermutation =
3874 getPermutation().compose(otherTransposeOp.getPermutation());
3875 getInMutable().assign(otherTransposeOp.getIn());
3876 setPermutation(composedPermutation);
3882FailureOr<std::optional<SmallVector<Value>>>
3883TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3891void ViewOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
3892 setNameFn(getResult(),
"view");
3895LogicalResult ViewOp::verify() {
3896 auto baseType = llvm::cast<MemRefType>(getOperand(0).
getType());
3900 if (!baseType.getLayout().isIdentity())
3901 return emitError(
"unsupported map for base memref type ") << baseType;
3904 if (!viewType.getLayout().isIdentity())
3905 return emitError(
"unsupported map for result memref type ") << viewType;
3908 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3909 return emitError(
"different memory spaces specified for base memref "
3911 << baseType <<
" and view memref type " << viewType;
3920Value ViewOp::getViewSource() {
return getSource(); }
3922OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3923 MemRefType sourceMemrefType = getSource().getType();
3924 MemRefType resultMemrefType = getResult().getType();
3926 if (resultMemrefType == sourceMemrefType &&
3927 resultMemrefType.hasStaticShape() &&
isZeroInteger(getByteShift()))
3928 return getViewSource();
3933SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3934 SmallVector<OpFoldResult>
result;
3938 if (ShapedType::isDynamic(dim)) {
3939 result.push_back(getSizes()[ctr++]);
3941 result.push_back(
b.getIndexAttr(dim));
3953 SmallVectorImpl<Value> &foldedDynamicSizes) {
3954 SmallVector<int64_t> staticShape(type.getShape());
3955 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3956 "incorrect number of dynamic sizes");
3960 for (
auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3961 if (ShapedType::isStatic(dimSize))
3964 Value dynamicSize = dynamicSizes[ctr++];
3967 if (cst.value() < 0) {
3968 foldedDynamicSizes.push_back(dynamicSize);
3971 staticShape[dim] = cst.value();
3973 foldedDynamicSizes.push_back(dynamicSize);
3977 return MemRefType::Builder(type).setShape(staticShape);
3991struct ViewOpShapeFolder :
public OpRewritePattern<ViewOp> {
3994 LogicalResult matchAndRewrite(ViewOp viewOp,
3995 PatternRewriter &rewriter)
const override {
3996 SmallVector<Value> foldedDynamicSizes;
3997 MemRefType resultType = viewOp.getType();
3999 resultType, viewOp.getSizes(), foldedDynamicSizes);
4002 if (foldedMemRefType == resultType)
4006 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
4007 viewOp.getSource(), viewOp.getByteShift(),
4008 foldedDynamicSizes);
4016struct ViewOpMemrefCastFolder :
public OpRewritePattern<ViewOp> {
4019 LogicalResult matchAndRewrite(ViewOp viewOp,
4020 PatternRewriter &rewriter)
const override {
4021 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
4026 viewOp, viewOp.getType(), memrefCastOp.getSource(),
4027 viewOp.getByteShift(), viewOp.getSizes());
4033void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
4034 MLIRContext *context) {
4035 results.
add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4038FailureOr<std::optional<SmallVector<Value>>>
4039ViewOp::bubbleDownCasts(OpBuilder &builder) {
4047LogicalResult AtomicRMWOp::verify() {
4048 switch (getKind()) {
4049 case arith::AtomicRMWKind::addf:
4050 case arith::AtomicRMWKind::maximumf:
4051 case arith::AtomicRMWKind::minimumf:
4052 case arith::AtomicRMWKind::mulf:
4053 if (!llvm::isa<FloatType>(getValue().
getType()))
4055 << arith::stringifyAtomicRMWKind(getKind())
4056 <<
"' expects a floating-point type";
4058 case arith::AtomicRMWKind::addi:
4059 case arith::AtomicRMWKind::maxs:
4060 case arith::AtomicRMWKind::maxu:
4061 case arith::AtomicRMWKind::mins:
4062 case arith::AtomicRMWKind::minu:
4063 case arith::AtomicRMWKind::muli:
4064 case arith::AtomicRMWKind::ori:
4065 case arith::AtomicRMWKind::xori:
4066 case arith::AtomicRMWKind::andi:
4067 if (!llvm::isa<IntegerType>(getValue().
getType()))
4069 << arith::stringifyAtomicRMWKind(getKind())
4070 <<
"' expects an integer type";
4078OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4082 return OpFoldResult();
4085FailureOr<std::optional<SmallVector<Value>>>
4086AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4093std::optional<SmallVector<Value>>
4094AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
4097 getMemrefMutable().assign(newMemref);
4098 getIndicesMutable().assign(newIndices);
4100 return std::nullopt;
4107#define GET_OP_CLASSES
4108#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.