33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/LogicalResult.h"
39 #include "llvm/Support/MathExtras.h"
47 using llvm::divideCeilSigned;
48 using llvm::divideFloorSigned;
56 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
58 if (complex::ConstantOp::isBuildableWith(value, type))
59 return builder.
create<complex::ConstantOp>(loc, type,
60 llvm::cast<ArrayAttr>(value));
66 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
67 if (tensorType.isDynamicDim(dim))
68 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
75 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
77 for (int64_t i = 0; i < tensorType.getRank(); ++i)
84 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
85 assert(tensorType &&
"expected tensor type");
89 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
91 return destOp.getTiedOpOperand(opResult)->get();
99 if (!tensorType.hasStaticShape()) {
107 for (int64_t sz : tensorType.getShape())
113 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
121 if (llvm::isa<TensorType>(opResult.getType())) {
123 if (failed(destination))
125 result.push_back(*destination);
132 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
133 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
134 return rtp1.getShape() == rtp2.getShape() &&
135 rtp1.getElementType() == rtp2.getElementType();
145 llvm::SmallBitVector droppedDims(mixedSizes.size());
146 int64_t shapePos = reducedShape.size() - 1;
148 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
149 size_t idx = mixedSizes.size() - size.index() - 1;
151 bool isStaticUnitSize =
152 isa<Attribute>(size.value()) &&
153 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
158 assert(isStaticUnitSize &&
"expected unit dim");
159 droppedDims.set(idx);
164 if (!isStaticUnitSize) {
170 if (reducedShape[shapePos] == 1) {
176 droppedDims.set(idx);
179 assert(shapePos < 0 &&
"dimension mismatch");
186 static RankedTensorType
190 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
191 "incorrect number of dynamic sizes");
195 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
196 if (type.isDynamicDim(i)) {
197 Value dynamicSize = dynamicSizes[ctr++];
199 if (cst.has_value()) {
201 if (cst.value() < 0) {
202 foldedDynamicSizes.push_back(dynamicSize);
205 staticShape[i] = *cst;
207 foldedDynamicSizes.push_back(dynamicSize);
221 if (inputs.size() != 1 || outputs.size() != 1)
223 Type a = inputs.front(), b = outputs.front();
224 auto aT = dyn_cast<TensorType>(a);
225 auto bT = dyn_cast<TensorType>(b);
229 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
242 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
244 auto tensorBitcastOperand =
245 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
246 if (!tensorBitcastOperand)
249 auto resultType = cast<TensorType>(tensorBitcast.getType());
250 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
251 tensorBitcastOperand.getOperand());
260 results.
add<ChainedTensorBitcast>(context);
268 setNameFn(getResult(),
"cast");
274 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
275 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
278 if (!sourceType || !targetType)
282 if (sourceType.getElementType() != targetType.getElementType())
286 if (sourceType.getRank() != targetType.getRank())
290 if (sourceType.getEncoding() != targetType.getEncoding())
294 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
295 if (!ShapedType::isDynamic(std::get<0>(t)) &&
296 ShapedType::isDynamic(std::get<1>(t)))
332 castOp.getSource().getType());
365 if (llvm::isa<BlockArgument>(opOperand.get()))
367 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
368 return castOp && canFoldIntoConsumerOp(castOp);
375 newOperands.reserve(op->getNumOperands());
380 int64_t dpsInitIdx = 0;
381 for (
OpOperand &opOperand : op->getOpOperands()) {
382 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
384 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
385 if (op.isDpsInit(&opOperand) &&
386 !llvm::isa<MemRefType>(newOperands.back().getType()))
387 newResTy[dpsInitIdx++] = newOperands.back().getType();
397 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
399 operand.set(castOp.getOperand());
403 return success(folded);
407 if (inputs.size() != 1 || outputs.size() != 1)
409 Type a = inputs.front(), b = outputs.front();
410 auto aT = llvm::dyn_cast<TensorType>(a);
411 auto bT = llvm::dyn_cast<TensorType>(b);
415 if (aT.getElementType() != bT.getElementType())
431 int64_t rank = one.getRank();
432 if (rank != two.getRank())
437 for (int64_t i = 0; i < rank; ++i) {
438 if (one.isDynamicDim(i)) {
439 join.push_back(two.getDimSize(i));
442 if (two.isDynamicDim(i)) {
443 join.push_back(one.getDimSize(i));
446 if (one.getDimSize(i) != two.getDimSize(i))
448 join.push_back(one.getDimSize(i));
460 LogicalResult matchAndRewrite(CastOp tensorCast,
462 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
464 if (!tensorCastOperand)
468 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
469 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
470 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
484 auto newJoin =
joinShapes(sourceType, resultType);
485 if (firstJoin != newJoin)
488 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
489 tensorCastOperand.getOperand());
509 LogicalResult matchAndRewrite(CastOp tensorCast,
511 auto extractOperand =
512 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
515 auto rankedResultType =
516 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
517 if (!rankedResultType)
521 rankedResultType.getShape() ==
522 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
528 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
530 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
531 if (dimMask && dimMask->count(i))
533 int64_t dim = rankedResultType.getShape()[dimIndex++];
534 if (ShapedType::isDynamic(dim))
536 sizes[i] = rewriter.getIndexAttr(dim);
539 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
540 tensorCast, rankedResultType, extractOperand.getSource(),
541 extractOperand.getMixedOffsets(), sizes,
542 extractOperand.getMixedStrides());
551 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
558 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
559 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
561 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
562 return llvm::cast<RankedTensorType>(type);
564 int64_t concatRank = tensorTypes[0].getRank();
567 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
570 for (int64_t i = 0, e = concatRank; i < e; ++i) {
574 for (
auto tensorType : tensorTypes)
579 for (
auto tensorType : tensorTypes)
582 sizes[dim] = concatSize.asInteger();
588 FailureOr<RankedTensorType> resultType =
589 inferResultType(dim, inputs.
getTypes());
590 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
591 build(builder, result, *resultType, dim, inputs);
595 if (getInputs().size() < 1)
596 return emitOpError(
"requires at least one input");
599 for (
auto input : getInputs())
600 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
602 RankedTensorType resultType = getResultType();
603 int64_t resultRank = getRank();
604 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
605 return type.getRank() != resultRank;
607 return emitOpError(
"rank of concatenated inputs must match result rank");
609 Type resultElementType = resultType.getElementType();
610 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
611 return type.getElementType() != resultElementType;
613 return emitOpError(
"inputs and result element type must match");
615 int64_t dim = getDim();
616 if (dim >= resultRank)
617 return emitOpError(
"concatenation dim must be less than the tensor rank");
620 for (int64_t i = 0, e = resultRank; i < e; ++i) {
624 for (
auto tensorType : inputTypes) {
625 FailureOr<SaturatedInteger> maybeSize =
627 if (failed(maybeSize))
628 return emitOpError(
"static concatenation size mismatch along ")
629 <<
"non-concatenated dimension " << i;
635 for (
auto tensorType : inputTypes)
638 sizes[dim] = concatSize.asInteger();
639 auto inferredResultType =
642 for (
auto [inferredSize, actualSize] :
643 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
644 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
645 ShapedType::isDynamic(actualSize);
646 if (!hasDynamic && inferredSize != actualSize)
647 return emitOpError(
"result type ")
648 << resultType <<
"does not match inferred shape "
649 << inferredResultType <<
" static sizes";
655 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
656 size_t numInputs = getInputs().size();
657 uint64_t concatDim = getDim();
660 inputShapes.reserve(numInputs);
662 concatOffsets.reserve(numInputs);
673 outputShape = inputShape;
674 concatOffsets.push_back(zero);
676 concatOffsets.push_back(outputShape[concatDim]);
678 builder, loc, addExpr,
679 {outputShape[concatDim], inputShape[concatDim]});
681 inputShapes.emplace_back(std::move(inputShape));
684 Value replacement = builder.
create<tensor::EmptyOp>(
685 loc, outputShape,
getType().getElementType());
687 int64_t rank =
getType().getRank();
692 offsets[concatDim] = concatOffsets[index];
693 auto insertSlice = builder.
create<tensor::InsertSliceOp>(
694 loc, input, replacement, offsets, inputShapes[index], strides);
697 if (replacement.getType() !=
getType()) {
698 replacement = builder.
create<tensor::CastOp>(loc,
getType(), replacement);
707 int64_t dim = getDim();
708 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
710 Value init = inputs[0];
711 int64_t rank =
getType().getRank();
718 for (int64_t i = 0; i < rank; ++i) {
721 if (!
getType().isDynamicDim(i)) {
723 }
else if (!inferredResultType.isDynamicDim(i)) {
726 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
728 reifiedReturnShapes[0][i] =
729 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
733 if (
getType().isDynamicDim(dim)) {
741 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
749 reifiedReturnShapes[0][dim] =
755 void ConcatOp::getAsmResultNames(
757 setNameFn(getResult(),
"concat");
762 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
772 LogicalResult matchAndRewrite(ConcatOp concatOp,
774 if (concatOp.getInputs().size() != 1)
777 concatOp.getInputs()[0]);
804 LogicalResult matchAndRewrite(ConcatOp concatOp,
806 int64_t dim = concatOp.getDim();
807 RankedTensorType inferredResultType =
808 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
811 LogicalResult matched = failure();
815 for (
auto [operandIdx, operandType] :
818 inferredOperandShape[dim] =
819 cast<RankedTensorType>(operandType).getDimSize(dim);
821 inferredOperandShape, inferredResultType.getElementType());
829 rewriter.
create<CastOp>(concatOp->getLoc(), inferredOperandType,
830 concatOp.getOperand(operandIdx));
832 concatOp->setOperand(operandIdx, castOp->getResult(0));
858 LogicalResult matchAndRewrite(ConcatOp concatOp,
860 int64_t dim = concatOp.getDim();
861 RankedTensorType inferredResultType =
862 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
866 concatOp.getResultType())) {
870 auto newConcatOp = rewriter.
create<ConcatOp>(
871 concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
883 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
892 setNameFn(getResult(),
"dim");
898 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
899 build(builder, result, source, indexValue);
902 std::optional<int64_t> DimOp::getConstantIndex() {
911 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
912 if (!rankedSourceType)
923 setResultRange(getResult(),
929 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
934 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
940 int64_t indexVal = index.getInt();
941 if (indexVal < 0 || indexVal >= tensorType.getRank())
945 if (!tensorType.isDynamicDim(index.getInt())) {
947 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
950 Operation *definingOp = getSource().getDefiningOp();
953 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
955 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
958 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
961 auto dynExtents = fromElements.getDynamicExtents().begin();
962 for (
auto dim : resultType.getShape().take_front(index.getInt()))
963 if (ShapedType::isDynamic(dim))
966 return Value{*dynExtents};
970 unsigned unsignedIndex = index.getValue().getZExtValue();
972 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
975 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
976 sliceOp.isDynamicSize(unsignedIndex)) {
977 return {sliceOp.getDynamicSize(unsignedIndex)};
993 LogicalResult matchAndRewrite(DimOp dimOp,
995 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
998 Value newSource = castOp.getOperand();
1009 LogicalResult matchAndRewrite(DimOp dimOp,
1011 auto source = dimOp.getSource();
1012 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1016 auto resultIndex = cast<OpResult>(source).getResultNumber();
1017 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1020 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1030 LogicalResult matchAndRewrite(DimOp dim,
1032 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1042 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
1043 if (extract.
getType() != dim.getType())
1045 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
1054 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1064 assert(none_of(staticShape, ShapedType::isDynamic) &&
1065 "expected only static sizes");
1066 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
1073 build(builder, result, tensorType, dynamicSizes);
1082 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1087 return emitOpError(
"incorrect number of dynamic sizes, has ")
1089 <<
getType().getNumDynamicDims();
1098 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1099 if (
getType().isDynamicDim(i)) {
1108 Value EmptyOp::getDynamicSize(
unsigned idx) {
1109 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1111 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1112 if (
getType().isDynamicDim(i))
1121 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1122 if (
getType().isDynamicDim(i)) {
1146 LogicalResult matchAndRewrite(EmptyOp op,
1150 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1153 if (foldedTensorType == op.getType())
1156 auto newOp = rewriter.
create<EmptyOp>(op.getLoc(), foldedTensorType,
1157 foldedDynamicSizes);
1166 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1168 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1169 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1170 if (!emptyTensorOp || !maybeConstantIndex)
1172 auto emptyTensorType = emptyTensorOp.getType();
1173 if (*maybeConstantIndex < 0 ||
1174 *maybeConstantIndex >= emptyTensorType.getRank() ||
1175 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1178 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1201 LogicalResult matchAndRewrite(CastOp castOp,
1205 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1210 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1214 newMixedSizes.reserve(currMixedSizes.size());
1215 assert(resultShape.size() == currMixedSizes.size() &&
1216 "mismatch in result shape and sizes of empty op");
1217 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1218 int64_t newDim = std::get<0>(it);
1222 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1223 if (ShapedType::isDynamic(newDim) ||
1224 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1229 producer,
"mismatch in static value of shape of empty tensor "
1230 "result and cast result");
1232 newMixedSizes.push_back(attr);
1238 if (!ShapedType::isDynamic(newDim)) {
1239 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1245 newMixedSizes.push_back(currDim);
1250 resultType.getElementType());
1259 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1260 ReplaceEmptyTensorStaticShapeDims>(context);
1277 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1280 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1282 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1285 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1288 extract, tensorCast.getSource(), extract.getIndices());
1303 struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1306 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1309 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1312 if (!collapseOp.getSrcType().hasStaticShape())
1315 auto sourceSizes = collapseOp.getSrcType().getShape();
1318 extractOp.getIndices().end());
1320 for (
auto [index, group] :
1321 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1322 assert(!group.empty() &&
"association indices groups cannot be empty");
1323 auto groupSize = group.size();
1325 if (groupSize == 1) {
1326 sourceIndices.push_back(index);
1331 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1333 extractOp.getLoc(), index, basis,
true);
1334 llvm::append_range(sourceIndices,
delinearize.getResults());
1336 if (collapseOp.getReassociationIndices().empty()) {
1339 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1341 rewriter, extractOp.getLoc(), zeroAffineMap,
1343 for (int64_t i = 0; i < srcRank; i++) {
1344 sourceIndices.push_back(
1350 extractOp, collapseOp.getSrc(), sourceIndices);
1357 void ExtractOp::getAsmResultNames(
1359 setNameFn(getResult(),
"extracted");
1364 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1365 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1366 return emitOpError(
"incorrect number of indices for extract_element");
1375 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1380 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1381 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1382 return insertOp.getScalar();
1388 if (
Attribute tensor = adaptor.getTensor()) {
1391 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1392 return splatTensor.getSplatValue<
Attribute>();
1395 if (isa<DenseResourceElementsAttr>(tensor))
1401 for (
Attribute indice : adaptor.getIndices()) {
1402 if (!indice || !llvm::isa<IntegerAttr>(indice))
1404 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1408 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1409 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1410 auto rank = tensorType.getRank();
1411 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1415 for (
int i = rank - 1; i >= 0; --i) {
1416 flatIndex += indices[i] * stride;
1417 stride *= tensorType.getDimSize(i);
1421 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1424 return fromElementsOp.getElements()[flatIndex];
1428 if (
Attribute tensor = adaptor.getTensor()) {
1429 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1430 if (elementsAttr && elementsAttr.isValidIndex(indices))
1431 return elementsAttr.getValues<
Attribute>()[indices];
1442 results.
add<ExtractFromTensorCast>(context);
1454 void FromElementsOp::getAsmResultNames(
1456 setNameFn(getResult(),
"from_elements");
1461 assert(!elements.empty() &&
"expected at least one element");
1463 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1464 build(builder, result, resultType, elements);
1467 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1468 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1491 struct ExtractElementFromIndexCast
1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1498 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1504 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1505 loc, elementTy, indexCast.getIn(), extract.getIndices());
1518 results.
add<ExtractElementFromIndexCast>(context);
1525 void GatherOp::getAsmResultNames(
1527 setNameFn(getResult(),
"gather");
1542 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543 RankedTensorType indicesType,
1547 resultShape.reserve(resultShape.size() + sourceType.getRank());
1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549 if (llvm::binary_search(gatherDims, idx)) {
1551 resultShape.push_back(1);
1554 resultShape.push_back(sourceType.getDimSize(idx));
1559 static LogicalResult
1562 StringRef gatherOrScatter, StringRef sourceOrDest) {
1564 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1566 int64_t numGatherDims = dims.size();
1567 if (numGatherDims > rank)
1569 <<
"_dims overflow " << sourceOrDest <<
" rank";
1570 if (indices.empty() || indices.back() != numGatherDims)
1572 <<
"_dims length must match the size of last dimension of indices";
1573 for (int64_t val : dims) {
1576 <<
"_dims value must be non-negative";
1579 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1581 for (int64_t i = 1; i < numGatherDims; ++i) {
1582 if (dims[i - 1] >= dims[i])
1584 <<
"_dims values must be strictly increasing";
1590 int64_t sourceRank = getSourceType().getRank();
1593 getIndicesType().
getShape(), sourceRank,
1594 "gather",
"source")))
1597 RankedTensorType expectedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims,
false);
1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600 getSourceType(), getIndicesType(), gatherDims,
true);
1601 if (getResultType() != expectedResultType &&
1602 getResultType() != expectedRankReducedResultType) {
1603 return emitOpError(
"result type "
1606 << expectedResultType <<
" or its rank-reduced variant "
1607 << expectedRankReducedResultType <<
" (got: " << getResultType()
1615 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1616 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1618 return reshapedSource;
1626 void InsertOp::getAsmResultNames(
1628 setNameFn(getResult(),
"inserted");
1633 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1634 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1635 return emitOpError(
"incorrect number of indices");
1643 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1644 if (scalar == splatDest.getSplatValue<
Attribute>())
1653 void GenerateOp::getAsmResultNames(
1655 setNameFn(getResult(),
"generated");
1662 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1663 if (
getType().isDynamicDim(dim)) {
1664 reifiedReturnShapes[0][dim] = getOperand(idx++);
1666 reifiedReturnShapes[0][dim] =
1676 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1677 if (getNumOperands() != resultType.getNumDynamicDims())
1678 return emitError(
"must have as many index operands as dynamic extents "
1679 "in the result type");
1683 LogicalResult GenerateOp::verifyRegions() {
1684 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1686 if (!llvm::all_of(getBody().getArgumentTypes(),
1688 return emitError(
"all body arguments must be index");
1689 if (getBody().getNumArguments() != resultTy.getRank())
1690 return emitError(
"must have one body argument per input dimension");
1693 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1695 if (yieldOp.getValue().getType() != resultTy.getElementType())
1697 "body must be terminated with a `yield` operation of the tensor "
1703 void GenerateOp::build(
1707 build(b, result, resultTy, dynamicExtents);
1712 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1716 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1729 LogicalResult matchAndRewrite(GenerateOp generateOp,
1733 generateOp.getType(), generateOp.getDynamicExtents(),
1734 foldedDynamicSizes);
1737 if (foldedTensorType == generateOp.getType())
1740 auto loc = generateOp.getLoc();
1742 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1744 newOp.getBody().begin());
1746 generateOp.getType(), newOp);
1762 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1767 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1772 Block *body = &tensorFromElements.getBody().
front();
1775 rewriter.
clone(op, mapping);
1789 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1796 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1797 setNameFn(getResult(),
"rank");
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804 if (shapedType && shapedType.hasRank())
1806 return IntegerAttr();
1813 void ReshapeOp::getAsmResultNames(
1815 setNameFn(getResult(),
"reshape");
1819 int64_t numElements = 1;
1820 for (
auto dim : type.getShape())
1830 return emitOpError(
"element types of source and destination tensor "
1831 "types should be the same");
1835 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1836 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1838 if (resultRankedType) {
1839 if (operandRankedType && resultRankedType.hasStaticShape() &&
1840 operandRankedType.hasStaticShape()) {
1842 return emitOpError(
"source and destination tensor should have the "
1843 "same number of elements");
1845 if (ShapedType::isDynamic(shapeSize))
1846 return emitOpError(
"cannot use shape operand with dynamic length to "
1847 "reshape to statically-ranked tensor type");
1848 if (shapeSize != resultRankedType.getRank())
1850 "length of shape operand differs from the result's tensor rank");
1856 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1859 return reshapedSource;
1864 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1865 getSourceMutable().assign(reshapeOpProducer.getSource());
1869 auto source = getSource();
1870 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1871 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1872 if (!sourceTy || !resultTy || sourceTy != resultTy)
1877 if (sourceTy.getRank() == 1)
1880 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1881 auto elements = fromElements.getElements();
1883 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1884 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1885 auto element = elements[id];
1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1892 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1893 dynamicNoop &= dimOp.getSource() == source;
1897 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1901 dynamicNoop =
false;
1916 void CollapseShapeOp::getAsmResultNames(
1918 setNameFn(getResult(),
"collapsed");
1921 void ExpandShapeOp::getAsmResultNames(
1923 setNameFn(getResult(),
"expanded");
1926 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928 "invalid resultDim");
1930 if (llvm::is_contained(it.value(), resultDim))
1932 llvm_unreachable(
"could not find reassociation group");
1935 FailureOr<SmallVector<OpFoldResult>>
1937 RankedTensorType expandedType,
1940 std::optional<SmallVector<OpFoldResult>> outputShape =
1945 return *outputShape;
1956 auto [staticOutputShape, dynamicOutputShape] =
1958 build(builder, result, cast<RankedTensorType>(resultType), src,
1960 dynamicOutputShape, staticOutputShape);
1968 auto tensorResultTy = cast<RankedTensorType>(resultType);
1969 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1970 builder, result.
location, tensorResultTy, reassociation, inputShape);
1972 if (succeeded(outputShape)) {
1973 outputShapeOrEmpty = *outputShape;
1975 build(builder, result, tensorResultTy, src, reassociation,
1976 outputShapeOrEmpty);
1984 getReassociationIndices());
1992 getReassociationIndices());
1995 RankedTensorType CollapseShapeOp::inferCollapsedType(
1997 return inferCollapsedType(
1999 type.getContext(), reassociation)));
2005 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2007 auto shape = type.getShape();
2009 newShape.reserve(reassociation.size());
2014 unsigned currentDim = 0;
2016 unsigned dim = m.getNumResults();
2017 auto band = shape.slice(currentDim, dim);
2019 if (llvm::is_contained(band, ShapedType::kDynamic))
2020 size = ShapedType::kDynamic;
2022 for (
unsigned d = 0; d < dim; ++d)
2023 size *= shape[currentDim + d];
2024 newShape.push_back(size);
2034 auto resultType = inferCollapsedType(
2035 llvm::cast<RankedTensorType>(src.
getType()),
2040 build(b, result, resultType, src, attrs);
2043 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2044 TensorReshapeOp, ExpandShapeOp>::value>
2046 RankedTensorType expandedType,
2047 RankedTensorType collapsedType) {
2052 auto maps = op.getReassociationMaps();
2053 RankedTensorType expectedType =
2054 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2056 return op.emitOpError(
"expected collapsed type to be ")
2057 << expectedType <<
", but got " << collapsedType;
2062 auto srcType = getSrcType();
2063 auto resultType = getResultType();
2065 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2066 return emitOpError(
"expected number of static shape dims to be equal to "
2067 "the output rank (")
2068 << resultType.getRank() <<
") but found "
2069 << getStaticOutputShape().size() <<
" inputs instead";
2071 if ((int64_t)getOutputShape().size() !=
2072 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2073 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2074 "static_output_shape: static_output_shape has ")
2075 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2076 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2089 template <
typename TensorReshapeOp>
2092 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2100 reshapeOp.getResultType(), attr.
getRawData());
2107 template <
typename TensorReshapeOp>
2112 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2114 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2115 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2119 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2126 template <
typename TensorReshapeOp>
2129 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2132 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2136 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2138 if (!shapedTy.hasStaticShape())
2142 fromElements.getElements());
2151 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2153 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2157 RankedTensorType srcType =
2158 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2159 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2160 srcType, collapseShapeOp.getReassociationMaps());
2162 if (newResultType == collapseShapeOp.getResultType()) {
2164 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2167 auto newOp = rewriter.
create<CollapseShapeOp>(
2168 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2169 collapseShapeOp.getReassociation());
2171 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2181 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2184 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2186 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2192 expandOp.getReassociationIndices();
2196 auto outputIt = expandOp.getOutputShape().begin();
2198 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2199 for (uint64_t outDim : innerReassoc) {
2200 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2207 Value val = *outputIt;
2209 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2210 dynamicOutputShape.push_back(val);
2216 newOutputShape[outDim] = cst.getSExtValue();
2218 dynamicOutputShape.push_back(val);
2224 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2229 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2230 for (
auto outDim : reassoc[inDim]) {
2231 auto ofr = newOutputShape[outDim];
2232 if (ShapedType::isDynamic(ofr)) {
2233 newInputShape[inDim] = ShapedType::kDynamic;
2236 newInputShape[inDim] *= ofr;
2243 newInputShape, expandOp.getSrcType().getElementType());
2245 newOutputShape, expandOp.getSrcType().getElementType());
2246 auto inputCast = rewriter.
create<CastOp>(expandOp.getLoc(), inputType,
2248 auto newExpand = rewriter.
create<ExpandShapeOp>(
2249 expandOp.getLoc(), outputType, inputCast.getResult(),
2250 expandOp.getReassociationIndices(), outputOfr);
2252 newExpand.getResult());
2263 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2264 FoldReshapeWithSplat<ExpandShapeOp>,
2265 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2273 tensor::DimOp, RankedTensorType>,
2274 FoldReshapeWithConstant<CollapseShapeOp>,
2275 FoldReshapeWithSplat<CollapseShapeOp>,
2276 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2280 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2281 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2282 adaptor.getOperands());
2285 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2286 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2287 adaptor.getOperands());
2294 void ExtractSliceOp::getAsmResultNames(
2296 setNameFn(getResult(),
"extracted_slice");
2302 RankedTensorType ExtractSliceOp::inferResultType(
2308 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2309 sourceTensorType.getRank() &&
2310 "unexpected staticSizes not equal to rank of source");
2312 sourceTensorType.getEncoding());
2315 RankedTensorType ExtractSliceOp::inferResultType(
2323 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2324 staticSizes, staticStrides);
2335 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2336 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2340 auto inferredType = llvm::cast<RankedTensorType>(
2341 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2342 int rankDiff = inferredType.getRank() - desiredResultRank;
2344 auto shape = inferredType.getShape();
2345 llvm::SmallBitVector dimsToProject =
2349 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2350 if (!dimsToProject.test(pos))
2351 projectedShape.push_back(shape[pos]);
2355 return inferredType;
2358 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2359 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2367 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2368 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2375 RankedTensorType resultType,
Value source,
2385 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2388 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2389 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2392 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2405 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2414 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2420 RankedTensorType resultType,
Value source,
2429 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2441 RankedTensorType expectedType) {
2446 return op->
emitError(
"expected rank to be smaller or equal to ")
2447 <<
"the other rank. ";
2449 return op->
emitError(
"expected type to be ")
2450 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2452 return op->
emitError(
"expected element type to be ")
2453 << expectedType.getElementType();
2455 llvm_unreachable(
"unexpected extract_slice op verification result");
2461 RankedTensorType sourceType = getSourceType();
2464 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2465 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2473 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2474 getStaticStrides(),
true);
2476 return getOperation()->emitError(boundsResult.
errorMessage);
2488 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2489 assert(sourceTensorType &&
"not a ranked tensor type");
2490 auto sourceShape = sourceTensorType.getShape();
2491 if (sourceShape.equals(desiredShape))
2493 auto maybeRankReductionMask =
2495 if (!maybeRankReductionMask)
2504 reifiedReturnShapes.resize(1);
2505 reifiedReturnShapes[0].reserve(
getType().getRank());
2508 for (
const auto &size :
enumerate(mixedSizes)) {
2509 if (droppedDims.test(size.index()))
2511 reifiedReturnShapes[0].push_back(size.value());
2532 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2536 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2539 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2540 return matchPattern(operand, matchConstantIndex());
2544 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2553 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2554 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2555 sliceOp.getStaticStrides());
2561 Value newResult = rewriter.
create<ExtractSliceOp>(
2562 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2563 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2564 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2573 template <
typename IterTy,
typename ElemTy>
2578 assert(offsets.size() == sizes.size());
2579 assert(offsets.size() == strides.size());
2580 if (offsets.empty())
2583 int64_t offset = offsets.front();
2584 int64_t size = sizes.front();
2585 int64_t stride = strides.front();
2586 if (offsets.size() == 1) {
2587 for (int64_t i = 0; i < size; ++i, offset += stride)
2588 outValues->push_back(*(values + offset));
2593 for (int64_t i = 0; i < size; ++i, offset += stride) {
2594 auto begin = values + offset * counts.front();
2595 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2596 offsets.drop_front(), sizes.drop_front(),
2597 strides.drop_front(), outValues);
2604 class ConstantOpExtractSliceFolder final
2609 ConstantOpExtractSliceFolder(
MLIRContext *context,
2612 controlFn(std::move(controlFn)) {}
2614 LogicalResult matchAndRewrite(ExtractSliceOp op,
2625 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2626 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2627 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2634 int64_t count = sourceType.getNumElements();
2639 auto offsets = op.getStaticOffsets();
2640 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2642 auto sizes = op.getStaticSizes();
2643 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2645 auto strides = op.getStaticStrides();
2646 if (llvm::is_contained(strides, ShapedType::kDynamic))
2652 counts.reserve(shape.size());
2653 for (int64_t v : shape) {
2655 counts.push_back(count);
2661 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2663 outValues.reserve(sourceType.getNumElements());
2664 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2665 elems.begin(), counts, offsets, sizes, strides, &outValues);
2667 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2669 outValues.reserve(sourceType.getNumElements());
2670 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2671 elems.begin(), counts, offsets, sizes, strides, &outValues);
2694 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2703 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2704 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2712 ExtractSliceOp newOp) {
2713 Value replacement = newOp.getResult();
2714 if (replacement.
getType() != op.getType())
2715 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2726 ExtractSliceOpCastFolder>(context);
2730 static LogicalResult
2732 ShapedType shapedType) {
2739 auto shape = shapedType.getShape();
2740 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2754 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2757 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2758 insertOp.isSameAs(extractOp, isSame))
2759 return insertOp.getSource();
2764 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2765 if (
OpFoldResult reshapedSource = reshapeConstantSource(
2766 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2768 return reshapedSource;
2769 if (getSourceType() ==
getType() &&
2771 return this->getSource();
2780 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2781 unsigned rank = rankedTensorType.getRank();
2785 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2786 offsets, sizes, strides);
2793 void InsertSliceOp::getAsmResultNames(
2795 setNameFn(getResult(),
"inserted_slice");
2810 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2822 build(b, result, source, dest, offsets, sizes, strides, attrs);
2835 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2841 RankedTensorType srcType, RankedTensorType dstType,
2846 RankedTensorType expected = ExtractSliceOp::inferResultType(
2847 dstType, staticOffsets, staticSizes, staticStrides);
2849 *expectedType = expected;
2856 RankedTensorType expectedType;
2859 getStaticSizes(), getStaticStrides(), &expectedType);
2866 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2867 getStaticStrides(),
true);
2869 return getOperation()->emitError(boundsResult.
errorMessage);
2892 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2895 if (!prevInsertOp ||
2896 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2897 !prevInsertOp.isSameAs(insertOp, isSame))
2900 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2912 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2915 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2916 !extractOp.isSameAs(insertOp, isSame))
2919 return extractOp.getSource();
2923 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2924 getSourceType() ==
getType() &&
2926 return this->getSource();
2947 template <
typename InsertOpTy>
2948 class InsertSliceOpConstantArgumentFolder final
2953 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2968 mixedOffsets, mixedSizes, mixedStrides);
2973 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2974 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2975 mixedOffsets, mixedSizes, mixedStrides);
2976 Value toInsert = insertSliceOp.getSource();
2977 if (sourceType != insertSliceOp.getSourceType()) {
2982 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2984 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2985 sourceType, toInsert);
2988 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2989 mixedSizes, mixedStrides);
3014 template <
typename InsertOpTy>
3015 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3018 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3020 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
3021 return matchPattern(operand, matchConstantIndex());
3025 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
3026 auto castOp = v.getDefiningOp<tensor::CastOp>();
3028 return std::nullopt;
3029 return castOp.getSource();
3031 std::optional<Value> sourceCastSource =
3032 getSourceOfCastOp(insertSliceOp.getSource());
3033 std::optional<Value> destCastSource =
3034 getSourceOfCastOp(insertSliceOp.getDest());
3035 if (!sourceCastSource && !destCastSource)
3039 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3040 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3041 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3042 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3043 if (!srcType || !dstType)
3051 staticSizes, srcType.getShape(),
true);
3052 if (!rankReductionMask.has_value())
3060 int64_t rankReducedIdx = 0;
3061 for (
auto [idx, size] :
enumerate(staticSizes)) {
3062 if (!rankReductionMask.value().contains(idx) &&
3063 !srcType.isDynamicDim(rankReducedIdx)) {
3065 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3066 size = srcType.getDimSize(rankReducedIdx++);
3072 staticSizes, insertSliceOp.getStaticStrides()) !=
3077 mixedSizes, insertSliceOp.getMixedStrides());
3082 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
3083 mixedSizes, insertSliceOp.getMixedStrides());
3086 bool isParallelInsert =
3087 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
3090 insertSliceOp.getDestType(),
3119 template <
typename InsertOpTy>
3120 struct InsertSliceOpSourceCastInserter final
3124 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3126 RankedTensorType srcType = insertSliceOp.getSourceType();
3127 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3130 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131 if (std::optional<int64_t> constInt =
3136 newSrcShape[i] = *constInt;
3143 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144 if (srcType == newSrcType ||
3146 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3158 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3161 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3163 insertSliceOp, cast, insertSliceOp.getDest(),
3164 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165 insertSliceOp.getMixedStrides());
3177 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3178 InsertSliceOpCastFolder<InsertSliceOp>,
3179 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3186 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3187 unsigned rank = rankedTensorType.getRank();
3191 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3200 setNameFn(getResult(),
"padded");
3206 Type typeToInfer,
Type typeToInferFrom) {}
3210 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3211 Type &typeToInfer,
Type typeToInferFrom) {
3213 typeToInfer = typeToInferFrom;
3218 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3219 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3221 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3222 if (!expectedType) {
3223 return emitError(
"failed to infer expectedType from sourceType ")
3224 << sourceType <<
", specified resultType is " << resultType;
3226 if (resultType.getRank() != expectedType.getRank()) {
3228 << resultType <<
" does not match the inferred type "
3231 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3232 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3234 if (expectedType.isDynamicDim(i))
3237 << resultType <<
" does not match the inferred type "
3244 LogicalResult PadOp::verifyRegions() {
3245 auto ®ion = getRegion();
3246 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3249 return emitError(
"expected the block to have ") << rank <<
" arguments";
3253 if (!en.value().isIndex())
3254 return emitOpError(
"expected block argument ")
3255 << (en.index() + 1) <<
" to be an index";
3260 if (yieldOp.getValue().getType() !=
3262 return emitOpError(
"expected yield type to match shape element type");
3267 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3271 unsigned rank = sourceType.getRank();
3272 if (staticLow.size() != rank)
3273 return RankedTensorType();
3274 if (staticHigh.size() != rank)
3275 return RankedTensorType();
3276 if (!resultShape.empty() && resultShape.size() != rank)
3277 return RankedTensorType();
3280 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3281 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3282 staticHigh[i] == ShapedType::kDynamic) {
3283 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3286 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3287 assert((resultShape.empty() || size == resultShape[i] ||
3288 resultShape[i] == ShapedType::kDynamic) &&
3289 "mismatch between inferred shape and result shape");
3290 inferredShape.push_back(size);
3301 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3303 resultType = inferResultType(sourceType, staticLow, staticHigh);
3305 build(b, result, resultType, source, low, high,
3313 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3314 unsigned rank = sourceType.getRank();
3316 build(b, result, resultType, source, staticVector, staticVector, low, high,
3324 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3334 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3336 assert(llvm::isa<RankedTensorType>(resultType));
3338 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3347 build(b, result, resultType, source, low, high, nofold, attrs);
3351 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3358 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3362 llvm::SmallBitVector PadOp::getPaddedDims() {
3363 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3365 for (
const auto &en :
enumerate(paddingWidths))
3367 paddedDims.set(en.index());
3369 extractPaddedDims(getMixedLowPad());
3370 extractPaddedDims(getMixedHighPad());
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3382 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3384 if (padTensorOp.getNofold())
3387 padTensorOp, padTensorOp.getResult().getType(),
3388 padTensorOp.getSource());
3397 LogicalResult matchAndRewrite(PadOp padTensorOp,
3399 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3403 auto newResultType = PadOp::inferResultType(
3404 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3405 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3406 padTensorOp.getResultType().getShape());
3408 if (newResultType == padTensorOp.getResultType()) {
3410 padTensorOp.getSourceMutable().assign(castOp.getSource());
3413 auto newOp = rewriter.
create<PadOp>(
3414 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3415 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3416 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3419 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3422 padTensorOp, padTensorOp.getResultType(), newOp);
3433 LogicalResult matchAndRewrite(PadOp padTensorOp,
3435 if (!padTensorOp.getResult().hasOneUse())
3438 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3442 tensorCastOp.getDest().getType()))
3445 auto replacementOp = rewriter.
create<PadOp>(
3446 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3453 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3454 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3497 LogicalResult matchAndRewrite(PadOp padOp,
3499 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3502 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3503 if (!outerPadOp || outerPadOp.getNofold())
3505 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3510 int64_t rank = padOp.getSourceType().getRank();
3511 if (outerSliceOp.getSourceType().getRank() != rank) {
3513 "cannot fold rank-reducing chain");
3517 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3519 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3523 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3525 "cannot fold PadOps with low padding");
3530 Value innerValue = padOp.getConstantPaddingValue();
3531 Value outerValue = outerPadOp.getConstantPaddingValue();
3532 if (!innerValue || !outerValue ||
3535 innerAttr != outerAttr) {
3537 padOp,
"cannot fold PadOps with different padding values");
3541 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3542 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3543 if (innerDims.anyCommon(outerDims)) {
3545 padOp,
"cannot fold PadOps with common padding dimensions");
3555 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3556 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3557 if (!innerDims.test(en.index()) &&
3559 en.value() = outerOffset;
3562 if (!outerDims.test(en.index()) &&
3564 en.value() = innerOffset;
3568 padOp,
"cannot find zero-offset and zero-padding pair");
3578 if (!outerDims.test(en.index()))
3580 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3581 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3582 assert(!ShapedType::isDynamic(sourceSize) &&
3583 "expected padded dimension to have a static size");
3586 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3587 "match the size of the outer padding");
3589 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3595 if (innerDims.test(en.index()))
3596 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3597 if (outerDims.test(en.index()))
3598 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3603 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3604 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3605 innerSliceOp.getMixedStrides());
3606 auto newPadOp = rewriter.
create<PadOp>(
3607 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3608 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3611 newPadOp.getRegion().begin());
3612 rewriter.
replaceOp(padOp, newPadOp.getResult());
3620 LogicalResult matchAndRewrite(PadOp padTensorOp,
3622 Value input = padTensorOp.getSource();
3623 if (!llvm::isa<RankedTensorType>(input.
getType()))
3625 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3626 auto inputRank = inputDims.size();
3628 auto oldResultType =
3629 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3633 auto outputDims = oldResultType.getShape();
3638 for (
auto operand : padTensorOp.getLow()) {
3641 constOperandsLow.push_back(ShapedType::kDynamic);
3642 newLows.push_back(operand);
3645 constOperandsLow.push_back(intOp.getExtValue());
3649 for (
auto operand : padTensorOp.getHigh()) {
3652 constOperandsHigh.push_back(ShapedType::kDynamic);
3653 newHighs.push_back(operand);
3656 constOperandsHigh.push_back(intOp.getExtValue());
3663 if (inputDims.size() != outputDims.size() ||
3664 inputDims.size() != constLow.size() ||
3665 inputDims.size() != constHigh.size())
3670 for (
size_t i = 0; i < inputRank; i++) {
3671 if (constLow[i] == ShapedType::kDynamic)
3672 constLow[i] = constOperandsLow[lowCount++];
3673 if (constHigh[i] == ShapedType::kDynamic)
3674 constHigh[i] = constOperandsHigh[highCount++];
3682 for (
size_t i = 0; i < inputRank; i++) {
3683 if (outputDims[i] == ShapedType::kDynamic) {
3684 newOutDims.push_back(
3685 (staticLow[i] == ShapedType::kDynamic ||
3686 staticHigh[i] == ShapedType::kDynamic ||
3687 inputDims[i] == ShapedType::kDynamic
3688 ? ShapedType::kDynamic
3689 : inputDims[i] + staticLow[i] + staticHigh[i]));
3691 newOutDims.push_back(outputDims[i]);
3696 llvm::all_of(newOutDims,
3697 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3702 newOutDims, padTensorOp.getType().getElementType());
3703 auto newOp = rewriter.
create<PadOp>(
3704 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3705 newLows, newHighs, padTensorOp.getNofold(),
3709 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3737 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3740 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3742 if (padOp.getNofold()) {
3746 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3747 if (!producerPad || producerPad.getNofold()) {
3749 padOp,
"producer is not a foldable tensor.pad op");
3753 Value consumerPadValue = padOp.getConstantPaddingValue();
3754 Value producerPadValue = producerPad.getConstantPaddingValue();
3755 if (!consumerPadValue || !producerPadValue ||
3756 consumerPadValue != producerPadValue) {
3759 "cannot fold PadOps with different or non-constant padding values");
3770 for (
auto [consumerIndex, producerIndex] :
3771 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3773 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3779 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3781 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3783 auto newPadOp = rewriter.
create<tensor::PadOp>(
3784 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3785 newLowPad, newHighPad, padOp.getNofold(),
3788 newPadOp.getRegion().begin());
3789 rewriter.
replaceOp(padOp, newPadOp.getResult());
3798 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3799 FoldOrthogonalPaddings, FoldStaticPadding,
3800 FoldConsecutiveConstantPadding>(context);
3812 Value PadOp::getConstantPaddingValue() {
3813 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3816 Value padValue = yieldOp.getValue();
3828 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3838 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3839 ParallelCombiningOpInterface parallelCombiningParent =
3840 getParallelCombiningParent();
3841 for (
const auto &it :
3844 if (&nextOp == getOperation())
3845 return parallelCombiningParent.getParentResult(it.index());
3847 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3863 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3876 build(b, result, source, dest, offsets, sizes, strides, attrs);
3890 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3894 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3895 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3896 << *(getOperation()->getParentOp());
3899 RankedTensorType expectedType;
3902 getStaticSizes(), getStaticStrides(), &expectedType);
3909 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3910 getStaticStrides(),
true);
3912 return getOperation()->emitError(boundsResult.
errorMessage);
3917 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3919 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3920 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3921 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3932 void ScatterOp::getAsmResultNames(
3934 setNameFn(getResult(),
"scatter");
3938 int64_t destRank = getDestType().getRank();
3941 getIndicesType().
getShape(), destRank,
3942 "scatter",
"dest")))
3946 return emitOpError(
"requires 'unique' attribute to be set");
3953 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3954 getDestType(), getIndicesType(), scatterDims,
false);
3955 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3956 getDestType(), getIndicesType(), scatterDims,
true);
3957 if (getSourceType() != expectedSourceType &&
3958 getSourceType() != expectedRankReducedSourceType) {
3959 return emitOpError(
"source type "
3962 << expectedSourceType <<
" or its rank-reduced variant "
3963 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3976 build(builder, result, aggregateType, element, dynamicSizes);
3982 build(builder, result, aggregateType, element, dynamicSizes);
3990 build(builder, result, element, staticShape, dynamicSizes);
3993 void SplatOp::getAsmResultNames(
3995 setNameFn(getResult(),
"splat");
4000 return emitOpError(
"incorrect number of dynamic sizes, has ")
4002 <<
getType().getNumDynamicDims();
4011 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4012 if (
getType().isDynamicDim(i)) {
4022 auto constOperand = adaptor.getInput();
4023 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4027 if (!
getType().hasStaticShape())
4042 if (isa<InsertSliceOp>(op.getOperation()) ||
4043 isa<LoopLikeOpInterface>(op.getOperation()))
4076 isa<linalg::RelayoutOpInterface>(*op))
4084 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4087 replacements.reserve(newOp->getNumResults());
4088 for (
auto [oldResult, newResult] :
4089 llvm::zip(op->getResults(), newOp->getResults())) {
4090 if (newResult.
getType() != oldResult.getType()) {
4091 replacements.push_back(rewriter.
create<tensor::CastOp>(
4092 op->getLoc(), oldResult.getType(), newResult));
4094 replacements.push_back(newResult);
4107 void TensorDialect::getCanonicalizationPatterns(
4116 #define GET_OP_CLASSES
4117 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.