32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallBitVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/MathExtras.h"
44 using llvm::divideCeilSigned;
45 using llvm::divideFloorSigned;
53 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
55 if (complex::ConstantOp::isBuildableWith(value, type))
56 return builder.
create<complex::ConstantOp>(loc, type,
57 llvm::cast<ArrayAttr>(value));
63 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
64 if (tensorType.isDynamicDim(dim))
65 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
72 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
74 for (int64_t i = 0; i < tensorType.getRank(); ++i)
81 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
82 assert(tensorType &&
"expected tensor type");
86 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
88 return destOp.getTiedOpOperand(opResult)->get();
96 if (!tensorType.hasStaticShape()) {
104 for (int64_t sz : tensorType.getShape())
110 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
118 if (llvm::isa<TensorType>(opResult.getType())) {
120 if (failed(destination))
122 result.push_back(*destination);
129 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
130 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
131 return rtp1.getShape() == rtp2.getShape() &&
132 rtp1.getElementType() == rtp2.getElementType();
142 llvm::SmallBitVector droppedDims(mixedSizes.size());
143 int64_t shapePos = reducedShape.size() - 1;
145 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
146 size_t idx = mixedSizes.size() - size.index() - 1;
148 bool isStaticUnitSize =
149 isa<Attribute>(size.value()) &&
150 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
155 assert(isStaticUnitSize &&
"expected unit dim");
156 droppedDims.set(idx);
161 if (!isStaticUnitSize) {
167 if (reducedShape[shapePos] == 1) {
173 droppedDims.set(idx);
176 assert(shapePos < 0 &&
"dimension mismatch");
183 static RankedTensorType
187 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
188 "incorrect number of dynamic sizes");
192 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
193 if (type.isDynamicDim(i)) {
194 Value dynamicSize = dynamicSizes[ctr++];
196 if (cst.has_value()) {
198 if (cst.value() < 0) {
199 foldedDynamicSizes.push_back(dynamicSize);
202 staticShape[i] = *cst;
204 foldedDynamicSizes.push_back(dynamicSize);
218 if (inputs.size() != 1 || outputs.size() != 1)
220 Type a = inputs.front(), b = outputs.front();
221 auto aT = dyn_cast<TensorType>(a);
222 auto bT = dyn_cast<TensorType>(b);
226 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
239 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
241 auto tensorBitcastOperand =
242 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
243 if (!tensorBitcastOperand)
246 auto resultType = cast<TensorType>(tensorBitcast.getType());
247 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
248 tensorBitcastOperand.getOperand());
257 results.
add<ChainedTensorBitcast>(context);
265 setNameFn(getResult(),
"cast");
271 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
272 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
275 if (!sourceType || !targetType)
279 if (sourceType.getElementType() != targetType.getElementType())
283 if (sourceType.getRank() != targetType.getRank())
287 if (sourceType.getEncoding() != targetType.getEncoding())
291 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
292 if (!ShapedType::isDynamic(std::get<0>(t)) &&
293 ShapedType::isDynamic(std::get<1>(t)))
329 castOp.getSource().getType());
362 if (llvm::isa<BlockArgument>(opOperand.get()))
364 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
365 return castOp && canFoldIntoConsumerOp(castOp);
372 newOperands.reserve(op->getNumOperands());
377 int64_t dpsInitIdx = 0;
378 for (
OpOperand &opOperand : op->getOpOperands()) {
379 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
381 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
382 if (op.isDpsInit(&opOperand) &&
383 !llvm::isa<MemRefType>(newOperands.back().getType()))
384 newResTy[dpsInitIdx++] = newOperands.back().getType();
394 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
396 operand.set(castOp.getOperand());
400 return success(folded);
404 if (inputs.size() != 1 || outputs.size() != 1)
406 Type a = inputs.front(), b = outputs.front();
407 auto aT = llvm::dyn_cast<TensorType>(a);
408 auto bT = llvm::dyn_cast<TensorType>(b);
412 if (aT.getElementType() != bT.getElementType())
428 int64_t rank = one.getRank();
429 if (rank != two.getRank())
434 for (int64_t i = 0; i < rank; ++i) {
435 if (one.isDynamicDim(i)) {
436 join.push_back(two.getDimSize(i));
439 if (two.isDynamicDim(i)) {
440 join.push_back(one.getDimSize(i));
443 if (one.getDimSize(i) != two.getDimSize(i))
445 join.push_back(one.getDimSize(i));
457 LogicalResult matchAndRewrite(CastOp tensorCast,
459 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
461 if (!tensorCastOperand)
465 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
466 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
467 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
481 auto newJoin =
joinShapes(sourceType, resultType);
482 if (firstJoin != newJoin)
485 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
486 tensorCastOperand.getOperand());
506 LogicalResult matchAndRewrite(CastOp tensorCast,
508 auto extractOperand =
509 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
512 auto rankedResultType =
513 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
514 if (!rankedResultType)
518 rankedResultType.getShape() ==
519 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
525 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
527 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
528 if (dimMask && dimMask->count(i))
530 int64_t dim = rankedResultType.getShape()[dimIndex++];
531 if (ShapedType::isDynamic(dim))
533 sizes[i] = rewriter.getIndexAttr(dim);
536 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
537 tensorCast, rankedResultType, extractOperand.getSource(),
538 extractOperand.getMixedOffsets(), sizes,
539 extractOperand.getMixedStrides());
548 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
555 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
556 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
558 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
559 return llvm::cast<RankedTensorType>(type);
561 int64_t concatRank = tensorTypes[0].getRank();
564 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
567 for (int64_t i = 0, e = concatRank; i < e; ++i) {
571 for (
auto tensorType : tensorTypes)
576 for (
auto tensorType : tensorTypes)
579 sizes[dim] = concatSize.asInteger();
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputs.
getTypes());
587 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
588 build(builder, result, *resultType, dim, inputs);
592 if (getInputs().size() < 1)
593 return emitOpError(
"requires at least one input");
596 for (
auto input : getInputs())
597 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
604 return emitOpError(
"rank of concatenated inputs must match result rank");
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
610 return emitOpError(
"inputs and result element type must match");
612 int64_t dim = getDim();
613 if (dim >= resultRank)
614 return emitOpError(
"concatenation dim must be less than the tensor rank");
617 for (int64_t i = 0, e = resultRank; i < e; ++i) {
621 for (
auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
624 if (failed(maybeSize))
625 return emitOpError(
"static concatenation size mismatch along ")
626 <<
"non-concatenated dimension " << i;
632 for (
auto tensorType : inputTypes)
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
639 for (
auto [inferredSize, actualSize] :
640 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642 ShapedType::isDynamic(actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
644 return emitOpError(
"result type ")
645 << resultType <<
"does not match inferred shape "
646 << inferredResultType <<
" static sizes";
652 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
657 inputShapes.reserve(numInputs);
659 concatOffsets.reserve(numInputs);
670 outputShape = inputShape;
671 concatOffsets.push_back(zero);
673 concatOffsets.push_back(outputShape[concatDim]);
675 builder, loc, addExpr,
676 {outputShape[concatDim], inputShape[concatDim]});
678 inputShapes.emplace_back(std::move(inputShape));
681 Value replacement = builder.
create<tensor::EmptyOp>(
682 loc, outputShape,
getType().getElementType());
684 int64_t rank =
getType().getRank();
689 offsets[concatDim] = concatOffsets[index];
690 auto insertSlice = builder.
create<tensor::InsertSliceOp>(
691 loc, input, replacement, offsets, inputShapes[index], strides);
694 if (replacement.getType() !=
getType()) {
695 replacement = builder.
create<tensor::CastOp>(loc,
getType(), replacement);
704 int64_t dim = getDim();
705 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
707 Value init = inputs[0];
708 int64_t rank =
getType().getRank();
715 for (int64_t i = 0; i < rank; ++i) {
718 if (!
getType().isDynamicDim(i)) {
720 }
else if (!inferredResultType.isDynamicDim(i)) {
723 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
725 reifiedReturnShapes[0][i] =
726 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
730 if (
getType().isDynamicDim(dim)) {
738 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
746 reifiedReturnShapes[0][dim] =
752 void ConcatOp::getAsmResultNames(
754 setNameFn(getResult(),
"concat");
759 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
771 if (concatOp.getInputs().size() != 1)
774 concatOp.getInputs()[0]);
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
808 LogicalResult matched = failure();
812 for (
auto [operandIdx, operandType] :
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(operandType).getDimSize(dim);
818 inferredOperandShape, inferredResultType.getElementType());
826 rewriter.
create<CastOp>(concatOp->getLoc(), inferredOperandType,
827 concatOp.getOperand(operandIdx));
829 concatOp->setOperand(operandIdx, castOp->getResult(0));
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863 concatOp.getResultType())) {
867 auto newConcatOp = rewriter.
create<ConcatOp>(
868 concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
880 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
889 setNameFn(getResult(),
"dim");
895 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
896 build(builder, result, source, indexValue);
899 std::optional<int64_t> DimOp::getConstantIndex() {
908 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
909 if (!rankedSourceType)
920 setResultRange(getResult(),
926 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
931 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
937 int64_t indexVal = index.getInt();
938 if (indexVal < 0 || indexVal >= tensorType.getRank())
942 if (!tensorType.isDynamicDim(index.getInt())) {
944 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
947 Operation *definingOp = getSource().getDefiningOp();
950 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
952 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
955 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
958 auto dynExtents = fromElements.getDynamicExtents().begin();
959 for (
auto dim : resultType.getShape().take_front(index.getInt()))
960 if (ShapedType::isDynamic(dim))
963 return Value{*dynExtents};
967 unsigned unsignedIndex = index.getValue().getZExtValue();
969 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
972 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
973 sliceOp.isDynamicSize(unsignedIndex)) {
974 return {sliceOp.getDynamicSize(unsignedIndex)};
990 LogicalResult matchAndRewrite(DimOp dimOp,
992 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
995 Value newSource = castOp.getOperand();
1006 LogicalResult matchAndRewrite(DimOp dimOp,
1008 auto source = dimOp.getSource();
1009 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1013 auto resultIndex = cast<OpResult>(source).getResultNumber();
1014 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1017 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1027 LogicalResult matchAndRewrite(DimOp dim,
1029 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1039 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
1040 if (extract.
getType() != dim.getType())
1042 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
1051 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1061 assert(all_of(staticShape,
1062 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
1063 "expected only static sizes");
1064 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
1071 build(builder, result, tensorType, dynamicSizes);
1080 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1085 return emitOpError(
"incorrect number of dynamic sizes, has ")
1087 <<
getType().getNumDynamicDims();
1096 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1097 if (
getType().isDynamicDim(i)) {
1106 Value EmptyOp::getDynamicSize(
unsigned idx) {
1107 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1109 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1110 if (
getType().isDynamicDim(i))
1119 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1120 if (
getType().isDynamicDim(i)) {
1144 LogicalResult matchAndRewrite(EmptyOp op,
1148 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1151 if (foldedTensorType == op.getType())
1154 auto newOp = rewriter.
create<EmptyOp>(op.getLoc(), foldedTensorType,
1155 foldedDynamicSizes);
1164 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1166 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1167 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1168 if (!emptyTensorOp || !maybeConstantIndex)
1170 auto emptyTensorType = emptyTensorOp.getType();
1171 if (*maybeConstantIndex < 0 ||
1172 *maybeConstantIndex >= emptyTensorType.getRank() ||
1173 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1176 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1199 LogicalResult matchAndRewrite(CastOp castOp,
1203 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1208 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1212 newMixedSizes.reserve(currMixedSizes.size());
1213 assert(resultShape.size() == currMixedSizes.size() &&
1214 "mismatch in result shape and sizes of empty op");
1215 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1216 int64_t newDim = std::get<0>(it);
1220 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1221 if (ShapedType::isDynamic(newDim) ||
1222 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1227 producer,
"mismatch in static value of shape of empty tensor "
1228 "result and cast result");
1230 newMixedSizes.push_back(attr);
1236 if (!ShapedType::isDynamic(newDim)) {
1237 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1243 newMixedSizes.push_back(currDim);
1248 resultType.getElementType());
1257 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1258 ReplaceEmptyTensorStaticShapeDims>(context);
1275 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1278 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1280 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1283 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1286 extract, tensorCast.getSource(), extract.getIndices());
1293 void ExtractOp::getAsmResultNames(
1295 setNameFn(getResult(),
"extracted");
1300 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1301 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1302 return emitOpError(
"incorrect number of indices for extract_element");
1307 if (
Attribute tensor = adaptor.getTensor()) {
1310 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1311 return splatTensor.getSplatValue<
Attribute>();
1314 if (isa<DenseResourceElementsAttr>(tensor))
1320 for (
Attribute indice : adaptor.getIndices()) {
1321 if (!indice || !llvm::isa<IntegerAttr>(indice))
1323 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1327 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1328 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1329 auto rank = tensorType.getRank();
1330 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1334 for (
int i = rank - 1; i >= 0; --i) {
1335 flatIndex += indices[i] * stride;
1336 stride *= tensorType.getDimSize(i);
1340 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1343 return fromElementsOp.getElements()[flatIndex];
1347 if (
Attribute tensor = adaptor.getTensor()) {
1348 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1349 if (elementsAttr && elementsAttr.isValidIndex(indices))
1350 return elementsAttr.getValues<
Attribute>()[indices];
1358 results.
add<ExtractFromTensorCast>(context);
1365 void FromElementsOp::getAsmResultNames(
1367 setNameFn(getResult(),
"from_elements");
1372 assert(!elements.empty() &&
"expected at least one element");
1374 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1375 build(builder, result, resultType, elements);
1378 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1379 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1402 struct ExtractElementFromIndexCast
1406 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1409 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1415 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1416 loc, elementTy, indexCast.getIn(), extract.getIndices());
1429 results.
add<ExtractElementFromIndexCast>(context);
1436 void GatherOp::getAsmResultNames(
1438 setNameFn(getResult(),
"gather");
1453 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1454 RankedTensorType indicesType,
1458 resultShape.reserve(resultShape.size() + sourceType.getRank());
1459 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1460 if (llvm::binary_search(gatherDims, idx)) {
1462 resultShape.push_back(1);
1465 resultShape.push_back(sourceType.getDimSize(idx));
1470 static LogicalResult
1473 StringRef gatherOrScatter, StringRef sourceOrDest) {
1475 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1477 int64_t numGatherDims = dims.size();
1478 if (numGatherDims > rank)
1480 <<
"_dims overflow " << sourceOrDest <<
" rank";
1481 if (indices.empty() || indices.back() != numGatherDims)
1483 <<
"_dims length must match the size of last dimension of indices";
1484 for (int64_t val : dims) {
1487 <<
"_dims value must be non-negative";
1490 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1492 for (int64_t i = 1; i < numGatherDims; ++i) {
1493 if (dims[i - 1] >= dims[i])
1495 <<
"_dims values must be strictly increasing";
1501 int64_t sourceRank = getSourceType().getRank();
1504 getIndicesType().
getShape(), sourceRank,
1505 "gather",
"source")))
1508 RankedTensorType expectedResultType = GatherOp::inferResultType(
1509 getSourceType(), getIndicesType(), gatherDims,
false);
1510 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1511 getSourceType(), getIndicesType(), gatherDims,
true);
1512 if (getResultType() != expectedResultType &&
1513 getResultType() != expectedRankReducedResultType) {
1514 return emitOpError(
"result type "
1517 << expectedResultType <<
" or its rank-reduced variant "
1518 << expectedRankReducedResultType <<
" (got: " << getResultType()
1526 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1527 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1529 return reshapedSource;
1537 void InsertOp::getAsmResultNames(
1539 setNameFn(getResult(),
"inserted");
1544 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1545 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1546 return emitOpError(
"incorrect number of indices");
1554 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1555 if (scalar == splatDest.getSplatValue<
Attribute>())
1564 void GenerateOp::getAsmResultNames(
1566 setNameFn(getResult(),
"generated");
1573 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1574 if (
getType().isDynamicDim(dim)) {
1575 reifiedReturnShapes[0][dim] = getOperand(idx++);
1577 reifiedReturnShapes[0][dim] =
1587 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1588 if (getNumOperands() != resultType.getNumDynamicDims())
1589 return emitError(
"must have as many index operands as dynamic extents "
1590 "in the result type");
1594 LogicalResult GenerateOp::verifyRegions() {
1595 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1597 if (!llvm::all_of(getBody().getArgumentTypes(),
1599 return emitError(
"all body arguments must be index");
1600 if (getBody().getNumArguments() != resultTy.getRank())
1601 return emitError(
"must have one body argument per input dimension");
1604 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1606 if (yieldOp.getValue().getType() != resultTy.getElementType())
1608 "body must be terminated with a `yield` operation of the tensor "
1614 void GenerateOp::build(
1618 build(b, result, resultTy, dynamicExtents);
1623 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1627 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1640 LogicalResult matchAndRewrite(GenerateOp generateOp,
1644 generateOp.getType(), generateOp.getDynamicExtents(),
1645 foldedDynamicSizes);
1648 if (foldedTensorType == generateOp.getType())
1651 auto loc = generateOp.getLoc();
1653 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1655 newOp.getBody().begin());
1657 generateOp.getType(), newOp);
1673 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1676 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1678 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1683 Block *body = &tensorFromElements.getBody().
front();
1686 rewriter.
clone(op, mapping);
1700 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1707 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1708 setNameFn(getResult(),
"rank");
1713 auto type = getOperand().getType();
1714 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1715 if (shapedType && shapedType.hasRank())
1717 return IntegerAttr();
1724 void ReshapeOp::getAsmResultNames(
1726 setNameFn(getResult(),
"reshape");
1730 int64_t numElements = 1;
1731 for (
auto dim : type.getShape())
1741 return emitOpError(
"element types of source and destination tensor "
1742 "types should be the same");
1746 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1747 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1749 if (resultRankedType) {
1750 if (operandRankedType && resultRankedType.hasStaticShape() &&
1751 operandRankedType.hasStaticShape()) {
1753 return emitOpError(
"source and destination tensor should have the "
1754 "same number of elements");
1756 if (ShapedType::isDynamic(shapeSize))
1757 return emitOpError(
"cannot use shape operand with dynamic length to "
1758 "reshape to statically-ranked tensor type");
1759 if (shapeSize != resultRankedType.getRank())
1761 "length of shape operand differs from the result's tensor rank");
1767 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1768 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1770 return reshapedSource;
1775 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1776 getSourceMutable().assign(reshapeOpProducer.getSource());
1780 auto source = getSource();
1781 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1782 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1783 if (!sourceTy || !resultTy || sourceTy != resultTy)
1788 if (sourceTy.getRank() == 1)
1791 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1792 auto elements = fromElements.getElements();
1794 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1795 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1796 auto element = elements[id];
1799 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1803 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1804 dynamicNoop &= dimOp.getSource() == source;
1808 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1812 dynamicNoop =
false;
1827 void CollapseShapeOp::getAsmResultNames(
1829 setNameFn(getResult(),
"collapsed");
1832 void ExpandShapeOp::getAsmResultNames(
1834 setNameFn(getResult(),
"expanded");
1837 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1838 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1839 "invalid resultDim");
1841 if (llvm::is_contained(it.value(), resultDim))
1843 llvm_unreachable(
"could not find reassociation group");
1846 FailureOr<SmallVector<OpFoldResult>>
1848 RankedTensorType expandedType,
1851 std::optional<SmallVector<OpFoldResult>> outputShape =
1856 return *outputShape;
1867 auto [staticOutputShape, dynamicOutputShape] =
1869 build(builder, result, cast<RankedTensorType>(resultType), src,
1871 dynamicOutputShape, staticOutputShape);
1879 auto tensorResultTy = cast<RankedTensorType>(resultType);
1880 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1881 builder, result.
location, tensorResultTy, reassociation, inputShape);
1883 if (succeeded(outputShape)) {
1884 outputShapeOrEmpty = *outputShape;
1886 build(builder, result, tensorResultTy, src, reassociation,
1887 outputShapeOrEmpty);
1895 getReassociationIndices());
1903 getReassociationIndices());
1906 RankedTensorType CollapseShapeOp::inferCollapsedType(
1908 return inferCollapsedType(
1910 type.getContext(), reassociation)));
1916 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1918 auto shape = type.getShape();
1920 newShape.reserve(reassociation.size());
1925 unsigned currentDim = 0;
1927 unsigned dim = m.getNumResults();
1928 auto band = shape.slice(currentDim, dim);
1930 if (llvm::is_contained(band, ShapedType::kDynamic))
1931 size = ShapedType::kDynamic;
1933 for (
unsigned d = 0; d < dim; ++d)
1934 size *= shape[currentDim + d];
1935 newShape.push_back(size);
1945 auto resultType = inferCollapsedType(
1946 llvm::cast<RankedTensorType>(src.
getType()),
1951 build(b, result, resultType, src, attrs);
1954 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1955 TensorReshapeOp, ExpandShapeOp>::value>
1957 RankedTensorType expandedType,
1958 RankedTensorType collapsedType) {
1963 auto maps = op.getReassociationMaps();
1964 RankedTensorType expectedType =
1965 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1967 return op.emitOpError(
"expected collapsed type to be ")
1968 << expectedType <<
", but got " << collapsedType;
1973 auto srcType = getSrcType();
1974 auto resultType = getResultType();
1976 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1977 return emitOpError(
"expected number of static shape dims to be equal to "
1978 "the output rank (")
1979 << resultType.getRank() <<
") but found "
1980 << getStaticOutputShape().size() <<
" inputs instead";
1982 if ((int64_t)getOutputShape().size() !=
1983 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1984 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1985 "static_output_shape: static_output_shape has ")
1986 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1987 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2000 template <
typename TensorReshapeOp>
2003 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2011 reshapeOp.getResultType(), attr.
getRawData());
2018 template <
typename TensorReshapeOp>
2023 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2025 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2026 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2030 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2037 template <
typename TensorReshapeOp>
2040 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2043 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2047 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2049 if (!shapedTy.hasStaticShape())
2053 fromElements.getElements());
2062 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2064 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2068 RankedTensorType srcType =
2069 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2070 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2071 srcType, collapseShapeOp.getReassociationMaps());
2073 if (newResultType == collapseShapeOp.getResultType()) {
2075 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2078 auto newOp = rewriter.
create<CollapseShapeOp>(
2079 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2080 collapseShapeOp.getReassociation());
2082 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2092 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2095 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2097 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2103 expandOp.getReassociationIndices();
2107 auto outputIt = expandOp.getOutputShape().begin();
2109 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2110 for (uint64_t outDim : innerReassoc) {
2111 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2118 Value val = *outputIt;
2120 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2121 dynamicOutputShape.push_back(val);
2127 newOutputShape[outDim] = cst.getSExtValue();
2129 dynamicOutputShape.push_back(val);
2135 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2140 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2141 for (
auto outDim : reassoc[inDim]) {
2142 auto ofr = newOutputShape[outDim];
2143 if (ShapedType::isDynamic(ofr)) {
2144 newInputShape[inDim] = ShapedType::kDynamic;
2147 newInputShape[inDim] *= ofr;
2154 newInputShape, expandOp.getSrcType().getElementType());
2156 newOutputShape, expandOp.getSrcType().getElementType());
2157 auto inputCast = rewriter.
create<CastOp>(expandOp.getLoc(), inputType,
2159 auto newExpand = rewriter.
create<ExpandShapeOp>(
2160 expandOp.getLoc(), outputType, inputCast.getResult(),
2161 expandOp.getReassociationIndices(), outputOfr);
2163 newExpand.getResult());
2174 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2175 FoldReshapeWithSplat<ExpandShapeOp>,
2176 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2184 tensor::DimOp, RankedTensorType>,
2185 FoldReshapeWithConstant<CollapseShapeOp>,
2186 FoldReshapeWithSplat<CollapseShapeOp>,
2187 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2191 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2192 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2193 adaptor.getOperands());
2196 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2197 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2198 adaptor.getOperands());
2205 void ExtractSliceOp::getAsmResultNames(
2207 setNameFn(getResult(),
"extracted_slice");
2213 RankedTensorType ExtractSliceOp::inferResultType(
2219 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2220 sourceTensorType.getRank() &&
2221 "unexpected staticSizes not equal to rank of source");
2223 sourceTensorType.getEncoding());
2226 RankedTensorType ExtractSliceOp::inferResultType(
2234 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2235 staticSizes, staticStrides);
2246 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2247 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2251 auto inferredType = llvm::cast<RankedTensorType>(
2252 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2253 int rankDiff = inferredType.getRank() - desiredResultRank;
2255 auto shape = inferredType.getShape();
2256 llvm::SmallBitVector dimsToProject =
2260 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2261 if (!dimsToProject.test(pos))
2262 projectedShape.push_back(shape[pos]);
2266 return inferredType;
2269 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2270 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2278 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2279 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2286 RankedTensorType resultType,
Value source,
2296 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2299 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2300 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2303 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2316 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2325 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2331 RankedTensorType resultType,
Value source,
2340 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2347 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2352 RankedTensorType expectedType) {
2357 return op->
emitError(
"expected rank to be smaller or equal to ")
2358 <<
"the other rank. ";
2360 return op->
emitError(
"expected type to be ")
2361 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2363 return op->
emitError(
"expected element type to be ")
2364 << expectedType.getElementType();
2366 llvm_unreachable(
"unexpected extract_slice op verification result");
2372 RankedTensorType sourceType = getSourceType();
2375 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2376 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2384 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2385 getStaticStrides(),
true);
2387 return getOperation()->emitError(boundsResult.
errorMessage);
2399 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2400 assert(sourceTensorType &&
"not a ranked tensor type");
2401 auto sourceShape = sourceTensorType.getShape();
2402 if (sourceShape.equals(desiredShape))
2404 auto maybeRankReductionMask =
2406 if (!maybeRankReductionMask)
2415 reifiedReturnShapes.resize(1);
2416 reifiedReturnShapes[0].reserve(
getType().getRank());
2419 for (
const auto &size :
enumerate(mixedSizes)) {
2420 if (droppedDims.test(size.index()))
2422 reifiedReturnShapes[0].push_back(size.value());
2443 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2447 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2450 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2451 return matchPattern(operand, matchConstantIndex());
2455 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2464 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2465 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2466 sliceOp.getStaticStrides());
2472 Value newResult = rewriter.
create<ExtractSliceOp>(
2473 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2474 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2475 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2484 template <
typename IterTy,
typename ElemTy>
2489 assert(offsets.size() == sizes.size());
2490 assert(offsets.size() == strides.size());
2491 if (offsets.empty())
2494 int64_t offset = offsets.front();
2495 int64_t size = sizes.front();
2496 int64_t stride = strides.front();
2497 if (offsets.size() == 1) {
2498 for (int64_t i = 0; i < size; ++i, offset += stride)
2499 outValues->push_back(*(values + offset));
2504 for (int64_t i = 0; i < size; ++i, offset += stride) {
2505 auto begin = values + offset * counts.front();
2506 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2507 offsets.drop_front(), sizes.drop_front(),
2508 strides.drop_front(), outValues);
2515 class ConstantOpExtractSliceFolder final
2520 ConstantOpExtractSliceFolder(
MLIRContext *context,
2523 controlFn(std::move(controlFn)) {}
2525 LogicalResult matchAndRewrite(ExtractSliceOp op,
2536 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2537 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2538 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2545 int64_t count = sourceType.getNumElements();
2550 auto offsets = op.getStaticOffsets();
2551 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2553 auto sizes = op.getStaticSizes();
2554 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2556 auto strides = op.getStaticStrides();
2557 if (llvm::is_contained(strides, ShapedType::kDynamic))
2563 counts.reserve(shape.size());
2564 for (int64_t v : shape) {
2566 counts.push_back(count);
2572 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2574 outValues.reserve(sourceType.getNumElements());
2575 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2576 elems.begin(), counts, offsets, sizes, strides, &outValues);
2578 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2580 outValues.reserve(sourceType.getNumElements());
2581 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2582 elems.begin(), counts, offsets, sizes, strides, &outValues);
2605 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2614 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2615 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2623 ExtractSliceOp newOp) {
2624 Value replacement = newOp.getResult();
2625 if (replacement.
getType() != op.getType())
2626 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2637 ExtractSliceOpCastFolder>(context);
2641 static LogicalResult
2643 ShapedType shapedType) {
2650 auto shape = shapedType.getShape();
2651 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2665 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2668 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2669 insertOp.isSameAs(extractOp, isSame))
2670 return insertOp.getSource();
2675 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2676 if (
OpFoldResult reshapedSource = reshapeConstantSource(
2677 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2679 return reshapedSource;
2680 if (getSourceType() ==
getType() &&
2682 return this->getSource();
2691 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2692 unsigned rank = rankedTensorType.getRank();
2696 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2697 offsets, sizes, strides);
2704 void InsertSliceOp::getAsmResultNames(
2706 setNameFn(getResult(),
"inserted_slice");
2721 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2733 build(b, result, source, dest, offsets, sizes, strides, attrs);
2746 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2752 RankedTensorType srcType, RankedTensorType dstType,
2757 RankedTensorType expected = ExtractSliceOp::inferResultType(
2758 dstType, staticOffsets, staticSizes, staticStrides);
2760 *expectedType = expected;
2767 RankedTensorType expectedType;
2770 getStaticSizes(), getStaticStrides(), &expectedType);
2777 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2778 getStaticStrides(),
true);
2780 return getOperation()->emitError(boundsResult.
errorMessage);
2803 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2806 if (!prevInsertOp ||
2807 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2808 !prevInsertOp.isSameAs(insertOp, isSame))
2811 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2823 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2826 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2827 !extractOp.isSameAs(insertOp, isSame))
2830 return extractOp.getSource();
2834 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2835 getSourceType() ==
getType() &&
2837 return this->getSource();
2858 template <
typename InsertOpTy>
2859 class InsertSliceOpConstantArgumentFolder final
2864 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2879 mixedOffsets, mixedSizes, mixedStrides);
2884 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2885 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2886 mixedOffsets, mixedSizes, mixedStrides);
2887 Value toInsert = insertSliceOp.getSource();
2888 if (sourceType != insertSliceOp.getSourceType()) {
2893 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2895 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2896 sourceType, toInsert);
2899 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2900 mixedSizes, mixedStrides);
2925 template <
typename InsertOpTy>
2926 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2929 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2931 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2932 return matchPattern(operand, matchConstantIndex());
2936 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2937 auto castOp = v.getDefiningOp<tensor::CastOp>();
2939 return std::nullopt;
2940 return castOp.getSource();
2942 std::optional<Value> sourceCastSource =
2943 getSourceOfCastOp(insertSliceOp.getSource());
2944 std::optional<Value> destCastSource =
2945 getSourceOfCastOp(insertSliceOp.getDest());
2946 if (!sourceCastSource && !destCastSource)
2950 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2951 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2952 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2953 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2954 if (!srcType || !dstType)
2962 staticSizes, srcType.getShape(),
true);
2963 if (!rankReductionMask.has_value())
2971 int64_t rankReducedIdx = 0;
2972 for (
auto [idx, size] :
enumerate(staticSizes)) {
2973 if (!rankReductionMask.value().contains(idx) &&
2974 !srcType.isDynamicDim(rankReducedIdx)) {
2976 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2977 size = srcType.getDimSize(rankReducedIdx++);
2983 staticSizes, insertSliceOp.getStaticStrides()) !=
2988 mixedSizes, insertSliceOp.getMixedStrides());
2993 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2994 mixedSizes, insertSliceOp.getMixedStrides());
2997 bool isParallelInsert =
2998 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2999 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3000 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
3001 insertSliceOp.getDestType(),
3030 template <
typename InsertOpTy>
3031 struct InsertSliceOpSourceCastInserter final
3035 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3037 RankedTensorType srcType = insertSliceOp.getSourceType();
3038 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3041 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3042 if (std::optional<int64_t> constInt =
3047 newSrcShape[i] = *constInt;
3054 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3055 if (srcType == newSrcType ||
3057 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3069 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3072 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3074 insertSliceOp, cast, insertSliceOp.getDest(),
3075 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3076 insertSliceOp.getMixedStrides());
3088 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3089 InsertSliceOpCastFolder<InsertSliceOp>,
3090 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3097 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3098 unsigned rank = rankedTensorType.getRank();
3102 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3111 setNameFn(getResult(),
"padded");
3117 Type typeToInfer,
Type typeToInferFrom) {}
3121 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3122 Type &typeToInfer,
Type typeToInferFrom) {
3124 typeToInfer = typeToInferFrom;
3129 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3130 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3132 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3133 if (!expectedType) {
3134 return emitError(
"failed to infer expectedType from sourceType ")
3135 << sourceType <<
", specified resultType is " << resultType;
3137 if (resultType.getRank() != expectedType.getRank()) {
3139 << resultType <<
" does not match the inferred type "
3142 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3143 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3145 if (expectedType.isDynamicDim(i))
3148 << resultType <<
" does not match the inferred type "
3155 LogicalResult PadOp::verifyRegions() {
3156 auto ®ion = getRegion();
3157 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3160 return emitError(
"expected the block to have ") << rank <<
" arguments";
3164 if (!en.value().isIndex())
3165 return emitOpError(
"expected block argument ")
3166 << (en.index() + 1) <<
" to be an index";
3171 if (yieldOp.getValue().getType() !=
3173 return emitOpError(
"expected yield type to match shape element type");
3178 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3182 unsigned rank = sourceType.getRank();
3183 if (staticLow.size() != rank)
3184 return RankedTensorType();
3185 if (staticHigh.size() != rank)
3186 return RankedTensorType();
3187 if (!resultShape.empty() && resultShape.size() != rank)
3188 return RankedTensorType();
3191 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3192 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3193 staticHigh[i] == ShapedType::kDynamic) {
3194 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3197 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3198 assert((resultShape.empty() || size == resultShape[i] ||
3199 resultShape[i] == ShapedType::kDynamic) &&
3200 "mismatch between inferred shape and result shape");
3201 inferredShape.push_back(size);
3212 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3214 resultType = inferResultType(sourceType, staticLow, staticHigh);
3216 build(b, result, resultType, source, low, high,
3224 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3225 unsigned rank = sourceType.getRank();
3227 build(b, result, resultType, source, staticVector, staticVector, low, high,
3235 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3245 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3247 assert(llvm::isa<RankedTensorType>(resultType));
3249 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3258 build(b, result, resultType, source, low, high, nofold, attrs);
3262 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3269 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3273 llvm::SmallBitVector PadOp::getPaddedDims() {
3274 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3276 for (
const auto &en :
enumerate(paddingWidths))
3278 paddedDims.set(en.index());
3280 extractPaddedDims(getMixedLowPad());
3281 extractPaddedDims(getMixedHighPad());
3291 LogicalResult matchAndRewrite(PadOp padTensorOp,
3293 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3295 if (padTensorOp.getNofold())
3298 padTensorOp, padTensorOp.getResult().getType(),
3299 padTensorOp.getSource());
3308 LogicalResult matchAndRewrite(PadOp padTensorOp,
3310 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3314 auto newResultType = PadOp::inferResultType(
3315 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3316 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3317 padTensorOp.getResultType().getShape());
3319 if (newResultType == padTensorOp.getResultType()) {
3321 padTensorOp.getSourceMutable().assign(castOp.getSource());
3324 auto newOp = rewriter.
create<PadOp>(
3325 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3326 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3327 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3330 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3333 padTensorOp, padTensorOp.getResultType(), newOp);
3344 LogicalResult matchAndRewrite(PadOp padTensorOp,
3346 if (!padTensorOp.getResult().hasOneUse())
3349 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3353 tensorCastOp.getDest().getType()))
3356 auto replacementOp = rewriter.
create<PadOp>(
3357 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3358 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3359 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3360 padTensorOp.getHigh(), padTensorOp.getNofold(),
3364 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3365 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3408 LogicalResult matchAndRewrite(PadOp padOp,
3410 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3413 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3414 if (!outerPadOp || outerPadOp.getNofold())
3416 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3421 int64_t rank = padOp.getSourceType().getRank();
3422 if (outerSliceOp.getSourceType().getRank() != rank) {
3424 "cannot fold rank-reducing chain");
3428 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3430 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3434 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3436 "cannot fold PadOps with low padding");
3441 Value innerValue = padOp.getConstantPaddingValue();
3442 Value outerValue = outerPadOp.getConstantPaddingValue();
3443 if (!innerValue || !outerValue ||
3446 innerAttr != outerAttr) {
3448 padOp,
"cannot fold PadOps with different padding values");
3452 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3453 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3454 if (innerDims.anyCommon(outerDims)) {
3456 padOp,
"cannot fold PadOps with common padding dimensions");
3466 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3467 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3468 if (!innerDims.test(en.index()) &&
3470 en.value() = outerOffset;
3473 if (!outerDims.test(en.index()) &&
3475 en.value() = innerOffset;
3479 padOp,
"cannot find zero-offset and zero-padding pair");
3489 if (!outerDims.test(en.index()))
3491 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3492 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3493 assert(!ShapedType::isDynamic(sourceSize) &&
3494 "expected padded dimension to have a static size");
3497 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3498 "match the size of the outer padding");
3500 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3506 if (innerDims.test(en.index()))
3507 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3508 if (outerDims.test(en.index()))
3509 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3514 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3515 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3516 innerSliceOp.getMixedStrides());
3517 auto newPadOp = rewriter.
create<PadOp>(
3518 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3519 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3522 newPadOp.getRegion().begin());
3523 rewriter.
replaceOp(padOp, newPadOp.getResult());
3531 LogicalResult matchAndRewrite(PadOp padTensorOp,
3533 Value input = padTensorOp.getSource();
3534 if (!llvm::isa<RankedTensorType>(input.
getType()))
3536 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3537 auto inputRank = inputDims.size();
3539 auto oldResultType =
3540 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3544 auto outputDims = oldResultType.getShape();
3549 for (
auto operand : padTensorOp.getLow()) {
3552 constOperandsLow.push_back(ShapedType::kDynamic);
3553 newLows.push_back(operand);
3556 constOperandsLow.push_back(intOp.getExtValue());
3560 for (
auto operand : padTensorOp.getHigh()) {
3563 constOperandsHigh.push_back(ShapedType::kDynamic);
3564 newHighs.push_back(operand);
3567 constOperandsHigh.push_back(intOp.getExtValue());
3574 if (inputDims.size() != outputDims.size() ||
3575 inputDims.size() != constLow.size() ||
3576 inputDims.size() != constHigh.size())
3581 for (
size_t i = 0; i < inputRank; i++) {
3582 if (constLow[i] == ShapedType::kDynamic)
3583 constLow[i] = constOperandsLow[lowCount++];
3584 if (constHigh[i] == ShapedType::kDynamic)
3585 constHigh[i] = constOperandsHigh[highCount++];
3593 for (
size_t i = 0; i < inputRank; i++) {
3594 if (outputDims[i] == ShapedType::kDynamic) {
3595 newOutDims.push_back(
3596 (staticLow[i] == ShapedType::kDynamic ||
3597 staticHigh[i] == ShapedType::kDynamic ||
3598 inputDims[i] == ShapedType::kDynamic
3599 ? ShapedType::kDynamic
3600 : inputDims[i] + staticLow[i] + staticHigh[i]));
3602 newOutDims.push_back(outputDims[i]);
3607 llvm::all_of(newOutDims,
3608 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3613 newOutDims, padTensorOp.getType().getElementType());
3614 auto newOp = rewriter.
create<PadOp>(
3615 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3616 newLows, newHighs, padTensorOp.getNofold(),
3620 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3648 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3651 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3653 if (padOp.getNofold()) {
3657 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3658 if (!producerPad || producerPad.getNofold()) {
3660 padOp,
"producer is not a foldable tensor.pad op");
3664 Value consumerPadValue = padOp.getConstantPaddingValue();
3665 Value producerPadValue = producerPad.getConstantPaddingValue();
3666 if (!consumerPadValue || !producerPadValue ||
3667 consumerPadValue != producerPadValue) {
3670 "cannot fold PadOps with different or non-constant padding values");
3681 for (
auto [consumerIndex, producerIndex] :
3682 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3684 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3690 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3692 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3694 auto newPadOp = rewriter.
create<tensor::PadOp>(
3695 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3696 newLowPad, newHighPad, padOp.getNofold(),
3699 newPadOp.getRegion().begin());
3700 rewriter.
replaceOp(padOp, newPadOp.getResult());
3709 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3710 FoldOrthogonalPaddings, FoldStaticPadding,
3711 FoldConsecutiveConstantPadding>(context);
3723 Value PadOp::getConstantPaddingValue() {
3724 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3727 Value padValue = yieldOp.getValue();
3739 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3749 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3750 ParallelCombiningOpInterface parallelCombiningParent =
3751 getParallelCombiningParent();
3752 for (
const auto &it :
3755 if (&nextOp == getOperation())
3756 return parallelCombiningParent.getParentResult(it.index());
3758 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3774 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3787 build(b, result, source, dest, offsets, sizes, strides, attrs);
3801 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3805 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3806 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3807 << *(getOperation()->getParentOp());
3810 RankedTensorType expectedType;
3813 getStaticSizes(), getStaticStrides(), &expectedType);
3820 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3821 getStaticStrides(),
true);
3823 return getOperation()->emitError(boundsResult.
errorMessage);
3828 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3830 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3831 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3832 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3843 void ScatterOp::getAsmResultNames(
3845 setNameFn(getResult(),
"scatter");
3849 int64_t destRank = getDestType().getRank();
3852 getIndicesType().
getShape(), destRank,
3853 "scatter",
"dest")))
3857 return emitOpError(
"requires 'unique' attribute to be set");
3864 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3865 getDestType(), getIndicesType(), scatterDims,
false);
3866 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3867 getDestType(), getIndicesType(), scatterDims,
true);
3868 if (getSourceType() != expectedSourceType &&
3869 getSourceType() != expectedRankReducedSourceType) {
3870 return emitOpError(
"source type "
3873 << expectedSourceType <<
" or its rank-reduced variant "
3874 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3887 build(builder, result, aggregateType, element, dynamicSizes);
3893 build(builder, result, aggregateType, element, dynamicSizes);
3901 build(builder, result, element, staticShape, dynamicSizes);
3904 void SplatOp::getAsmResultNames(
3906 setNameFn(getResult(),
"splat");
3911 return emitOpError(
"incorrect number of dynamic sizes, has ")
3913 <<
getType().getNumDynamicDims();
3922 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3923 if (
getType().isDynamicDim(i)) {
3933 auto constOperand = adaptor.getInput();
3934 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3938 if (!
getType().hasStaticShape())
3953 if (isa<InsertSliceOp>(op.getOperation()) ||
3954 isa<LoopLikeOpInterface>(op.getOperation()))
3987 isa<linalg::RelayoutOpInterface>(*op))
3995 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
3998 replacements.reserve(newOp->getNumResults());
3999 for (
auto [oldResult, newResult] :
4000 llvm::zip(op->getResults(), newOp->getResults())) {
4001 if (newResult.
getType() != oldResult.getType()) {
4002 replacements.push_back(rewriter.
create<tensor::CastOp>(
4003 op->getLoc(), oldResult.getType(), newResult));
4005 replacements.push_back(newResult);
4018 void TensorDialect::getCanonicalizationPatterns(
4027 #define GET_OP_CLASSES
4028 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.