33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/MathExtras.h"
44 using llvm::divideCeilSigned;
45 using llvm::divideFloorSigned;
53 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
55 if (complex::ConstantOp::isBuildableWith(value, type))
56 return complex::ConstantOp::create(builder, loc, type,
57 llvm::cast<ArrayAttr>(value));
63 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
64 if (tensorType.isDynamicDim(dim))
65 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
72 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
74 for (int64_t i = 0; i < tensorType.getRank(); ++i)
81 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
82 assert(tensorType &&
"expected tensor type");
86 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
88 return destOp.getTiedOpOperand(opResult)->get();
96 if (!tensorType.hasStaticShape()) {
104 for (int64_t sz : tensorType.getShape())
110 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
118 if (llvm::isa<TensorType>(opResult.getType())) {
122 result.push_back(*destination);
129 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
130 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
131 return rtp1.getShape() == rtp2.getShape() &&
132 rtp1.getElementType() == rtp2.getElementType();
142 llvm::SmallBitVector droppedDims(mixedSizes.size());
143 int64_t shapePos = reducedShape.size() - 1;
145 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
146 size_t idx = mixedSizes.size() - size.index() - 1;
148 bool isStaticUnitSize =
149 isa<Attribute>(size.value()) &&
150 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
155 assert(isStaticUnitSize &&
"expected unit dim");
156 droppedDims.set(idx);
161 if (!isStaticUnitSize) {
167 if (reducedShape[shapePos] == 1) {
173 droppedDims.set(idx);
176 assert(shapePos < 0 &&
"dimension mismatch");
183 static RankedTensorType
187 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
188 "incorrect number of dynamic sizes");
192 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
193 if (type.isDynamicDim(i)) {
194 Value dynamicSize = dynamicSizes[ctr++];
196 if (cst.has_value()) {
198 if (cst.value() < 0) {
199 foldedDynamicSizes.push_back(dynamicSize);
202 staticShape[i] = *cst;
204 foldedDynamicSizes.push_back(dynamicSize);
218 if (inputs.size() != 1 || outputs.size() != 1)
220 Type a = inputs.front(), b = outputs.front();
221 auto aT = dyn_cast<TensorType>(a);
222 auto bT = dyn_cast<TensorType>(b);
226 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
239 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
241 auto tensorBitcastOperand =
242 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
243 if (!tensorBitcastOperand)
246 auto resultType = cast<TensorType>(tensorBitcast.getType());
247 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
248 tensorBitcastOperand.getOperand());
257 results.
add<ChainedTensorBitcast>(context);
265 setNameFn(getResult(),
"cast");
271 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
272 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
275 if (!sourceType || !targetType)
279 if (sourceType.getElementType() != targetType.getElementType())
283 if (sourceType.getRank() != targetType.getRank())
287 if (sourceType.getEncoding() != targetType.getEncoding())
291 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
292 if (ShapedType::isStatic(std::get<0>(t)) &&
293 ShapedType::isDynamic(std::get<1>(t)))
329 castOp.getSource().getType());
362 if (llvm::isa<BlockArgument>(opOperand.get()))
364 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
365 return castOp && canFoldIntoConsumerOp(castOp);
372 newOperands.reserve(op->getNumOperands());
377 int64_t dpsInitIdx = 0;
378 for (
OpOperand &opOperand : op->getOpOperands()) {
379 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
381 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
382 if (op.isDpsInit(&opOperand) &&
383 !llvm::isa<MemRefType>(newOperands.back().getType()))
384 newResTy[dpsInitIdx++] = newOperands.back().getType();
394 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
396 operand.set(castOp.getOperand());
400 return success(folded);
404 if (inputs.size() != 1 || outputs.size() != 1)
406 Type a = inputs.front(), b = outputs.front();
407 auto aT = llvm::dyn_cast<TensorType>(a);
408 auto bT = llvm::dyn_cast<TensorType>(b);
412 if (aT.getElementType() != bT.getElementType())
428 int64_t rank = one.getRank();
429 if (rank != two.getRank())
434 for (int64_t i = 0; i < rank; ++i) {
435 if (one.isDynamicDim(i)) {
436 join.push_back(two.getDimSize(i));
439 if (two.isDynamicDim(i)) {
440 join.push_back(one.getDimSize(i));
443 if (one.getDimSize(i) != two.getDimSize(i))
445 join.push_back(one.getDimSize(i));
457 LogicalResult matchAndRewrite(CastOp tensorCast,
459 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
461 if (!tensorCastOperand)
465 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
466 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
467 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
481 auto newJoin =
joinShapes(sourceType, resultType);
482 if (firstJoin != newJoin)
485 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
486 tensorCastOperand.getOperand());
506 LogicalResult matchAndRewrite(CastOp tensorCast,
508 auto extractOperand =
509 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
512 auto rankedResultType =
513 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
514 if (!rankedResultType)
518 rankedResultType.getShape() ==
519 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
525 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
527 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
528 if (dimMask && dimMask->count(i))
530 int64_t dim = rankedResultType.getShape()[dimIndex++];
531 if (ShapedType::isDynamic(dim))
533 sizes[i] = rewriter.getIndexAttr(dim);
536 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
537 tensorCast, rankedResultType, extractOperand.getSource(),
538 extractOperand.getMixedOffsets(), sizes,
539 extractOperand.getMixedStrides());
548 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
555 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
556 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
558 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
559 return llvm::cast<RankedTensorType>(type);
561 int64_t concatRank = tensorTypes[0].getRank();
564 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
567 for (int64_t i = 0, e = concatRank; i < e; ++i) {
571 for (
auto tensorType : tensorTypes)
576 for (
auto tensorType : tensorTypes)
579 sizes[dim] = concatSize.asInteger();
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputs.
getTypes());
587 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
588 build(builder, result, *resultType, dim, inputs);
592 if (getInputs().size() < 1)
593 return emitOpError(
"requires at least one input");
596 for (
auto input : getInputs())
597 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
604 return emitOpError(
"rank of concatenated inputs must match result rank");
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
610 return emitOpError(
"inputs and result element type must match");
612 int64_t dim = getDim();
613 if (dim >= resultRank)
614 return emitOpError(
"concatenation dim must be less than the tensor rank");
617 for (int64_t i = 0, e = resultRank; i < e; ++i) {
621 for (
auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
625 return emitOpError(
"static concatenation size mismatch along ")
626 <<
"non-concatenated dimension " << i;
632 for (
auto tensorType : inputTypes)
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
639 for (
auto [inferredSize, actualSize] :
640 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642 ShapedType::isDynamic(actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
644 return emitOpError(
"result type ")
645 << resultType <<
"does not match inferred shape "
646 << inferredResultType <<
" static sizes";
652 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
657 inputShapes.reserve(numInputs);
659 concatOffsets.reserve(numInputs);
670 outputShape = inputShape;
671 concatOffsets.push_back(zero);
673 concatOffsets.push_back(outputShape[concatDim]);
675 builder, loc, addExpr,
676 {outputShape[concatDim], inputShape[concatDim]});
678 inputShapes.emplace_back(std::move(inputShape));
681 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
684 int64_t rank =
getType().getRank();
689 offsets[concatDim] = concatOffsets[index];
690 auto insertSlice = tensor::InsertSliceOp::create(
691 builder, loc, input, replacement, offsets, inputShapes[index], strides);
692 replacement = insertSlice.getResult();
695 replacement = tensor::CastOp::create(builder, loc,
getType(), replacement);
704 int64_t dim = getDim();
705 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
707 Value init = inputs[0];
708 int64_t rank =
getType().getRank();
715 for (int64_t i = 0; i < rank; ++i) {
718 if (!
getType().isDynamicDim(i)) {
720 }
else if (!inferredResultType.isDynamicDim(i)) {
723 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
725 reifiedReturnShapes[0][i] =
726 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
730 if (
getType().isDynamicDim(dim)) {
738 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
746 reifiedReturnShapes[0][dim] =
752 void ConcatOp::getAsmResultNames(
754 setNameFn(getResult(),
"concat");
759 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
771 if (concatOp.getInputs().size() != 1)
774 concatOp.getInputs()[0]);
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
808 LogicalResult matched = failure();
812 for (
auto [operandIdx, operandType] :
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(operandType).getDimSize(dim);
818 inferredOperandShape, inferredResultType.getElementType());
826 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
827 concatOp.getOperand(operandIdx));
829 concatOp->setOperand(operandIdx, castOp->getResult(0));
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863 concatOp.getResultType())) {
868 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
869 concatOp->getOperands());
881 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
890 setNameFn(getResult(),
"dim");
897 build(builder, result, source, indexValue);
900 std::optional<int64_t> DimOp::getConstantIndex() {
909 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
910 if (!rankedSourceType)
921 setResultRange(getResult(),
927 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
932 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
938 int64_t indexVal = index.getInt();
939 if (indexVal < 0 || indexVal >= tensorType.getRank())
943 if (!tensorType.isDynamicDim(index.getInt())) {
945 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
948 Operation *definingOp = getSource().getDefiningOp();
951 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
953 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
956 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
959 auto dynExtents = fromElements.getDynamicExtents().begin();
960 for (
auto dim : resultType.getShape().take_front(index.getInt()))
961 if (ShapedType::isDynamic(dim))
964 return Value{*dynExtents};
968 unsigned unsignedIndex = index.getValue().getZExtValue();
970 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
973 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
974 sliceOp.isDynamicSize(unsignedIndex)) {
975 return {sliceOp.getDynamicSize(unsignedIndex)};
991 LogicalResult matchAndRewrite(DimOp dimOp,
993 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
996 Value newSource = castOp.getOperand();
1007 LogicalResult matchAndRewrite(DimOp dimOp,
1009 auto source = dimOp.getSource();
1010 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1014 auto resultIndex = cast<OpResult>(source).getResultNumber();
1015 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1018 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1028 LogicalResult matchAndRewrite(DimOp dim,
1030 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1040 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1041 if (extract.
getType() != dim.getType())
1043 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1052 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1062 assert(none_of(staticShape, ShapedType::isDynamic) &&
1063 "expected only static sizes");
1064 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
1071 build(builder, result, tensorType, dynamicSizes);
1080 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1085 return emitOpError(
"incorrect number of dynamic sizes, has ")
1087 <<
getType().getNumDynamicDims();
1096 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1097 if (
getType().isDynamicDim(i)) {
1106 Value EmptyOp::getDynamicSize(
unsigned idx) {
1107 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1109 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1110 if (
getType().isDynamicDim(i))
1119 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1120 if (
getType().isDynamicDim(i)) {
1144 LogicalResult matchAndRewrite(EmptyOp op,
1148 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1151 if (foldedTensorType == op.getType())
1154 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1155 foldedDynamicSizes);
1164 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1166 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1167 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1168 if (!emptyTensorOp || !maybeConstantIndex)
1170 auto emptyTensorType = emptyTensorOp.getType();
1171 if (*maybeConstantIndex < 0 ||
1172 *maybeConstantIndex >= emptyTensorType.getRank() ||
1173 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1176 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1199 LogicalResult matchAndRewrite(CastOp castOp,
1203 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1208 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1212 newMixedSizes.reserve(currMixedSizes.size());
1213 assert(resultShape.size() == currMixedSizes.size() &&
1214 "mismatch in result shape and sizes of empty op");
1215 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1216 int64_t newDim = std::get<0>(it);
1220 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1221 if (ShapedType::isDynamic(newDim) ||
1222 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1227 producer,
"mismatch in static value of shape of empty tensor "
1228 "result and cast result");
1230 newMixedSizes.push_back(attr);
1236 if (ShapedType::isStatic(newDim)) {
1237 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1243 newMixedSizes.push_back(currDim);
1248 resultType.getElementType());
1257 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1258 ReplaceEmptyTensorStaticShapeDims>(context);
1275 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1278 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1280 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1283 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1286 extract, tensorCast.getSource(), extract.getIndices());
1301 struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1304 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1307 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1310 if (!collapseOp.getSrcType().hasStaticShape())
1313 auto sourceSizes = collapseOp.getSrcType().getShape();
1316 extractOp.getIndices().end());
1318 for (
auto [index, group] :
1319 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1320 assert(!group.empty() &&
"association indices groups cannot be empty");
1321 auto groupSize = group.size();
1323 if (groupSize == 1) {
1324 sourceIndices.push_back(index);
1329 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1330 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1331 rewriter, extractOp.getLoc(), index, basis,
true);
1332 llvm::append_range(sourceIndices,
delinearize.getResults());
1334 if (collapseOp.getReassociationIndices().empty()) {
1337 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1339 rewriter, extractOp.getLoc(), zeroAffineMap,
1341 for (int64_t i = 0; i < srcRank; i++) {
1342 sourceIndices.push_back(
1348 extractOp, collapseOp.getSrc(), sourceIndices);
1355 void ExtractOp::getAsmResultNames(
1357 setNameFn(getResult(),
"extracted");
1362 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1363 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1364 return emitOpError(
"incorrect number of indices for extract_element");
1373 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1378 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1379 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1380 return insertOp.getScalar();
1386 if (
Attribute tensor = adaptor.getTensor()) {
1389 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1390 return splatTensor.getSplatValue<
Attribute>();
1393 if (isa<DenseResourceElementsAttr>(tensor))
1399 for (
Attribute indice : adaptor.getIndices()) {
1400 if (!indice || !llvm::isa<IntegerAttr>(indice))
1402 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1406 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1407 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1408 auto rank = tensorType.getRank();
1409 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1413 for (
int i = rank - 1; i >= 0; --i) {
1414 flatIndex += indices[i] * stride;
1415 stride *= tensorType.getDimSize(i);
1419 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1422 return fromElementsOp.getElements()[flatIndex];
1426 if (
Attribute tensor = adaptor.getTensor()) {
1427 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1428 if (elementsAttr && elementsAttr.isValidIndex(indices))
1429 return elementsAttr.getValues<
Attribute>()[indices];
1440 results.
add<ExtractFromTensorCast>(context);
1452 void FromElementsOp::getAsmResultNames(
1454 setNameFn(getResult(),
"from_elements");
1459 assert(!elements.empty() &&
"expected at least one element");
1461 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1462 build(builder, result, resultType, elements);
1465 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1466 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1489 struct ExtractElementFromIndexCast
1493 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1496 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1502 auto newExtract = tensor::ExtractOp::create(
1503 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1516 results.
add<ExtractElementFromIndexCast>(context);
1523 void GatherOp::getAsmResultNames(
1525 setNameFn(getResult(),
"gather");
1540 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1541 RankedTensorType indicesType,
1545 resultShape.reserve(resultShape.size() + sourceType.getRank());
1546 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1547 if (llvm::binary_search(gatherDims, idx)) {
1549 resultShape.push_back(1);
1552 resultShape.push_back(sourceType.getDimSize(idx));
1557 static LogicalResult
1560 StringRef gatherOrScatter, StringRef sourceOrDest) {
1562 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1564 int64_t numGatherDims = dims.size();
1565 if (numGatherDims > rank)
1567 <<
"_dims overflow " << sourceOrDest <<
" rank";
1568 if (indices.empty() || indices.back() != numGatherDims)
1570 <<
"_dims length must match the size of last dimension of indices";
1571 for (int64_t val : dims) {
1574 <<
"_dims value must be non-negative";
1577 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1579 for (int64_t i = 1; i < numGatherDims; ++i) {
1580 if (dims[i - 1] >= dims[i])
1582 <<
"_dims values must be strictly increasing";
1588 int64_t sourceRank = getSourceType().getRank();
1591 getIndicesType().
getShape(), sourceRank,
1592 "gather",
"source")))
1595 RankedTensorType expectedResultType = GatherOp::inferResultType(
1596 getSourceType(), getIndicesType(), gatherDims,
false);
1597 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims,
true);
1599 if (getResultType() != expectedResultType &&
1600 getResultType() != expectedRankReducedResultType) {
1601 return emitOpError(
"result type "
1604 << expectedResultType <<
" or its rank-reduced variant "
1605 << expectedRankReducedResultType <<
" (got: " << getResultType()
1613 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1614 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1616 return reshapedSource;
1624 void InsertOp::getAsmResultNames(
1626 setNameFn(getResult(),
"inserted");
1631 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1632 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1633 return emitOpError(
"incorrect number of indices");
1641 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1642 if (scalar == splatDest.getSplatValue<
Attribute>())
1651 void GenerateOp::getAsmResultNames(
1653 setNameFn(getResult(),
"generated");
1660 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1661 if (
getType().isDynamicDim(dim)) {
1662 reifiedReturnShapes[0][dim] = getOperand(idx++);
1664 reifiedReturnShapes[0][dim] =
1674 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1675 if (getNumOperands() != resultType.getNumDynamicDims())
1676 return emitError(
"must have as many index operands as dynamic extents "
1677 "in the result type");
1681 LogicalResult GenerateOp::verifyRegions() {
1682 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1684 if (!llvm::all_of(getBody().getArgumentTypes(),
1686 return emitError(
"all body arguments must be index");
1687 if (getBody().getNumArguments() != resultTy.getRank())
1688 return emitError(
"must have one body argument per input dimension");
1691 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1693 if (yieldOp.getValue().getType() != resultTy.getElementType())
1695 "body must be terminated with a `yield` operation of the tensor "
1701 void GenerateOp::build(
1705 build(b, result, resultTy, dynamicExtents);
1710 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1714 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1727 LogicalResult matchAndRewrite(GenerateOp generateOp,
1731 generateOp.getType(), generateOp.getDynamicExtents(),
1732 foldedDynamicSizes);
1735 if (foldedTensorType == generateOp.getType())
1738 auto loc = generateOp.getLoc();
1740 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1742 newOp.getBody().begin());
1744 generateOp.getType(), newOp);
1760 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1763 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1765 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1770 Block *body = &tensorFromElements.getBody().
front();
1773 rewriter.
clone(op, mapping);
1787 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1794 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1795 setNameFn(getResult(),
"rank");
1800 auto type = getOperand().getType();
1801 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802 if (shapedType && shapedType.hasRank())
1804 return IntegerAttr();
1811 void ReshapeOp::getAsmResultNames(
1813 setNameFn(getResult(),
"reshape");
1817 int64_t numElements = 1;
1818 for (
auto dim : type.getShape())
1828 return emitOpError(
"element types of source and destination tensor "
1829 "types should be the same");
1833 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1834 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1836 if (resultRankedType) {
1837 if (operandRankedType && resultRankedType.hasStaticShape() &&
1838 operandRankedType.hasStaticShape()) {
1840 return emitOpError(
"source and destination tensor should have the "
1841 "same number of elements");
1843 if (ShapedType::isDynamic(shapeSize))
1844 return emitOpError(
"cannot use shape operand with dynamic length to "
1845 "reshape to statically-ranked tensor type");
1846 if (shapeSize != resultRankedType.getRank())
1848 "length of shape operand differs from the result's tensor rank");
1854 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1855 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1857 return reshapedSource;
1862 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1863 getSourceMutable().assign(reshapeOpProducer.getSource());
1867 auto source = getSource();
1868 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1869 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1870 if (!sourceTy || !resultTy || sourceTy != resultTy)
1875 if (sourceTy.getRank() <= 1)
1878 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1879 auto elements = fromElements.getElements();
1881 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1882 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1883 auto element = elements[id];
1886 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1890 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1891 dynamicNoop &= dimOp.getSource() == source;
1895 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1899 dynamicNoop =
false;
1914 void CollapseShapeOp::getAsmResultNames(
1916 setNameFn(getResult(),
"collapsed");
1919 void ExpandShapeOp::getAsmResultNames(
1921 setNameFn(getResult(),
"expanded");
1924 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1925 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1926 "invalid resultDim");
1928 if (llvm::is_contained(it.value(), resultDim))
1930 llvm_unreachable(
"could not find reassociation group");
1933 FailureOr<SmallVector<OpFoldResult>>
1935 RankedTensorType expandedType,
1938 std::optional<SmallVector<OpFoldResult>> outputShape =
1943 return *outputShape;
1954 auto [staticOutputShape, dynamicOutputShape] =
1956 build(builder, result, cast<RankedTensorType>(resultType), src,
1958 dynamicOutputShape, staticOutputShape);
1966 auto tensorResultTy = cast<RankedTensorType>(resultType);
1967 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1968 builder, result.
location, tensorResultTy, reassociation, inputShape);
1970 if (succeeded(outputShape)) {
1971 outputShapeOrEmpty = *outputShape;
1973 build(builder, result, tensorResultTy, src, reassociation,
1974 outputShapeOrEmpty);
1982 getReassociationIndices());
1990 getReassociationIndices());
1993 RankedTensorType CollapseShapeOp::inferCollapsedType(
1995 return inferCollapsedType(
1997 type.getContext(), reassociation)));
2003 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2005 auto shape = type.getShape();
2007 newShape.reserve(reassociation.size());
2012 unsigned currentDim = 0;
2014 unsigned dim = m.getNumResults();
2015 auto band = shape.slice(currentDim, dim);
2017 if (llvm::is_contained(band, ShapedType::kDynamic))
2018 size = ShapedType::kDynamic;
2020 for (
unsigned d = 0; d < dim; ++d)
2021 size *= shape[currentDim + d];
2022 newShape.push_back(size);
2032 auto resultType = inferCollapsedType(
2033 llvm::cast<RankedTensorType>(src.
getType()),
2038 build(b, result, resultType, src, attrs);
2041 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2042 TensorReshapeOp, ExpandShapeOp>::value>
2044 RankedTensorType expandedType,
2045 RankedTensorType collapsedType) {
2050 auto maps = op.getReassociationMaps();
2051 RankedTensorType expectedType =
2052 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2054 return op.emitOpError(
"expected collapsed type to be ")
2055 << expectedType <<
", but got " << collapsedType;
2060 auto srcType = getSrcType();
2061 auto resultType = getResultType();
2063 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2064 return emitOpError(
"expected number of static shape dims to be equal to "
2065 "the output rank (")
2066 << resultType.getRank() <<
") but found "
2067 << getStaticOutputShape().size() <<
" inputs instead";
2069 if ((int64_t)getOutputShape().size() !=
2070 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2071 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2072 "static_output_shape: static_output_shape has ")
2073 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2074 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2087 template <
typename TensorReshapeOp>
2090 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2098 reshapeOp.getResultType(), attr.
getRawData());
2105 template <
typename TensorReshapeOp>
2110 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2112 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2113 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2117 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2124 template <
typename TensorReshapeOp>
2127 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2130 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2134 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2136 if (!shapedTy.hasStaticShape())
2140 fromElements.getElements());
2149 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2151 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2155 RankedTensorType srcType =
2156 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2157 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2158 srcType, collapseShapeOp.getReassociationMaps());
2160 if (newResultType == collapseShapeOp.getResultType()) {
2162 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2165 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2166 newResultType, castOp.getSource(),
2167 collapseShapeOp.getReassociation());
2169 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2179 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2182 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2184 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2190 expandOp.getReassociationIndices();
2194 auto outputIt = expandOp.getOutputShape().begin();
2196 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2197 for (uint64_t outDim : innerReassoc) {
2198 if (ShapedType::isStatic(newOutputShape[outDim]))
2205 Value val = *outputIt;
2207 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2208 dynamicOutputShape.push_back(val);
2214 newOutputShape[outDim] = cst.getSExtValue();
2216 dynamicOutputShape.push_back(val);
2222 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2227 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2228 for (
auto outDim : reassoc[inDim]) {
2229 auto ofr = newOutputShape[outDim];
2230 if (ShapedType::isDynamic(ofr)) {
2231 newInputShape[inDim] = ShapedType::kDynamic;
2234 newInputShape[inDim] *= ofr;
2241 newInputShape, expandOp.getSrcType().getElementType());
2243 newOutputShape, expandOp.getSrcType().getElementType());
2244 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2246 auto newExpand = ExpandShapeOp::create(
2247 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2248 expandOp.getReassociationIndices(), outputOfr);
2250 newExpand.getResult());
2261 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2262 FoldReshapeWithSplat<ExpandShapeOp>,
2263 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2271 tensor::DimOp, RankedTensorType>,
2272 FoldReshapeWithConstant<CollapseShapeOp>,
2273 FoldReshapeWithSplat<CollapseShapeOp>,
2274 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2278 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2279 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2280 adaptor.getOperands());
2283 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2284 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2285 adaptor.getOperands());
2292 void ExtractSliceOp::getAsmResultNames(
2294 setNameFn(getResult(),
"extracted_slice");
2300 RankedTensorType ExtractSliceOp::inferResultType(
2306 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2307 sourceTensorType.getRank() &&
2308 "unexpected staticSizes not equal to rank of source");
2310 sourceTensorType.getEncoding());
2313 RankedTensorType ExtractSliceOp::inferResultType(
2318 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2319 sourceTensorType.getRank() &&
2320 "unexpected staticSizes not equal to rank of source");
2322 sourceTensorType.getEncoding());
2333 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2334 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2338 auto inferredType = llvm::cast<RankedTensorType>(
2339 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2340 int rankDiff = inferredType.getRank() - desiredResultRank;
2342 auto shape = inferredType.getShape();
2343 llvm::SmallBitVector dimsToProject =
2347 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2348 if (!dimsToProject.test(pos))
2349 projectedShape.push_back(shape[pos]);
2353 return inferredType;
2356 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2357 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2365 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2366 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2373 RankedTensorType resultType,
Value source,
2383 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2386 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2387 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2390 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2403 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2412 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2418 RankedTensorType resultType,
Value source,
2427 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2434 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2439 RankedTensorType expectedType) {
2444 return op->
emitError(
"expected rank to be smaller or equal to ")
2445 <<
"the other rank. ";
2447 return op->
emitError(
"expected type to be ")
2448 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2450 return op->
emitError(
"expected element type to be ")
2451 << expectedType.getElementType();
2453 llvm_unreachable(
"unexpected extract_slice op verification result");
2459 RankedTensorType sourceType = getSourceType();
2462 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2463 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2471 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2472 getStaticStrides(),
true);
2474 return getOperation()->emitError(boundsResult.
errorMessage);
2486 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2487 assert(sourceTensorType &&
"not a ranked tensor type");
2488 auto sourceShape = sourceTensorType.getShape();
2489 if (sourceShape.equals(desiredShape))
2491 auto maybeRankReductionMask =
2493 if (!maybeRankReductionMask)
2502 reifiedReturnShapes.resize(1);
2503 reifiedReturnShapes[0].reserve(
getType().getRank());
2506 for (
const auto &size :
enumerate(mixedSizes)) {
2507 if (droppedDims.test(size.index()))
2509 reifiedReturnShapes[0].push_back(size.value());
2530 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2534 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2537 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2538 return matchPattern(operand, matchConstantIndex());
2542 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2551 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2552 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2553 sliceOp.getStaticStrides());
2559 Value newResult = ExtractSliceOp::create(
2560 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2561 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2562 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2563 sliceOp.getStaticStrides());
2572 template <
typename IterTy,
typename ElemTy>
2577 assert(offsets.size() == sizes.size());
2578 assert(offsets.size() == strides.size());
2579 if (offsets.empty())
2582 int64_t offset = offsets.front();
2583 int64_t size = sizes.front();
2584 int64_t stride = strides.front();
2585 if (offsets.size() == 1) {
2586 for (int64_t i = 0; i < size; ++i, offset += stride)
2587 outValues->push_back(*(values + offset));
2592 for (int64_t i = 0; i < size; ++i, offset += stride) {
2593 auto begin = values + offset * counts.front();
2594 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2595 offsets.drop_front(), sizes.drop_front(),
2596 strides.drop_front(), outValues);
2603 class ConstantOpExtractSliceFolder final
2608 ConstantOpExtractSliceFolder(
MLIRContext *context,
2611 controlFn(std::move(controlFn)) {}
2613 LogicalResult matchAndRewrite(ExtractSliceOp op,
2624 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2625 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2626 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2633 int64_t count = sourceType.getNumElements();
2638 auto offsets = op.getStaticOffsets();
2639 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2641 auto sizes = op.getStaticSizes();
2642 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2644 auto strides = op.getStaticStrides();
2645 if (llvm::is_contained(strides, ShapedType::kDynamic))
2651 counts.reserve(shape.size());
2652 for (int64_t v : shape) {
2654 counts.push_back(count);
2660 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2662 outValues.reserve(sourceType.getNumElements());
2663 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2664 elems.begin(), counts, offsets, sizes, strides, &outValues);
2666 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2668 outValues.reserve(sourceType.getNumElements());
2669 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2670 elems.begin(), counts, offsets, sizes, strides, &outValues);
2693 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2702 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2703 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2711 ExtractSliceOp newOp) {
2712 Value replacement = newOp.getResult();
2713 if (replacement.
getType() != op.getType())
2714 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.
getType(),
2725 ExtractSliceOpCastFolder>(context);
2729 static LogicalResult
2731 ShapedType shapedType) {
2738 auto shape = shapedType.getShape();
2739 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2753 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2756 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2757 insertOp.isSameAs(extractOp, isSame))
2758 return insertOp.getSource();
2763 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2764 if (
OpFoldResult reshapedSource = reshapeConstantSource(
2765 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2767 return reshapedSource;
2768 if (getSourceType() ==
getType() &&
2770 return this->getSource();
2779 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2780 unsigned rank = rankedTensorType.getRank();
2784 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2785 offsets, sizes, strides);
2792 void InsertSliceOp::getAsmResultNames(
2794 setNameFn(getResult(),
"inserted_slice");
2809 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2821 build(b, result, source, dest, offsets, sizes, strides, attrs);
2834 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2840 RankedTensorType srcType, RankedTensorType dstType,
2845 RankedTensorType expected = ExtractSliceOp::inferResultType(
2846 dstType, staticOffsets, staticSizes, staticStrides);
2848 *expectedType = expected;
2855 RankedTensorType expectedType;
2858 getStaticSizes(), getStaticStrides(), &expectedType);
2865 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2866 getStaticStrides(),
true);
2868 return getOperation()->emitError(boundsResult.
errorMessage);
2891 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2894 if (!prevInsertOp ||
2895 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2896 !prevInsertOp.isSameAs(insertOp, isSame))
2899 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2911 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2914 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2915 !extractOp.isSameAs(insertOp, isSame))
2918 return extractOp.getSource();
2922 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2923 getSourceType() ==
getType() &&
2925 return this->getSource();
2946 template <
typename InsertOpTy>
2947 class InsertSliceOpConstantArgumentFolder final
2952 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2967 mixedOffsets, mixedSizes, mixedStrides);
2972 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2973 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2974 mixedOffsets, mixedSizes, mixedStrides);
2975 Value toInsert = insertSliceOp.getSource();
2976 if (sourceType != insertSliceOp.getSourceType()) {
2981 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2983 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2984 sourceType, toInsert);
2987 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2988 mixedSizes, mixedStrides);
3013 template <
typename InsertOpTy>
3014 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3017 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3019 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
3020 return matchPattern(operand, matchConstantIndex());
3024 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
3025 auto castOp = v.getDefiningOp<tensor::CastOp>();
3027 return std::nullopt;
3028 return castOp.getSource();
3030 std::optional<Value> sourceCastSource =
3031 getSourceOfCastOp(insertSliceOp.getSource());
3032 std::optional<Value> destCastSource =
3033 getSourceOfCastOp(insertSliceOp.getDest());
3034 if (!sourceCastSource && !destCastSource)
3038 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3039 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3040 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3041 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3042 if (!srcType || !dstType)
3050 staticSizes, srcType.getShape(),
true);
3051 if (!rankReductionMask.has_value())
3059 int64_t rankReducedIdx = 0;
3060 for (
auto [idx, size] :
enumerate(staticSizes)) {
3061 if (!rankReductionMask.value().contains(idx) &&
3062 !srcType.isDynamicDim(rankReducedIdx)) {
3064 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3065 size = srcType.getDimSize(rankReducedIdx++);
3071 staticSizes, insertSliceOp.getStaticStrides()) !=
3076 mixedSizes, insertSliceOp.getMixedStrides());
3081 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3082 insertSliceOp.getMixedOffsets(), mixedSizes,
3083 insertSliceOp.getMixedStrides());
3086 bool isParallelInsert =
3087 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3090 insertSliceOp.getDestType(),
3119 template <
typename InsertOpTy>
3120 struct InsertSliceOpSourceCastInserter final
3124 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3126 RankedTensorType srcType = insertSliceOp.getSourceType();
3127 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3130 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131 if (std::optional<int64_t> constInt =
3136 newSrcShape[i] = *constInt;
3143 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144 if (srcType == newSrcType ||
3146 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3158 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3160 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3161 newSrcType, insertSliceOp.getSource());
3163 insertSliceOp, cast, insertSliceOp.getDest(),
3164 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165 insertSliceOp.getMixedStrides());
3177 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3178 InsertSliceOpCastFolder<InsertSliceOp>,
3179 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3186 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3187 unsigned rank = rankedTensorType.getRank();
3191 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3200 setNameFn(getResult(),
"padded");
3204 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3205 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3207 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3208 if (!expectedType) {
3209 return emitError(
"failed to infer expectedType from sourceType ")
3210 << sourceType <<
", specified resultType is " << resultType;
3212 if (resultType.getRank() != expectedType.getRank()) {
3214 << resultType <<
" does not match the inferred type "
3217 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3218 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3220 if (expectedType.isDynamicDim(i))
3223 << resultType <<
" does not match the inferred type "
3230 LogicalResult PadOp::verifyRegions() {
3231 auto ®ion = getRegion();
3232 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3235 return emitError(
"expected the block to have ") << rank <<
" arguments";
3239 if (!en.value().isIndex())
3240 return emitOpError(
"expected block argument ")
3241 << (en.index() + 1) <<
" to be an index";
3246 if (yieldOp.getValue().getType() !=
3248 return emitOpError(
"expected yield type to match shape element type");
3253 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3257 unsigned rank = sourceType.getRank();
3258 if (staticLow.size() != rank)
3259 return RankedTensorType();
3260 if (staticHigh.size() != rank)
3261 return RankedTensorType();
3262 if (!resultShape.empty() && resultShape.size() != rank)
3263 return RankedTensorType();
3266 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3267 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3268 staticHigh[i] == ShapedType::kDynamic) {
3269 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3272 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3273 assert((resultShape.empty() || size == resultShape[i] ||
3274 resultShape[i] == ShapedType::kDynamic) &&
3275 "mismatch between inferred shape and result shape");
3276 inferredShape.push_back(size);
3287 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3289 resultType = inferResultType(sourceType, staticLow, staticHigh);
3291 build(b, result, resultType, source, low, high,
3299 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3300 unsigned rank = sourceType.getRank();
3302 build(b, result, resultType, source, staticVector, staticVector, low, high,
3310 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3320 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3322 assert(llvm::isa<RankedTensorType>(resultType));
3324 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3333 build(b, result, resultType, source, low, high, nofold, attrs);
3337 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3344 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3345 tensor::YieldOp::create(b, result.
location, constantPadValue);
3348 llvm::SmallBitVector PadOp::getPaddedDims() {
3349 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3351 for (
const auto &en :
enumerate(paddingWidths))
3353 paddedDims.set(en.index());
3355 extractPaddedDims(getMixedLowPad());
3356 extractPaddedDims(getMixedHighPad());
3366 LogicalResult matchAndRewrite(PadOp padTensorOp,
3368 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3370 if (padTensorOp.getNofold())
3373 padTensorOp, padTensorOp.getResult().getType(),
3374 padTensorOp.getSource());
3383 LogicalResult matchAndRewrite(PadOp padTensorOp,
3385 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3389 auto newResultType = PadOp::inferResultType(
3390 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3391 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3392 padTensorOp.getResultType().getShape());
3394 if (newResultType == padTensorOp.getResultType()) {
3396 padTensorOp.getSourceMutable().assign(castOp.getSource());
3399 auto newOp = PadOp::create(
3400 rewriter, padTensorOp->getLoc(), newResultType,
3401 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3402 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3403 padTensorOp.getHigh(), padTensorOp.getNofold(),
3406 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3409 padTensorOp, padTensorOp.getResultType(), newOp);
3420 LogicalResult matchAndRewrite(PadOp padTensorOp,
3422 if (!padTensorOp.getResult().hasOneUse())
3425 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3429 tensorCastOp.getDest().getType()))
3432 auto replacementOp = PadOp::create(
3433 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3434 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3435 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3436 padTensorOp.getHigh(), padTensorOp.getNofold(),
3438 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3440 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3441 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3484 LogicalResult matchAndRewrite(PadOp padOp,
3486 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3489 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3490 if (!outerPadOp || outerPadOp.getNofold())
3492 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3497 int64_t rank = padOp.getSourceType().getRank();
3498 if (outerSliceOp.getSourceType().getRank() != rank) {
3500 "cannot fold rank-reducing chain");
3504 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3506 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3510 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3512 "cannot fold PadOps with low padding");
3517 Value innerValue = padOp.getConstantPaddingValue();
3518 Value outerValue = outerPadOp.getConstantPaddingValue();
3519 if (!innerValue || !outerValue ||
3522 innerAttr != outerAttr) {
3524 padOp,
"cannot fold PadOps with different padding values");
3528 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3529 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3530 if (innerDims.anyCommon(outerDims)) {
3532 padOp,
"cannot fold PadOps with common padding dimensions");
3542 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3543 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3544 if (!innerDims.test(en.index()) &&
3546 en.value() = outerOffset;
3549 if (!outerDims.test(en.index()) &&
3551 en.value() = innerOffset;
3555 padOp,
"cannot find zero-offset and zero-padding pair");
3565 if (!outerDims.test(en.index()))
3567 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3568 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3569 assert(ShapedType::isStatic(sourceSize) &&
3570 "expected padded dimension to have a static size");
3573 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3574 "match the size of the outer padding");
3576 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3582 if (innerDims.test(en.index()))
3583 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3584 if (outerDims.test(en.index()))
3585 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3590 auto newSliceOp = ExtractSliceOp::create(
3591 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3592 newSizes, innerSliceOp.getMixedStrides());
3593 auto newPadOp = PadOp::create(
3594 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3595 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3598 newPadOp.getRegion().begin());
3599 rewriter.
replaceOp(padOp, newPadOp.getResult());
3607 LogicalResult matchAndRewrite(PadOp padTensorOp,
3609 Value input = padTensorOp.getSource();
3610 if (!llvm::isa<RankedTensorType>(input.
getType()))
3612 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3613 auto inputRank = inputDims.size();
3615 auto oldResultType =
3616 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3620 auto outputDims = oldResultType.getShape();
3625 for (
auto operand : padTensorOp.getLow()) {
3628 constOperandsLow.push_back(ShapedType::kDynamic);
3629 newLows.push_back(operand);
3632 constOperandsLow.push_back(intOp.getExtValue());
3636 for (
auto operand : padTensorOp.getHigh()) {
3639 constOperandsHigh.push_back(ShapedType::kDynamic);
3640 newHighs.push_back(operand);
3643 constOperandsHigh.push_back(intOp.getExtValue());
3650 if (inputDims.size() != outputDims.size() ||
3651 inputDims.size() != constLow.size() ||
3652 inputDims.size() != constHigh.size())
3657 for (
size_t i = 0; i < inputRank; i++) {
3658 if (constLow[i] == ShapedType::kDynamic)
3659 constLow[i] = constOperandsLow[lowCount++];
3660 if (constHigh[i] == ShapedType::kDynamic)
3661 constHigh[i] = constOperandsHigh[highCount++];
3669 for (
size_t i = 0; i < inputRank; i++) {
3670 if (outputDims[i] == ShapedType::kDynamic) {
3671 newOutDims.push_back(
3672 (staticLow[i] == ShapedType::kDynamic ||
3673 staticHigh[i] == ShapedType::kDynamic ||
3674 inputDims[i] == ShapedType::kDynamic
3675 ? ShapedType::kDynamic
3676 : inputDims[i] + staticLow[i] + staticHigh[i]));
3678 newOutDims.push_back(outputDims[i]);
3683 llvm::all_of(newOutDims,
3684 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3689 newOutDims, padTensorOp.getType().getElementType());
3690 auto newOp = PadOp::create(
3691 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3692 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3696 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3724 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3727 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3729 if (padOp.getNofold()) {
3733 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3734 if (!producerPad || producerPad.getNofold()) {
3736 padOp,
"producer is not a foldable tensor.pad op");
3740 Value consumerPadValue = padOp.getConstantPaddingValue();
3741 Value producerPadValue = producerPad.getConstantPaddingValue();
3742 if (!consumerPadValue || !producerPadValue ||
3743 consumerPadValue != producerPadValue) {
3746 "cannot fold PadOps with different or non-constant padding values");
3757 for (
auto [consumerIndex, producerIndex] :
3758 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3760 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3766 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3768 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3770 auto newPadOp = tensor::PadOp::create(
3771 rewriter, padOp.getLoc(), padOp.getResultType(),
3772 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3775 newPadOp.getRegion().begin());
3776 rewriter.
replaceOp(padOp, newPadOp.getResult());
3789 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3790 if (!
getType().isDynamicDim(i)) {
3801 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3808 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3809 FoldOrthogonalPaddings, FoldStaticPadding,
3810 FoldConsecutiveConstantPadding>(context);
3822 Value PadOp::getConstantPaddingValue() {
3823 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3826 Value padValue = yieldOp.getValue();
3838 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3848 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3849 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3850 for (
const auto &it :
3853 if (&nextOp == getOperation())
3854 return parallelCombiningParent.getParentResult(it.index());
3856 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3872 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3885 build(b, result, source, dest, offsets, sizes, strides, attrs);
3899 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3903 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3904 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3905 << *(getOperation()->getParentOp());
3908 RankedTensorType expectedType;
3911 getStaticSizes(), getStaticStrides(), &expectedType);
3918 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3919 getStaticStrides(),
true);
3921 return getOperation()->emitError(boundsResult.
errorMessage);
3926 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3928 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3929 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3930 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3939 return getDestMutable();
3942 Operation *ParallelInsertSliceOp::getIteratingParent() {
3944 if (
auto combiningOp =
3945 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3954 void ScatterOp::getAsmResultNames(
3956 setNameFn(getResult(),
"scatter");
3960 int64_t destRank = getDestType().getRank();
3963 getIndicesType().
getShape(), destRank,
3964 "scatter",
"dest")))
3968 return emitOpError(
"requires 'unique' attribute to be set");
3975 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3976 getDestType(), getIndicesType(), scatterDims,
false);
3977 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3978 getDestType(), getIndicesType(), scatterDims,
true);
3979 if (getSourceType() != expectedSourceType &&
3980 getSourceType() != expectedRankReducedSourceType) {
3981 return emitOpError(
"source type "
3984 << expectedSourceType <<
" or its rank-reduced variant "
3985 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3998 build(builder, result, aggregateType, element, dynamicSizes);
4004 build(builder, result, aggregateType, element, dynamicSizes);
4012 build(builder, result, element, staticShape, dynamicSizes);
4015 void SplatOp::getAsmResultNames(
4017 setNameFn(getResult(),
"splat");
4022 return emitOpError(
"incorrect number of dynamic sizes, has ")
4024 <<
getType().getNumDynamicDims();
4033 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4034 if (
getType().isDynamicDim(i)) {
4044 auto constOperand = adaptor.getInput();
4045 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4049 if (!
getType().hasStaticShape())
4064 if (isa<InsertSliceOp>(op.getOperation()) ||
4065 isa<LoopLikeOpInterface>(op.getOperation()))
4098 isa<linalg::RelayoutOpInterface>(*op))
4106 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4109 replacements.reserve(newOp->getNumResults());
4110 for (
auto [oldResult, newResult] :
4111 llvm::zip(op->getResults(), newOp->getResults())) {
4112 if (newResult.getType() != oldResult.getType()) {
4113 replacements.push_back(tensor::CastOp::create(
4114 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4116 replacements.push_back(newResult);
4129 void TensorDialect::getCanonicalizationPatterns(
4138 #define GET_OP_CLASSES
4139 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.