26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/ADT/StringRef.h"
41 if (arith::ConstantOp::isBuildableWith(value, type))
42 return builder.
create<arith::ConstantOp>(loc, value, type);
43 if (complex::ConstantOp::isBuildableWith(value, type))
44 return builder.
create<complex::ConstantOp>(loc, type,
45 value.cast<ArrayAttr>());
51 auto tensorType = value.
getType().
cast<RankedTensorType>();
53 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
54 if (tensorType.isDynamicDim(i)) {
55 Value size = builder.
create<tensor::DimOp>(loc, value, i);
56 result.push_back(size);
58 result.push_back(builder.
getIndexAttr(tensorType.getDimSize(i)));
67 assert(tensorType &&
"expected tensor type");
71 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
73 return destOp.getTiedOpOperand(opResult)->get();
81 if (!tensorType.hasStaticShape()) {
89 for (int64_t sz : tensorType.getShape())
95 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
107 result.push_back(*destination);
117 llvm::SmallBitVector droppedDims(mixedSizes.size());
118 int64_t shapePos = 0;
120 for (
const auto &size :
enumerate(mixedSizes)) {
122 bool isStaticUnitSize =
124 size.value().get<
Attribute>().cast<IntegerAttr>().getInt() == 1;
126 if (shapePos ==
static_cast<int64_t
>(reducedShape.size())) {
129 assert(isStaticUnitSize &&
"expected unit dim");
130 droppedDims.set(size.index());
135 if (!isStaticUnitSize) {
141 if (reducedShape[shapePos] == 1) {
147 droppedDims.set(size.index());
150 assert(shapePos ==
static_cast<int64_t
>(reducedShape.size()) &&
151 "dimension mismatch");
160 setNameFn(getResult(),
"cast");
166 auto sourceType = source.
dyn_cast<RankedTensorType>();
167 auto targetType = target.
dyn_cast<RankedTensorType>();
170 if (!sourceType || !targetType)
174 if (sourceType.getElementType() != targetType.getElementType())
178 if (sourceType.getRank() != targetType.getRank())
182 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
183 if (!ShapedType::isDynamic(std::get<0>(t)) &&
184 ShapedType::isDynamic(std::get<1>(t)))
220 castOp.getSource().getType());
255 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
257 operand.set(castOp.getOperand());
265 if (inputs.size() != 1 || outputs.size() != 1)
267 Type a = inputs.front(), b = outputs.front();
289 int64_t rank = one.getRank();
290 if (rank != two.getRank())
295 for (int64_t i = 0; i < rank; ++i) {
296 if (one.isDynamicDim(i)) {
297 join.push_back(two.getDimSize(i));
300 if (two.isDynamicDim(i)) {
301 join.push_back(one.getDimSize(i));
304 if (one.getDimSize(i) != two.getDimSize(i))
306 join.push_back(one.getDimSize(i));
320 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
322 if (!tensorCastOperand)
326 tensorCastOperand.getOperand().getType().cast<
TensorType>();
327 auto intermediateType = tensorCastOperand.getType().
cast<
TensorType>();
342 auto newJoin =
joinShapes(sourceType, resultType);
343 if (firstJoin != newJoin)
346 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
347 tensorCastOperand.getOperand());
369 auto extractOperand =
370 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
373 tensorCast.getType().getShape() == tensorCast.getSource()
375 .cast<RankedTensorType>()
381 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
383 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
384 if (dimMask && dimMask->count(i))
386 int64_t dim = tensorCast.getType().getShape()[dimIndex++];
387 if (ShapedType::isDynamic(dim))
389 sizes[i] = rewriter.getIndexAttr(dim);
392 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
393 tensorCast, tensorCast.getType().cast<RankedTensorType>(),
394 extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
395 extractOperand.getMixedStrides());
404 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
412 setNameFn(getResult(),
"dim");
418 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
419 build(builder, result, source, indexValue);
422 std::optional<int64_t> DimOp::getConstantIndex() {
431 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
432 if (!rankedSourceType)
442 auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
447 auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
453 int64_t indexVal = index.getInt();
454 if (indexVal < 0 || indexVal >= tensorType.getRank())
458 if (!tensorType.isDynamicDim(index.getInt())) {
460 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
463 Operation *definingOp = getSource().getDefiningOp();
466 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
468 fromElements.getResult().getType().
cast<RankedTensorType>();
471 assert(ShapedType::isDynamic(resultType.
getShape()[index.getInt()]));
474 auto dynExtents = fromElements.getDynamicExtents().begin();
475 for (
auto dim : resultType.
getShape().take_front(index.getInt()))
476 if (ShapedType::isDynamic(dim))
479 return Value{*dynExtents};
483 unsigned unsignedIndex = index.getValue().getZExtValue();
485 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
488 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
489 sliceOp.isDynamicSize(unsignedIndex)) {
490 return {sliceOp.getDynamicSize(unsignedIndex)};
508 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
511 Value newSource = castOp.getOperand();
520 results.
add<DimOfCastOp>(context);
530 assert(all_of(staticShape,
531 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
532 "expected only static sizes");
533 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
539 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
540 build(builder, result, tensorType, dynamicSizes);
549 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
553 if (getType().getNumDynamicDims() !=
555 return emitOpError(
"incorrect number of dynamic sizes, has ")
557 << getType().getNumDynamicDims();
566 for (int64_t i = 0; i < getType().getRank(); ++i) {
567 if (getType().isDynamicDim(i)) {
570 reifiedReturnShapes[0][i] = builder.
getIndexAttr(getType().getDimSize(i));
576 Value EmptyOp::getDynamicSize(
unsigned idx) {
577 assert(getType().isDynamicDim(idx) &&
"expected dynamic dim");
579 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
580 if (getType().isDynamicDim(i))
589 for (int64_t i = 0; i < getType().getRank(); ++i) {
590 if (getType().isDynamicDim(i)) {
593 result.push_back(b.getIndexAttr(getType().
getShape()[i]));
617 op.getType().getShape().end());
622 bool changedType =
false;
623 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
624 if (op.getType().isDynamicDim(i)) {
625 Value dynamicSize = op.getDynamicSizes()[ctr++];
627 if (cst.has_value()) {
628 staticShape[i] = *cst;
631 dynamicSizes.push_back(dynamicSize);
640 auto tensorType = RankedTensorType::get(
641 staticShape, op.getType().getElementType(), op.getType().getEncoding());
643 rewriter.
create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
654 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
655 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
656 if (!emptyTensorOp || !maybeConstantIndex)
658 if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
661 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
688 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
692 auto resultType = castOp->getResult(0).getType().
cast<RankedTensorType>();
696 newMixedSizes.reserve(currMixedSizes.size());
697 assert(resultShape.size() == currMixedSizes.size() &&
698 "mismatch in result shape and sizes of empty op");
699 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
700 int64_t newDim = std::get<0>(it);
704 if (
auto attr = currDim.dyn_cast<
Attribute>()) {
705 if (ShapedType::isDynamic(newDim) ||
706 newDim != attr.cast<IntegerAttr>().getInt()) {
711 producer,
"mismatch in static value of shape of empty tensor "
712 "result and cast result");
714 newMixedSizes.push_back(attr);
720 if (!ShapedType::isDynamic(newDim)) {
727 newMixedSizes.push_back(currDim);
741 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
742 ReplaceEmptyTensorStaticShapeDims>(context);
764 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
767 if (!tensorCast.getSource().getType().isa<RankedTensorType>())
770 extract, tensorCast.getSource(), extract.getIndices());
777 void ExtractOp::getAsmResultNames(
779 setNameFn(getResult(),
"extracted");
784 auto tensorType = getTensor().getType().cast<RankedTensorType>();
785 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
786 return emitOpError(
"incorrect number of indices for extract_element");
793 if (
Attribute tensor = adaptor.getTensor())
795 return splatTensor.getSplatValue<
Attribute>();
799 for (
Attribute indice : adaptor.getIndices()) {
800 if (!indice || !indice.isa<IntegerAttr>())
802 indices.push_back(indice.cast<IntegerAttr>().getInt());
806 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
807 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
808 auto rank = tensorType.getRank();
809 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
813 for (
int i = rank - 1; i >= 0; --i) {
815 stride *= tensorType.getDimSize(i);
816 flatIndex += indices[i] * stride;
820 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
823 return fromElementsOp.getElements()[flatIndex];
827 if (
Attribute tensor = adaptor.getTensor()) {
828 auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
829 if (elementsAttr && elementsAttr.isValidIndex(indices))
830 return elementsAttr.getValues<
Attribute>()[indices];
838 results.
add<ExtractFromTensorCast>(context);
845 void FromElementsOp::getAsmResultNames(
847 setNameFn(getResult(),
"from_elements");
858 assert(!elements.empty() &&
"expected at least one element");
859 Type resultType = RankedTensorType::get(
860 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
861 build(builder, result, resultType, elements);
864 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
865 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
888 struct ExtractElementFromIndexCast
895 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
901 auto newExtract = rewriter.
create<tensor::ExtractOp>(
902 loc, elementTy, indexCast.getIn(), extract.getIndices());
915 results.
add<ExtractElementFromIndexCast>(context);
922 void GatherOp::getAsmResultNames(
924 setNameFn(getResult(),
"gather");
939 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
940 RankedTensorType indicesType,
944 resultShape.reserve(resultShape.size() + sourceType.getRank());
945 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
946 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
948 resultShape.push_back(1);
951 resultShape.push_back(sourceType.getDimSize(idx));
958 StringRef gatherOrScatter, StringRef sourceOrDest) {
960 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
962 int64_t numGatherDims = dims.size();
963 if (numGatherDims > rank)
965 <<
"_dims overflow " << sourceOrDest <<
" rank";
966 for (int64_t val : dims) {
969 <<
"_dims value must be non-negative";
972 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
974 for (int64_t i = 1; i < numGatherDims; ++i) {
975 if (dims[i - 1] >= dims[i])
977 <<
"_dims values must be strictly increasing";
983 int64_t sourceRank = getSourceType().getRank();
986 "gather",
"source")))
989 RankedTensorType expectedResultType = GatherOp::inferResultType(
990 getSourceType(), getIndicesType(), gatherDims,
false);
991 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
992 getSourceType(), getIndicesType(), gatherDims,
true);
993 if (getResultType() != expectedResultType &&
994 getResultType() != expectedRankReducedResultType) {
995 return emitOpError(
"result type "
998 << expectedResultType <<
" or its rank-reduced variant "
999 << expectedRankReducedResultType <<
" (got: " << getResultType()
1010 void InsertOp::getAsmResultNames(
1012 setNameFn(getResult(),
"inserted");
1017 auto destType = getDest().getType().cast<RankedTensorType>();
1018 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1019 return emitOpError(
"incorrect number of indices");
1028 if (scalar == splatDest.getSplatValue<
Attribute>())
1037 void GenerateOp::getAsmResultNames(
1039 setNameFn(getResult(),
"generated");
1046 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1047 if (getType().isDynamicDim(dim)) {
1048 reifiedReturnShapes[0][dim] = getOperand(idx++);
1050 reifiedReturnShapes[0][dim] =
1060 RankedTensorType resultTy = getType().cast<RankedTensorType>();
1061 if (getNumOperands() != resultTy.getNumDynamicDims())
1062 return emitError(
"must have as many index operands as dynamic extents "
1063 "in the result type");
1069 RankedTensorType resultTy = getType().cast<RankedTensorType>();
1071 if (!llvm::all_of(getBody().getArgumentTypes(),
1073 return emitError(
"all body arguments must be index");
1074 if (getBody().getNumArguments() != resultTy.getRank())
1075 return emitError(
"must have one body argument per input dimension");
1078 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1080 if (yieldOp.getValue().getType() != resultTy.getElementType())
1082 "body must be terminated with a `yield` operation of the tensor "
1088 void GenerateOp::build(
1092 build(b, result, resultTy, dynamicExtents);
1097 auto rank = resultTy.
cast<RankedTensorType>().getRank();
1101 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1114 LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
1117 tensorFromElements.getResult().getType().
cast<RankedTensorType>();
1119 if (resultType.hasStaticShape())
1124 auto operandsIt = tensorFromElements.getDynamicExtents().begin();
1126 for (int64_t dim : resultType.getShape()) {
1127 if (!ShapedType::isDynamic(dim)) {
1128 newShape.push_back(dim);
1133 newShape.push_back(ShapedType::kDynamic);
1134 newOperands.push_back(*operandsIt++);
1137 newShape.push_back(index.getSExtValue());
1141 if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
1144 auto loc = tensorFromElements.getLoc();
1145 auto newOp = rewriter.
create<GenerateOp>(
1146 loc, RankedTensorType::get(newShape, resultType.getElementType()),
1149 newOp.getBody().begin());
1167 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1172 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1177 Block *body = &tensorFromElements.getBody().
front();
1180 rewriter.
clone(op, mapping);
1194 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1201 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1202 setNameFn(getResult(),
"rank");
1207 auto type = getOperand().getType();
1208 auto shapedType = type.
dyn_cast<ShapedType>();
1209 if (shapedType && shapedType.hasRank())
1210 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1211 return IntegerAttr();
1218 void ReshapeOp::getAsmResultNames(
1220 setNameFn(getResult(),
"reshape");
1224 int64_t numElements = 1;
1225 for (
auto dim : type.getShape())
1235 return emitOpError(
"element types of source and destination tensor "
1236 "types should be the same");
1239 getShape().getType().cast<RankedTensorType>().getDimSize(0);
1240 auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
1241 auto operandRankedType = operandType.
dyn_cast<RankedTensorType>();
1243 if (resultRankedType) {
1244 if (operandRankedType && resultRankedType.hasStaticShape() &&
1245 operandRankedType.hasStaticShape()) {
1247 return emitOpError(
"source and destination tensor should have the "
1248 "same number of elements");
1250 if (ShapedType::isDynamic(shapeSize))
1251 return emitOpError(
"cannot use shape operand with dynamic length to "
1252 "reshape to statically-ranked tensor type");
1253 if (shapeSize != resultRankedType.getRank())
1255 "length of shape operand differs from the result's tensor rank");
1264 void CollapseShapeOp::getAsmResultNames(
1266 setNameFn(getResult(),
"collapsed");
1269 void ExpandShapeOp::getAsmResultNames(
1271 setNameFn(getResult(),
"expanded");
1274 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1275 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1276 "invalid resultDim");
1278 if (llvm::is_contained(it.value(), resultDim))
1280 llvm_unreachable(
"could not find reassociation group");
1288 getReassociationIndices());
1296 getReassociationIndices());
1299 RankedTensorType CollapseShapeOp::inferCollapsedType(
1301 return inferCollapsedType(
1303 type.getContext(), reassociation)));
1309 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1311 auto shape = type.getShape();
1313 newShape.reserve(reassociation.size());
1318 unsigned currentDim = 0;
1320 unsigned dim = m.getNumResults();
1321 auto band = shape.slice(currentDim, dim);
1323 if (llvm::is_contained(band, ShapedType::kDynamic))
1324 size = ShapedType::kDynamic;
1326 for (
unsigned d = 0; d < dim; ++d)
1327 size *= shape[currentDim + d];
1328 newShape.push_back(size);
1332 return RankedTensorType::get(newShape, type.getElementType());
1338 auto resultType = inferCollapsedType(
1342 build(b, result, resultType, src, attrs);
1349 if (
auto rtp1 = tp1.
dyn_cast<RankedTensorType>()) {
1350 if (
auto rtp2 = tp2.
dyn_cast<RankedTensorType>())
1351 return rtp1.getShape() == rtp2.getShape() &&
1352 rtp1.getElementType() == rtp2.getElementType();
1359 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1360 TensorReshapeOp, ExpandShapeOp>::value>
1362 RankedTensorType expandedType,
1363 RankedTensorType collapsedType) {
1368 auto maps = op.getReassociationMaps();
1369 RankedTensorType expectedType =
1370 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1372 return op.emitOpError(
"expected collapsed type to be ")
1373 << expectedType <<
", but got " << collapsedType;
1378 auto srcType = getSrcType();
1379 auto resultType = getResultType();
1380 if (srcType.getRank() >= resultType.getRank())
1381 return emitOpError(
"expected rank expansion, but found source rank ")
1382 << srcType.getRank() <<
" >= result rank " << resultType.getRank();
1388 auto srcType = getSrcType();
1389 auto resultType = getResultType();
1390 if (srcType.getRank() <= resultType.getRank())
1391 return emitOpError(
"expected rank reduction, but found source rank ")
1392 << srcType.getRank() <<
" <= result rank " << resultType.getRank();
1400 template <
typename TensorReshapeOp>
1411 reshapeOp.getResultType(), attr.
getRawData());
1418 template <
typename TensorReshapeOp>
1425 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1430 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1437 template <
typename TensorReshapeOp>
1443 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1447 auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
1449 if (!shapedTy.hasStaticShape())
1453 fromElements.getElements());
1462 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1464 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1468 RankedTensorType srcType =
1469 castOp.getSource().getType().cast<RankedTensorType>();
1470 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1471 srcType, collapseShapeOp.getReassociationMaps());
1473 if (newResultType == collapseShapeOp.getResultType()) {
1475 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1478 auto newOp = rewriter.
create<CollapseShapeOp>(
1479 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1480 collapseShapeOp.getReassociation());
1482 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1493 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1498 std::optional<int64_t> dim = dimOp.getConstantIndex();
1499 if (!dim.has_value())
1503 TensorType resultType = expandShapeOp.getResultType();
1504 if (!resultType.isDynamicDim(*dim))
1508 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1512 int64_t product = 1;
1514 for (int64_t d : grp) {
1516 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1517 product *= resultType.getDimSize(d);
1523 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1537 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1538 if (!collapseShapeOp)
1542 std::optional<int64_t> dim = dimOp.getConstantIndex();
1543 if (!dim.has_value())
1547 TensorType resultType = collapseShapeOp.getResultType();
1548 if (!resultType.isDynamicDim(*dim))
1553 collapseShapeOp.getReassociationIndices()[*dim];
1560 srcDimSizes.push_back(rewriter.
create<DimOp>(
1561 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1563 product = product ? product * syms.back() : syms.back();
1575 FoldReshapeWithConstant<ExpandShapeOp>,
1576 FoldReshapeWithSplat<ExpandShapeOp>,
1577 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1578 FoldDimOfCollapseShape>(context);
1586 FoldReshapeWithConstant<CollapseShapeOp>,
1587 FoldReshapeWithSplat<CollapseShapeOp>,
1588 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1592 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1593 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
1594 adaptor.getOperands());
1597 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1598 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
1599 adaptor.getOperands());
1606 void ExtractSliceOp::getAsmResultNames(
1608 setNameFn(getResult(),
"extracted_slice");
1614 RankedTensorType ExtractSliceOp::inferResultType(
1620 assert(
static_cast<int64_t
>(staticSizes.size()) ==
1621 sourceShapedTensorType.getRank() &&
1622 "unexpected staticSizes not equal to rank of source");
1623 return RankedTensorType::get(staticSizes,
1624 sourceShapedTensorType.getElementType());
1627 RankedTensorType ExtractSliceOp::inferResultType(
1635 return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
1636 staticSizes, staticStrides);
1647 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1648 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1653 inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1654 .cast<RankedTensorType>();
1655 int rankDiff = inferredType.getRank() - desiredResultRank;
1657 auto shape = inferredType.getShape();
1658 llvm::SmallBitVector dimsToProject =
1662 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1663 if (!dimsToProject.test(pos))
1664 projectedShape.push_back(shape[pos]);
1666 RankedTensorType::get(projectedShape, inferredType.getElementType());
1668 return inferredType;
1671 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1672 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1680 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1681 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1688 RankedTensorType resultType,
Value source,
1698 auto sourceRankedTensorType = source.
getType().
cast<RankedTensorType>();
1702 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1703 staticSizes, staticStrides)
1704 .cast<RankedTensorType>();
1706 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1720 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1729 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1735 RankedTensorType resultType,
Value source,
1744 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1751 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1754 template <
typename OpTy>
1756 OpTy op,
Type expectedType) {
1757 auto memrefType = expectedType.
cast<ShapedType>();
1762 return op.emitError(
"expected rank to be smaller or equal to ")
1763 <<
"the other rank. ";
1765 return op.emitError(
"expected type to be ")
1766 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
1768 return op.emitError(
"expected element type to be ")
1769 << memrefType.getElementType();
1771 llvm_unreachable(
"unexpected extract_slice op verification result");
1778 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1779 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
1792 assert(sourceTensorType &&
"not a ranked tensor type");
1793 auto sourceShape = sourceTensorType.getShape();
1794 if (sourceShape.equals(desiredShape))
1796 auto maybeRankReductionMask =
1798 if (!maybeRankReductionMask)
1807 reifiedReturnShapes.resize(1);
1808 reifiedReturnShapes[0].reserve(getType().getRank());
1811 for (
const auto &size :
enumerate(mixedSizes)) {
1812 if (droppedDims.test(size.index()))
1814 reifiedReturnShapes[0].push_back(size.value());
1835 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
1842 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
1843 return matchPattern(operand, matchConstantIndex());
1847 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
1856 auto sliceOpType = sliceOp.getType();
1857 RankedTensorType resultType =
1858 ExtractSliceOp::inferCanonicalRankReducedResultType(
1859 sliceOpType.getRank(), sliceOp.getSourceType(),
1860 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1861 sliceOp.getMixedStrides());
1862 Value newResult = rewriter.
create<ExtractSliceOp>(
1863 loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
1864 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1865 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1866 if (newResult.
getType() != sliceOpType)
1867 newResult = rewriter.
create<CastOp>(loc, sliceOpType, newResult);
1876 template <
typename IterTy,
typename ElemTy>
1881 assert(offsets.size() == sizes.size());
1882 assert(offsets.size() == strides.size());
1883 if (offsets.empty())
1886 int64_t offset = offsets.front();
1887 int64_t size = sizes.front();
1888 int64_t stride = strides.front();
1889 if (offsets.size() == 1) {
1890 for (int64_t i = 0; i < size; ++i, offset += stride)
1891 outValues->push_back(*(values + offset));
1896 for (int64_t i = 0; i < size; ++i, offset += stride) {
1897 auto begin = values + offset * counts.front();
1898 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1899 offsets.drop_front(), sizes.drop_front(),
1900 strides.drop_front(), outValues);
1907 class ConstantOpExtractSliceFolder final
1912 ConstantOpExtractSliceFolder(
MLIRContext *context,
1915 controlFn(std::move(controlFn)) {}
1928 auto sourceType = op.getSource().getType().cast<ShapedType>();
1929 auto resultType = op.getResult().getType().cast<ShapedType>();
1930 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1937 int64_t count = sourceType.getNumElements();
1942 auto offsets = op.getStaticOffsets();
1943 if (llvm::is_contained(offsets, ShapedType::kDynamic))
1945 auto sizes = op.getStaticSizes();
1946 if (llvm::is_contained(sizes, ShapedType::kDynamic))
1948 auto strides = op.getStaticStrides();
1949 if (llvm::is_contained(strides, ShapedType::kDynamic))
1955 counts.reserve(shape.size());
1956 for (int64_t v : shape) {
1958 counts.push_back(count);
1966 outValues.reserve(sourceType.getNumElements());
1967 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1968 elems.begin(), counts, offsets, sizes, strides, &outValues);
1972 outValues.reserve(sourceType.getNumElements());
1973 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1974 elems.begin(), counts, offsets, sizes, strides, &outValues);
1997 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2006 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2007 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2015 ExtractSliceOp newOp) {
2016 Value replacement = newOp.getResult();
2017 if (replacement.
getType() != op.getType())
2018 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2029 ExtractSliceOpCastFolder>(context);
2035 ShapedType shapedType) {
2042 auto shape = shapedType.getShape();
2043 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2057 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2060 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2061 insertOp.isSameAs(extractOp, isSame))
2062 return insertOp.getSource();
2067 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2069 auto resultType = getResult().getType().cast<ShapedType>();
2070 if (resultType.hasStaticShape())
2071 return splat.resizeSplat(resultType);
2073 if (getSourceType() == getType() &&
2075 return this->getSource();
2084 auto rankedTensorType = tensor.
getType().
cast<RankedTensorType>();
2085 unsigned rank = rankedTensorType.getRank();
2089 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2090 offsets, sizes, strides);
2097 void InsertSliceOp::getAsmResultNames(
2099 setNameFn(getResult(),
"inserted_slice");
2113 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2126 build(b, result, source, dest, offsets, sizes, strides, attrs);
2139 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2147 ShapedType *expectedType =
nullptr) {
2150 RankedTensorType expected = ExtractSliceOp::inferResultType(
2151 dstType, staticOffsets, staticSizes, staticStrides);
2153 *expectedType = expected;
2159 ShapedType expectedType;
2162 getStaticSizes(), getStaticStrides(), &expectedType);
2184 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2187 if (!prevInsertOp ||
2188 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2189 !prevInsertOp.isSameAs(insertOp, isSame))
2192 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2204 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2207 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2208 !extractOp.isSameAs(insertOp, isSame))
2211 return extractOp.getSource();
2215 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2216 getSourceType() == getType() &&
2218 return this->getSource();
2229 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
2230 if (getType().isDynamicDim(dim)) {
2231 reifiedReturnShapes[0][dim] =
2232 builder.
createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
2234 reifiedReturnShapes[0][dim] =
2245 template <
typename InsertOpTy>
2246 class InsertSliceOpConstantArgumentFolder final
2264 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2265 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2266 mixedOffsets, mixedSizes, mixedStrides);
2267 Value toInsert = insertSliceOp.getSource();
2268 if (sourceType != insertSliceOp.getSourceType()) {
2273 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2275 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2276 sourceType, toInsert);
2279 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2280 mixedSizes, mixedStrides);
2305 template <
typename InsertOpTy>
2306 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2311 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2312 return matchPattern(operand, matchConstantIndex());
2316 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2317 auto castOp = v.getDefiningOp<tensor::CastOp>();
2319 return std::nullopt;
2320 return castOp.getSource();
2322 std::optional<Value> sourceCastSource =
2323 getSourceOfCastOp(insertSliceOp.getSource());
2324 std::optional<Value> destCastSource =
2325 getSourceOfCastOp(insertSliceOp.getDest());
2326 if (!sourceCastSource && !destCastSource)
2330 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2331 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2332 auto srcType = src.
getType().template cast<ShapedType>();
2333 auto dstType = dst.getType().template cast<ShapedType>();
2335 insertSliceOp.getStaticSizes(),
2336 insertSliceOp.getStaticStrides()) !=
2341 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2342 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2345 bool isParallelInsert =
2346 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2347 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2348 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2349 insertSliceOp.getDestType(),
2378 template <
typename InsertOpTy>
2379 struct InsertSliceOpSourceCastInserter final
2385 RankedTensorType srcType = insertSliceOp.getSourceType();
2386 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2389 srcType.getShape().end());
2390 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2391 if (std::optional<int64_t> constInt =
2393 newSrcShape[i] = *constInt;
2396 RankedTensorType newSrcType =
2397 RankedTensorType::get(newSrcShape, srcType.getElementType());
2398 if (srcType == newSrcType ||
2400 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2412 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2415 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2417 insertSliceOp, cast, insertSliceOp.getDest(),
2418 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2419 insertSliceOp.getMixedStrides());
2431 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2432 InsertSliceOpCastFolder<InsertSliceOp>,
2433 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2440 auto rankedTensorType = dest.
getType().
cast<RankedTensorType>();
2441 unsigned rank = rankedTensorType.getRank();
2445 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2454 setNameFn(getResult(),
"padded");
2460 Type typeToInfer,
Type typeToInferFrom) {}
2464 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2465 Type &typeToInfer,
Type typeToInferFrom) {
2467 typeToInfer = typeToInferFrom;
2472 auto sourceType = getSource().getType().cast<RankedTensorType>();
2473 auto resultType = getResult().getType().cast<RankedTensorType>();
2475 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2476 if (!expectedType) {
2477 return emitError(
"failed to infer expectedType from sourceType ")
2478 << sourceType <<
", specified resultType is " << resultType;
2480 if (resultType.getRank() != expectedType.getRank()) {
2482 << resultType <<
" does not match the inferred type "
2485 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
2486 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2488 if (expectedType.isDynamicDim(i))
2491 << resultType <<
" does not match the inferred type "
2499 auto ®ion = getRegion();
2500 unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
2503 return emitError(
"expected the block to have ") << rank <<
" arguments";
2507 if (!en.value().isIndex())
2508 return emitOpError(
"expected block argument ")
2509 << (en.index() + 1) <<
" to be an index";
2514 if (yieldOp.getValue().getType() !=
2516 return emitOpError(
"expected yield type to match shape element type");
2521 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2525 unsigned rank = sourceType.getRank();
2526 if (staticLow.size() != rank)
2527 return RankedTensorType();
2528 if (staticHigh.size() != rank)
2529 return RankedTensorType();
2530 if (!(resultShape.empty() || resultShape.size() == rank))
2531 return RankedTensorType();
2534 for (
auto i : llvm::seq<unsigned>(0, rank)) {
2535 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2536 staticHigh[i] == ShapedType::kDynamic) {
2537 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2540 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2541 assert((resultShape.empty() || size == resultShape[i] ||
2542 resultShape[i] == ShapedType::kDynamic) &&
2543 "mismatch between inferred shape and result shape");
2544 inferredShape.push_back(size);
2548 return RankedTensorType::get(inferredShape, sourceType.getElementType());
2555 auto sourceType = source.
getType().
cast<RankedTensorType>();
2557 resultType = inferResultType(sourceType, staticLow, staticHigh);
2558 build(b, result, resultType, source, low, high,
2567 auto sourceType = source.
getType().
cast<RankedTensorType>();
2568 unsigned rank = sourceType.getRank();
2570 build(b, result, resultType, source, staticVector, staticVector, low, high,
2578 auto sourceType = source.
getType().
cast<RankedTensorType>();
2588 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2590 assert(resultType.
isa<RankedTensorType>());
2591 build(b, result, resultType, source, dynamicLow, dynamicHigh,
2601 build(b, result, resultType, source, low, high, nofold, attrs);
2605 int sourceRank = source.
getType().
cast<RankedTensorType>().getRank();
2612 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
2616 llvm::SmallBitVector PadOp::getPaddedDims() {
2617 llvm::SmallBitVector paddedDims(getSourceType().getRank());
2619 for (
const auto &en :
enumerate(paddingWidths))
2621 paddedDims.set(en.index());
2623 extractPaddedDims(getMixedLowPad());
2624 extractPaddedDims(getMixedHighPad());
2636 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2638 if (padTensorOp.getNofold())
2641 padTensorOp, padTensorOp.getResult().getType(),
2642 padTensorOp.getSource());
2653 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2657 auto newResultType = PadOp::inferResultType(
2658 castOp.getSource().getType().cast<RankedTensorType>(),
2659 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2660 padTensorOp.getResultType().getShape());
2662 if (newResultType == padTensorOp.getResultType()) {
2664 padTensorOp.getSourceMutable().assign(castOp.getSource());
2667 auto newOp = rewriter.
create<PadOp>(
2668 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2669 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2670 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2673 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2676 padTensorOp, padTensorOp.getResultType(), newOp);
2689 if (!padTensorOp.getResult().hasOneUse())
2692 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2696 tensorCastOp.getDest().getType()))
2699 auto replacementOp = rewriter.
create<PadOp>(
2700 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2701 padTensorOp.getSource(), padTensorOp.getStaticLow(),
2702 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2703 padTensorOp.getHigh(), padTensorOp.getNofold(),
2707 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
2708 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
2753 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2756 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2757 if (!outerPadOp || outerPadOp.getNofold())
2759 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2764 int64_t rank = padOp.getSourceType().getRank();
2765 if (outerSliceOp.getSourceType().getRank() != rank) {
2767 "cannot fold rank-reducing chain");
2771 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2773 padOp,
"cannot fold non-unit stride ExtractSliceOps");
2777 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2779 "cannot fold PadOps with low padding");
2784 Value innerValue = padOp.getConstantPaddingValue();
2785 Value outerValue = outerPadOp.getConstantPaddingValue();
2786 if (!innerValue || !outerValue ||
2789 innerAttr != outerAttr) {
2791 padOp,
"cannot fold PadOps with different padding values");
2795 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2796 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2797 if (innerDims.anyCommon(outerDims)) {
2799 padOp,
"cannot fold PadOps with common padding dimensions");
2809 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2810 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2811 if (!innerDims.test(en.index()) &&
2813 en.value() = outerOffset;
2816 if (!outerDims.test(en.index()) &&
2818 en.value() = innerOffset;
2822 padOp,
"cannot find zero-offset and zero-padding pair");
2832 if (!outerDims.test(en.index()))
2834 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2835 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2836 assert(!ShapedType::isDynamic(sourceSize) &&
2837 "expected padded dimension to have a static size");
2840 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
2841 "match the size of the outer padding");
2843 en.value() = outerSliceOp.getMixedSizes()[en.index()];
2849 if (innerDims.test(en.index()))
2850 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2851 if (outerDims.test(en.index()))
2852 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2857 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
2858 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2859 innerSliceOp.getMixedStrides());
2860 auto newPadOp = rewriter.
create<PadOp>(
2861 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2862 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
2865 newPadOp.getRegion().begin());
2866 rewriter.
replaceOp(padOp, newPadOp.getResult());
2876 Value input = padTensorOp.getSource();
2877 if (!input.
getType().
isa<RankedTensorType>())
2880 auto inputRank = inputDims.size();
2882 auto oldResultType =
2883 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
2887 auto outputDims = oldResultType.getShape();
2891 for (
auto operand : padTensorOp.getLow()) {
2894 constOperandsLow.push_back(ShapedType::kDynamic);
2897 constOperandsLow.push_back(intOp.getExtValue());
2900 for (
auto operand : padTensorOp.getHigh()) {
2903 constOperandsHigh.push_back(ShapedType::kDynamic);
2906 constOperandsHigh.push_back(intOp.getExtValue());
2913 if (inputDims.size() != outputDims.size() ||
2914 inputDims.size() != constLow.size() ||
2915 inputDims.size() != constHigh.size())
2920 for (
size_t i = 0; i < inputRank; i++) {
2921 if (constLow[i] == ShapedType::kDynamic)
2922 constLow[i] = constOperandsLow[lowCount++];
2923 if (constHigh[i] == ShapedType::kDynamic)
2924 constHigh[i] = constOperandsHigh[highCount++];
2932 for (
size_t i = 0; i < inputRank; i++) {
2933 if (outputDims[i] == ShapedType::kDynamic) {
2934 newOutDims.push_back(
2935 (staticLow[i] == ShapedType::kDynamic ||
2936 staticHigh[i] == ShapedType::kDynamic ||
2937 inputDims[i] == ShapedType::kDynamic
2938 ? ShapedType::kDynamic
2939 : inputDims[i] + staticLow[i] + staticHigh[i]));
2941 newOutDims.push_back(outputDims[i]);
2946 llvm::all_of(newOutDims,
2947 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
2951 auto newResultType = RankedTensorType::get(
2952 newOutDims, padTensorOp.getType().getElementType());
2953 auto newOp = rewriter.
create<PadOp>(
2954 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
2955 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2959 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2971 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2972 FoldOrthogonalPaddings, FoldStaticPadding>(context);
2984 Value PadOp::getConstantPaddingValue() {
2985 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2988 Value padValue = yieldOp.getValue();
3000 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3010 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3011 ParallelCombiningOpInterface parallelCombiningParent =
3012 getParallelCombiningParent();
3013 for (
const auto &it :
3016 if (&nextOp == getOperation())
3017 return parallelCombiningParent.getParentResult(it.index());
3019 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3034 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3048 build(b, result, source, dest, offsets, sizes, strides, attrs);
3062 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3066 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3067 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3068 << *(getOperation()->getParentOp());
3070 ShapedType expectedType;
3073 getStaticSizes(), getStaticStrides(), &expectedType);
3077 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3079 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3080 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3081 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3088 void ScatterOp::getAsmResultNames(
3090 setNameFn(getResult(),
"scatter");
3094 int64_t destRank = getDestType().getRank();
3097 "scatter",
"dest")))
3101 return emitOpError(
"requires 'unique' attribute to be set");
3108 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3109 getDestType(), getIndicesType(), scatterDims,
false);
3110 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3111 getDestType(), getIndicesType(), scatterDims,
true);
3112 if (getSourceType() != expectedSourceType &&
3113 getSourceType() != expectedRankReducedSourceType) {
3114 return emitOpError(
"source type "
3117 << expectedSourceType <<
" or its rank-reduced variant "
3118 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3129 void SplatOp::getAsmResultNames(
3131 setNameFn(getResult(),
"splat");
3135 auto constOperand = adaptor.getInput();
3136 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3155 Type newOperandType, ArrayAttr reassociation)
const {
3156 if (operand.
getType() == newOperandType)
3158 return rewriter.
create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3164 RankedTensorType sourceType = packOp.getSourceType();
3165 RankedTensorType destType = packOp.getDestType();
3166 if (sourceType.getRank() != 1 || packOp.getPaddingValue())
3168 auto reassociation =
3172 Value expanded = insertExpand(
3173 rewriter, packOp.getLoc(), packOp.getSource(), destType,
3183 patterns.
add<SimplifyPackToExandShape>(patterns.
getContext());
3186 template <
typename OpTy>
3190 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3191 "applies to only pack or unpack operations");
3192 int64_t destRank = op.getDestRank();
3194 ShapedType resultType = op.getResult().getType().template cast<ShapedType>();
3195 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
3196 if (resultType.isDynamicDim(dim)) {
3197 reifiedReturnShapes[0][dim] =
3198 builder.
createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
3200 reifiedReturnShapes[0][dim] =
3207 template <
typename OpTy>
3209 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3210 "applies to only pack or unpack operations");
3214 assert(tiles.size() == dimsToTile.size() &&
3215 "tiles must match indices of dimension to block");
3217 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3218 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3219 return dimAndTileMapping;
3222 template <
typename OpTy>
3224 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3225 "applies to only pack or unpack operations");
3228 unsigned dynamicValIndex = 0;
3229 for (int64_t staticTile : op.getStaticInnerTiles()) {
3230 if (!ShapedType::isDynamic(staticTile))
3233 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3235 return mixedInnerTiles;
3238 template <
typename OpTy>
3240 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3241 "applies to only pack or unpack operations");
3254 size_t dimsPosSize = dimsPos.size();
3255 if (dimsPosSize > rank)
3258 for (int64_t dim : dimsPos)
3259 uniqued.insert(dim);
3260 if (dimsPosSize != uniqued.size())
3262 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3263 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3272 sourceShape.size() == limitShape.size() &&
3273 "expected source shape rank, and limit of the shape to have same rank");
3274 return llvm::all_of(
3275 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3276 int64_t sourceExtent = std::get<0>(it);
3277 int64_t limit = std::get<1>(it);
3278 return ShapedType::isDynamic(sourceExtent) ||
3279 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3283 template <
typename OpTy>
3285 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3286 "applies to only pack or unpack operations");
3287 Operation *op = packOrUnPack.getOperation();
3291 return llvm::any_of(
3297 if (hasZeros(mixedTiles))
3298 return op->
emitError(
"invalid zero tile factor");
3301 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
3302 ? packOrUnPack.getSourceType()
3303 : packOrUnPack.getDestType();
3304 size_t unpackedRank = unpackedType.getRank();
3308 return op->
emitError(
"invalid inner_dims_pos vector");
3310 return op->
emitError(
"invalid outer_dims_perm vector");
3311 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3312 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3316 if (mixedTiles.size() > unpackedRank) {
3317 return op->
emitError(
"tiling factors must be less than or equal to the "
3318 "input rank for pack or output rank for unpack");
3320 if (mixedTiles.size() != innerDimsPos.size()) {
3322 "tiling factors must equal the number of dimensions to tile");
3325 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3326 ? packOrUnPack.getDestType()
3327 : packOrUnPack.getSourceType();
3328 size_t packedRank = packedType.getRank();
3330 if (unpackedRank + mixedTiles.size() != packedRank) {
3332 "packed rank must equal unpacked rank + tiling factors");
3338 ShapedType expectedPackedType = PackOp::inferPackedType(
3339 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3340 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3341 return op->
emitError(
"the shape of output is not large enough to hold the "
3342 "packed data. Expected at least ")
3343 << expectedPackedType <<
", got " << packedType;
3346 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3348 [](std::tuple<int64_t, OpFoldResult> it) {
3349 std::optional<int64_t> constTileSize =
3350 getConstantIntValue(std::get<1>(it));
3351 int64_t shape = std::get<0>(it);
3352 if (!constTileSize) {
3355 return ShapedType::isDynamic(shape);
3357 if (ShapedType::isDynamic(shape)) {
3364 return shape == constTileSize.value();
3366 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
3367 "tiled dimension in the packed type");
3379 struct PackOrUnPackTransposeResult {
3386 template <
typename OpTy>
3387 static PackOrUnPackTransposeResult
3391 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3392 "applies to only pack or unpack operations");
3393 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3394 "some permutation must be non-empty");
3395 PackOrUnPackTransposeResult metadata;
3396 metadata.innerDimsPos =
3398 metadata.innerTiles =
3400 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3401 ? packOrUnPackOp.getSourceRank()
3402 : packOrUnPackOp.getDestRank();
3403 metadata.outerDimsPerm =
3404 packOrUnPackOp.getOuterDimsPerm().empty()
3405 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3407 if (!innerPermutation.empty()) {
3408 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3410 "invalid inner permutation");
3414 if (!outerPermutation.empty()) {
3415 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3417 "invalid outer permutation");
3427 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3428 setNameFn(getResult(),
"pack");
3434 std::optional<Value> paddingValue,
3436 assert(innerDimsPos.size() == innerTiles.size() &&
3437 "number of tile sizes specified must match the specified number of "
3438 "original dimensions to be tiled");
3442 build(builder, state, dest.
getType(), source, dest,
3443 paddingValue ? *paddingValue :
nullptr,
3444 outerDimsPerm.empty() ?
nullptr
3471 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3472 if (ShapedType::isDynamic(inputShape[pos]))
3477 if (inputShape[pos] % (*constantTile) != 0)
3490 auto paddingValue = getPaddingValue();
3493 return emitOpError(
"expected padding_value has ")
3494 << getSourceType().getElementType()
3495 <<
" but got: " << paddingValue.getType();
3498 if (!paddingValue &&
3499 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
3501 return emitOpError(
"invalid tile factor provided. Only full tiles are "
3502 "supported when padding_value is not set");
3512 for (
auto o : ofrs) {
3514 if (o.dyn_cast<
Value>())
3515 result.push_back(ShapedType::kDynamic);
3530 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3532 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3533 resultShape[tiledDim.value()] = ShapedType::kDynamic;
3536 resultShape[tiledDim.value()] =
ceilDiv(resultShape[tiledDim.value()],
3537 innerTileSizes[tiledDim.index()]);
3541 if (!outerDimsPerm.empty())
3545 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3560 builder, loc, ceilDivExpr,
3561 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3563 if (!outerDimsPerm.empty())
3565 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3570 innerDimsPos, outerDimsPerm);
3576 for (
unsigned i = 0; i < resultDims.size(); ++i) {
3577 if (!ShapedType::isDynamic(resultTypeShape[i]))
3588 ShapedType PackOp::inferPackedType(ShapedType sourceType,
3593 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3594 return RankedTensorType::get(resultShape, sourceType.getElementType());
3608 for (
auto [index, value] :
3610 if (ShapedType::isDynamic(value))
3611 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
3615 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3616 int64_t dimPos = std::get<0>(it);
3618 mixedSizes[dimPos] =
ceilDiv(mixedSizes[dimPos], tileSize);
3620 if (!outerDimsPerm.empty())
3621 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3623 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3625 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3632 *
this, innerPermutation, outerPermutation);
3633 Value transposedDest =
3634 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3635 metadata.innerDimsPos, metadata.outerDimsPerm);
3636 return b.
create<PackOp>(loc, getSource(), transposedDest,
3637 metadata.innerDimsPos, metadata.innerTiles,
3638 getPaddingValue(), metadata.outerDimsPerm);
3642 template <
typename OpTy>
3644 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3645 "applies to only pack or unpack operations");
3646 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3648 : op.getSourceType();
3650 for (
auto [dimDest,
tile] : llvm::zip(
3651 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3653 if (!constTileSize || ShapedType::isDynamic(dimDest))
3660 if (
auto paddingValue = getPaddingValue())
3675 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3677 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3683 auto packTiles = packOp.getMixedTiles();
3684 auto unPackTiles = unPackOp.getMixedTiles();
3685 if (packTiles.size() != unPackTiles.size())
3687 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
3696 UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3697 if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3699 if (packOp.getPaddingValue() ||
3703 rewriter.
replaceOp(packOp, unPackOp.getSource());
3711 void UnPackOp::getAsmResultNames(
3713 setNameFn(getResult(),
"unpack");
3750 assert(innerDimsPos.size() == innerTiles.size() &&
3751 "number of tile sizes specified must match the specified number of "
3752 "original dimensions to be tiled");
3756 build(builder, state, dest.
getType(), source, dest,
3757 outerDimsPerm.empty() ?
nullptr
3764 Value transposedSource,
3768 *
this, innerPermutation, outerPermutation);
3769 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
3770 metadata.innerDimsPos, metadata.innerTiles,
3771 metadata.outerDimsPerm);
3777 PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3778 if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3780 if (packOp.getPaddingValue() ||
3784 rewriter.
replaceOp(unPackOp, packOp.getSource());
3816 if (isa<InsertSliceOp>(op.getOperation()))
3820 bool hasTensorCastOperand =
3821 llvm::any_of(op->getOpOperands(), [&](
OpOperand &opOperand) {
3822 if (opOperand.get().isa<BlockArgument>())
3824 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3825 return castOp && canFoldIntoConsumerOp(castOp);
3827 if (!hasTensorCastOperand)
3831 newResultTypes.reserve(op->getNumResults());
3833 newOperands.reserve(op->getNumOperands());
3834 for (
OpOperand &opOperand : op->getOpOperands()) {
3837 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.
get());
3838 if (op.isDpsInit(&opOperand) &&
3839 !newOperands.back().getType().isa<MemRefType>())
3840 newResultTypes.push_back(newOperands.back().getType());
3844 Operation *newOp =
clone(rewriter, op, newResultTypes, newOperands);
3847 for (
auto [oldResult, newResult] :
3848 llvm::zip(op->getResults(), newOp->
getResults())) {
3849 if (newResult.
getType() != oldResult.getType()) {
3850 replacements.push_back(rewriter.
create<tensor::CastOp>(
3851 op->getLoc(), oldResult.getType(), newResult));
3853 replacements.push_back(newResult);
3866 void TensorDialect::getCanonicalizationPatterns(
3875 #define GET_OP_CLASSES
3876 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Operation::operand_range getIndices(Operation *op)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, ShapedType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite a subview op with constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
void populateSimplifyTensorPack(RewritePatternSet &patterns)
Patterns to simplify tensor.pack.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
This header declares functions that assit transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
LogicalResult foldDynamicIndexList(Builder &b, SmallVectorImpl< OpFoldResult > &ofrs)
Returns success when any of the elements in ofrs was produced by arith::ConstantIndexOp.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.