28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
39 using llvm::divideCeilSigned;
40 using llvm::divideFloorSigned;
48 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 if (complex::ConstantOp::isBuildableWith(value, type))
51 return builder.
create<complex::ConstantOp>(loc, type,
52 llvm::cast<ArrayAttr>(value));
58 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (int64_t sz : tensorType.getShape())
106 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
116 if (failed(destination))
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
146 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179 static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (!ShapedType::isDynamic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
362 operand.set(castOp.getOperand());
366 return success(folded);
370 if (inputs.size() != 1 || outputs.size() != 1)
372 Type a = inputs.front(), b = outputs.front();
373 auto aT = llvm::dyn_cast<TensorType>(a);
374 auto bT = llvm::dyn_cast<TensorType>(b);
378 if (aT.getElementType() != bT.getElementType())
394 int64_t rank = one.getRank();
395 if (rank != two.getRank())
400 for (int64_t i = 0; i < rank; ++i) {
401 if (one.isDynamicDim(i)) {
402 join.push_back(two.getDimSize(i));
405 if (two.isDynamicDim(i)) {
406 join.push_back(one.getDimSize(i));
409 if (one.getDimSize(i) != two.getDimSize(i))
411 join.push_back(one.getDimSize(i));
423 LogicalResult matchAndRewrite(CastOp tensorCast,
425 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 if (!tensorCastOperand)
431 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
447 auto newJoin =
joinShapes(sourceType, resultType);
448 if (firstJoin != newJoin)
451 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452 tensorCastOperand.getOperand());
472 LogicalResult matchAndRewrite(CastOp tensorCast,
474 auto extractOperand =
475 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
478 auto rankedResultType =
479 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480 if (!rankedResultType)
484 rankedResultType.getShape() ==
485 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
491 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
493 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
494 if (dimMask && dimMask->count(i))
496 int64_t dim = rankedResultType.getShape()[dimIndex++];
497 if (ShapedType::isDynamic(dim))
499 sizes[i] = rewriter.getIndexAttr(dim);
502 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503 tensorCast, rankedResultType, extractOperand.getSource(),
504 extractOperand.getMixedOffsets(), sizes,
505 extractOperand.getMixedStrides());
514 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
521 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
522 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
524 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
525 return llvm::cast<RankedTensorType>(type);
527 int64_t concatRank = tensorTypes[0].getRank();
530 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
533 for (int64_t i = 0, e = concatRank; i < e; ++i) {
537 for (
auto tensorType : tensorTypes)
542 for (
auto tensorType : tensorTypes)
545 sizes[dim] = concatSize.asInteger();
551 FailureOr<RankedTensorType> resultType =
552 inferResultType(dim, inputs.
getTypes());
553 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
554 build(builder, result, *resultType, dim, inputs);
558 if (getInputs().size() < 1)
559 return emitOpError(
"requires at least one input");
562 for (
auto input : getInputs())
563 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 RankedTensorType resultType = getResultType();
566 int64_t resultRank = getRank();
567 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568 return type.getRank() != resultRank;
570 return emitOpError(
"rank of concatenated inputs must match result rank");
572 Type resultElementType = resultType.getElementType();
573 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574 return type.getElementType() != resultElementType;
576 return emitOpError(
"inputs and result element type must match");
578 int64_t dim = getDim();
579 if (dim >= resultRank)
580 return emitOpError(
"concatenation dim must be less than the tensor rank");
583 for (int64_t i = 0, e = resultRank; i < e; ++i) {
587 for (
auto tensorType : inputTypes) {
588 FailureOr<SaturatedInteger> maybeSize =
590 if (failed(maybeSize))
591 return emitOpError(
"static concatenation size mismatch along ")
592 <<
"non-concatenated dimension " << i;
598 for (
auto tensorType : inputTypes)
601 sizes[dim] = concatSize.asInteger();
602 auto inferredResultType =
605 for (
auto [inferredSize, actualSize] :
606 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608 ShapedType::isDynamic(actualSize);
609 if (!hasDynamic && inferredSize != actualSize)
610 return emitOpError(
"result type ")
611 << resultType <<
"does not match inferred shape "
612 << inferredResultType <<
" static sizes";
618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
619 size_t numInputs = getInputs().size();
620 uint64_t concatDim = getDim();
623 inputShapes.reserve(numInputs);
625 concatOffsets.reserve(numInputs);
636 outputShape = inputShape;
637 concatOffsets.push_back(zero);
639 concatOffsets.push_back(outputShape[concatDim]);
641 builder, loc, addExpr,
642 {outputShape[concatDim], inputShape[concatDim]});
644 inputShapes.emplace_back(std::move(inputShape));
647 Value replacement = builder.
create<tensor::EmptyOp>(
648 loc, outputShape,
getType().getElementType());
650 int64_t rank =
getType().getRank();
655 offsets[concatDim] = concatOffsets[index];
656 auto insertSlice = builder.
create<tensor::InsertSliceOp>(
657 loc, input, replacement, offsets, inputShapes[index], strides);
660 if (replacement.getType() !=
getType()) {
661 replacement = builder.
create<tensor::CastOp>(loc,
getType(), replacement);
670 int64_t dim = getDim();
671 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
673 Value init = inputs[0];
674 int64_t rank =
getType().getRank();
681 for (int64_t i = 0; i < rank; ++i) {
684 if (!
getType().isDynamicDim(i)) {
686 }
else if (!inferredResultType.isDynamicDim(i)) {
689 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
691 reifiedReturnShapes[0][i] =
692 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
696 if (
getType().isDynamicDim(dim)) {
704 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
712 reifiedReturnShapes[0][dim] =
718 void ConcatOp::getAsmResultNames(
720 setNameFn(getResult(),
"concat");
725 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
735 LogicalResult matchAndRewrite(ConcatOp concatOp,
737 if (concatOp.getInputs().size() != 1)
740 concatOp.getInputs()[0]);
748 results.
add<SingleInputConcatOp>(context);
756 setNameFn(getResult(),
"dim");
762 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
763 build(builder, result, source, indexValue);
766 std::optional<int64_t> DimOp::getConstantIndex() {
775 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
776 if (!rankedSourceType)
787 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
792 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
798 int64_t indexVal = index.getInt();
799 if (indexVal < 0 || indexVal >= tensorType.getRank())
803 if (!tensorType.isDynamicDim(index.getInt())) {
805 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
808 Operation *definingOp = getSource().getDefiningOp();
811 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
813 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
816 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
819 auto dynExtents = fromElements.getDynamicExtents().begin();
820 for (
auto dim : resultType.getShape().take_front(index.getInt()))
821 if (ShapedType::isDynamic(dim))
824 return Value{*dynExtents};
828 unsigned unsignedIndex = index.getValue().getZExtValue();
830 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
833 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
834 sliceOp.isDynamicSize(unsignedIndex)) {
835 return {sliceOp.getDynamicSize(unsignedIndex)};
851 LogicalResult matchAndRewrite(DimOp dimOp,
853 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
856 Value newSource = castOp.getOperand();
867 LogicalResult matchAndRewrite(DimOp dimOp,
869 auto source = dimOp.getSource();
870 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
874 auto resultIndex = cast<OpResult>(source).getResultNumber();
875 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
878 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
888 LogicalResult matchAndRewrite(DimOp dim,
890 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
900 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
901 if (extract.
getType() != dim.getType())
903 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
912 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
922 assert(all_of(staticShape,
923 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
924 "expected only static sizes");
925 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
932 build(builder, result, tensorType, dynamicSizes);
941 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
946 return emitOpError(
"incorrect number of dynamic sizes, has ")
948 <<
getType().getNumDynamicDims();
957 for (int64_t i = 0; i <
getType().getRank(); ++i) {
958 if (
getType().isDynamicDim(i)) {
967 Value EmptyOp::getDynamicSize(
unsigned idx) {
968 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
970 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
980 for (int64_t i = 0; i <
getType().getRank(); ++i) {
981 if (
getType().isDynamicDim(i)) {
1005 LogicalResult matchAndRewrite(EmptyOp op,
1009 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1012 if (foldedTensorType == op.getType())
1015 auto newOp = rewriter.
create<EmptyOp>(op.getLoc(), foldedTensorType,
1016 foldedDynamicSizes);
1025 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1027 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1028 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1029 if (!emptyTensorOp || !maybeConstantIndex)
1031 auto emptyTensorType = emptyTensorOp.getType();
1032 if (*maybeConstantIndex < 0 ||
1033 *maybeConstantIndex >= emptyTensorType.getRank() ||
1034 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1037 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1060 LogicalResult matchAndRewrite(CastOp castOp,
1064 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1069 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1073 newMixedSizes.reserve(currMixedSizes.size());
1074 assert(resultShape.size() == currMixedSizes.size() &&
1075 "mismatch in result shape and sizes of empty op");
1076 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1077 int64_t newDim = std::get<0>(it);
1081 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1082 if (ShapedType::isDynamic(newDim) ||
1083 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1088 producer,
"mismatch in static value of shape of empty tensor "
1089 "result and cast result");
1091 newMixedSizes.push_back(attr);
1097 if (!ShapedType::isDynamic(newDim)) {
1098 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1104 newMixedSizes.push_back(currDim);
1109 resultType.getElementType());
1118 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1119 ReplaceEmptyTensorStaticShapeDims>(context);
1128 std::optional<Attribute> cst = std::nullopt) {
1129 if (source && source.
isSplat() && result.hasStaticShape() &&
1150 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1153 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1155 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1158 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1161 extract, tensorCast.getSource(), extract.getIndices());
1168 void ExtractOp::getAsmResultNames(
1170 setNameFn(getResult(),
"extracted");
1175 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1176 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1177 return emitOpError(
"incorrect number of indices for extract_element");
1182 if (
Attribute tensor = adaptor.getTensor()) {
1185 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1186 return splatTensor.getSplatValue<
Attribute>();
1189 if (isa<DenseResourceElementsAttr>(tensor))
1195 for (
Attribute indice : adaptor.getIndices()) {
1196 if (!indice || !llvm::isa<IntegerAttr>(indice))
1198 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1202 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1203 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1204 auto rank = tensorType.getRank();
1205 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1209 for (
int i = rank - 1; i >= 0; --i) {
1210 flatIndex += indices[i] * stride;
1211 stride *= tensorType.getDimSize(i);
1215 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1218 return fromElementsOp.getElements()[flatIndex];
1222 if (
Attribute tensor = adaptor.getTensor()) {
1223 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1224 if (elementsAttr && elementsAttr.isValidIndex(indices))
1225 return elementsAttr.getValues<
Attribute>()[indices];
1233 results.
add<ExtractFromTensorCast>(context);
1240 void FromElementsOp::getAsmResultNames(
1242 setNameFn(getResult(),
"from_elements");
1247 assert(!elements.empty() &&
"expected at least one element");
1249 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1250 build(builder, result, resultType, elements);
1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1254 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1277 struct ExtractElementFromIndexCast
1281 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1284 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1290 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1291 loc, elementTy, indexCast.getIn(), extract.getIndices());
1304 results.
add<ExtractElementFromIndexCast>(context);
1311 void GatherOp::getAsmResultNames(
1313 setNameFn(getResult(),
"gather");
1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1329 RankedTensorType indicesType,
1333 resultShape.reserve(resultShape.size() + sourceType.getRank());
1334 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1335 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1337 resultShape.push_back(1);
1340 resultShape.push_back(sourceType.getDimSize(idx));
1345 static LogicalResult
1348 StringRef gatherOrScatter, StringRef sourceOrDest) {
1350 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1352 int64_t numGatherDims = dims.size();
1353 if (numGatherDims > rank)
1355 <<
"_dims overflow " << sourceOrDest <<
" rank";
1356 if (indices.empty() || indices.back() != numGatherDims)
1358 <<
"_dims length must match the size of last dimension of indices";
1359 for (int64_t val : dims) {
1362 <<
"_dims value must be non-negative";
1365 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1367 for (int64_t i = 1; i < numGatherDims; ++i) {
1368 if (dims[i - 1] >= dims[i])
1370 <<
"_dims values must be strictly increasing";
1376 int64_t sourceRank = getSourceType().getRank();
1379 getIndicesType().
getShape(), sourceRank,
1380 "gather",
"source")))
1383 RankedTensorType expectedResultType = GatherOp::inferResultType(
1384 getSourceType(), getIndicesType(), gatherDims,
false);
1385 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1386 getSourceType(), getIndicesType(), gatherDims,
true);
1387 if (getResultType() != expectedResultType &&
1388 getResultType() != expectedRankReducedResultType) {
1389 return emitOpError(
"result type "
1392 << expectedResultType <<
" or its rank-reduced variant "
1393 << expectedRankReducedResultType <<
" (got: " << getResultType()
1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1404 return reshapedSource;
1412 void InsertOp::getAsmResultNames(
1414 setNameFn(getResult(),
"inserted");
1419 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1420 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1421 return emitOpError(
"incorrect number of indices");
1429 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1430 if (scalar == splatDest.getSplatValue<
Attribute>())
1439 void GenerateOp::getAsmResultNames(
1441 setNameFn(getResult(),
"generated");
1448 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1449 if (
getType().isDynamicDim(dim)) {
1450 reifiedReturnShapes[0][dim] = getOperand(idx++);
1452 reifiedReturnShapes[0][dim] =
1462 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1463 if (getNumOperands() != resultType.getNumDynamicDims())
1464 return emitError(
"must have as many index operands as dynamic extents "
1465 "in the result type");
1469 LogicalResult GenerateOp::verifyRegions() {
1470 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1472 if (!llvm::all_of(getBody().getArgumentTypes(),
1474 return emitError(
"all body arguments must be index");
1475 if (getBody().getNumArguments() != resultTy.getRank())
1476 return emitError(
"must have one body argument per input dimension");
1479 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1481 if (yieldOp.getValue().getType() != resultTy.getElementType())
1483 "body must be terminated with a `yield` operation of the tensor "
1489 void GenerateOp::build(
1493 build(b, result, resultTy, dynamicExtents);
1498 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1502 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1515 LogicalResult matchAndRewrite(GenerateOp generateOp,
1519 generateOp.getType(), generateOp.getDynamicExtents(),
1520 foldedDynamicSizes);
1523 if (foldedTensorType == generateOp.getType())
1526 auto loc = generateOp.getLoc();
1528 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1530 newOp.getBody().begin());
1532 generateOp.getType(), newOp);
1548 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1551 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1553 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1558 Block *body = &tensorFromElements.getBody().
front();
1561 rewriter.
clone(op, mapping);
1575 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1582 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1583 setNameFn(getResult(),
"rank");
1588 auto type = getOperand().getType();
1589 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1590 if (shapedType && shapedType.hasRank())
1592 return IntegerAttr();
1599 void ReshapeOp::getAsmResultNames(
1601 setNameFn(getResult(),
"reshape");
1605 int64_t numElements = 1;
1606 for (
auto dim : type.getShape())
1616 return emitOpError(
"element types of source and destination tensor "
1617 "types should be the same");
1621 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1622 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1624 if (resultRankedType) {
1625 if (operandRankedType && resultRankedType.hasStaticShape() &&
1626 operandRankedType.hasStaticShape()) {
1628 return emitOpError(
"source and destination tensor should have the "
1629 "same number of elements");
1631 if (ShapedType::isDynamic(shapeSize))
1632 return emitOpError(
"cannot use shape operand with dynamic length to "
1633 "reshape to statically-ranked tensor type");
1634 if (shapeSize != resultRankedType.getRank())
1636 "length of shape operand differs from the result's tensor rank");
1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1645 return reshapedSource;
1650 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1651 getSourceMutable().assign(reshapeOpProducer.getSource());
1655 auto source = getSource();
1656 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1657 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1658 if (!sourceTy || !resultTy || sourceTy != resultTy)
1663 if (sourceTy.getRank() == 1)
1666 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1667 auto elements = fromElements.getElements();
1669 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1670 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1671 auto element = elements[id];
1674 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1678 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1679 dynamicNoop &= dimOp.getSource() == source;
1684 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1688 dynamicNoop =
false;
1703 void CollapseShapeOp::getAsmResultNames(
1705 setNameFn(getResult(),
"collapsed");
1708 void ExpandShapeOp::getAsmResultNames(
1710 setNameFn(getResult(),
"expanded");
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1714 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1715 "invalid resultDim");
1717 if (llvm::is_contained(it.value(), resultDim))
1719 llvm_unreachable(
"could not find reassociation group");
1722 FailureOr<SmallVector<OpFoldResult>>
1724 RankedTensorType expandedType,
1727 std::optional<SmallVector<OpFoldResult>> outputShape =
1732 return *outputShape;
1739 auto [staticOutputShape, dynamicOutputShape] =
1741 build(builder, result, cast<RankedTensorType>(resultType), src,
1743 dynamicOutputShape, staticOutputShape);
1751 auto tensorResultTy = cast<RankedTensorType>(resultType);
1752 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1753 builder, result.
location, tensorResultTy, reassociation, inputShape);
1755 if (succeeded(outputShape)) {
1756 outputShapeOrEmpty = *outputShape;
1758 build(builder, result, tensorResultTy, src, reassociation,
1759 outputShapeOrEmpty);
1767 getReassociationIndices());
1775 getReassociationIndices());
1778 RankedTensorType CollapseShapeOp::inferCollapsedType(
1780 return inferCollapsedType(
1782 type.getContext(), reassociation)));
1788 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1790 auto shape = type.getShape();
1792 newShape.reserve(reassociation.size());
1797 unsigned currentDim = 0;
1799 unsigned dim = m.getNumResults();
1800 auto band = shape.slice(currentDim, dim);
1802 if (llvm::is_contained(band, ShapedType::kDynamic))
1803 size = ShapedType::kDynamic;
1805 for (
unsigned d = 0; d < dim; ++d)
1806 size *= shape[currentDim + d];
1807 newShape.push_back(size);
1817 auto resultType = inferCollapsedType(
1818 llvm::cast<RankedTensorType>(src.
getType()),
1823 build(b, result, resultType, src, attrs);
1826 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1827 TensorReshapeOp, ExpandShapeOp>::value>
1829 RankedTensorType expandedType,
1830 RankedTensorType collapsedType) {
1835 auto maps = op.getReassociationMaps();
1836 RankedTensorType expectedType =
1837 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1839 return op.emitOpError(
"expected collapsed type to be ")
1840 << expectedType <<
", but got " << collapsedType;
1845 auto srcType = getSrcType();
1846 auto resultType = getResultType();
1848 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1849 return emitOpError(
"expected number of static shape dims to be equal to "
1850 "the output rank (")
1851 << resultType.getRank() <<
") but found "
1852 << getStaticOutputShape().size() <<
" inputs instead";
1854 if ((int64_t)getOutputShape().size() !=
1855 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1856 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1857 "static_output_shape: static_output_shape has ")
1858 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1859 <<
" dynamic dims while output_shape has " << getOutputShape().size()
1872 template <
typename TensorReshapeOp>
1875 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1883 reshapeOp.getResultType(), attr.
getRawData());
1890 template <
typename TensorReshapeOp>
1895 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1897 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1898 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1902 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1909 template <
typename TensorReshapeOp>
1912 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1915 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1919 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1921 if (!shapedTy.hasStaticShape())
1925 fromElements.getElements());
1934 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1936 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1940 RankedTensorType srcType =
1941 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1942 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1943 srcType, collapseShapeOp.getReassociationMaps());
1945 if (newResultType == collapseShapeOp.getResultType()) {
1947 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1950 auto newOp = rewriter.
create<CollapseShapeOp>(
1951 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1952 collapseShapeOp.getReassociation());
1954 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1963 LogicalResult matchAndRewrite(DimOp dimOp,
1965 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1970 std::optional<int64_t> dim = dimOp.getConstantIndex();
1971 if (!dim.has_value())
1975 RankedTensorType resultType = expandShapeOp.getResultType();
1976 if (!resultType.isDynamicDim(*dim))
1980 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1986 for (int64_t d : grp) {
1988 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1989 product *= resultType.getDimSize(d);
1995 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1999 dimOp, expr.floorDiv(
product), srcDimSz);
2007 LogicalResult matchAndRewrite(DimOp dimOp,
2009 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2010 if (!collapseShapeOp)
2014 std::optional<int64_t> dim = dimOp.getConstantIndex();
2015 if (!dim.has_value())
2019 RankedTensorType resultType = collapseShapeOp.getResultType();
2020 if (!resultType.isDynamicDim(*dim))
2025 collapseShapeOp.getReassociationIndices()[*dim];
2032 srcDimSizes.push_back(rewriter.
create<DimOp>(
2033 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2047 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2050 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2052 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2058 expandOp.getReassociationIndices();
2062 auto outputIt = expandOp.getOutputShape().begin();
2064 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2065 for (uint64_t outDim : innerReassoc) {
2066 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2073 Value val = *outputIt;
2075 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2076 dynamicOutputShape.push_back(val);
2082 newOutputShape[outDim] = cst.getSExtValue();
2084 dynamicOutputShape.push_back(val);
2090 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2095 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2096 for (
auto outDim : reassoc[inDim]) {
2097 auto ofr = newOutputShape[outDim];
2098 if (ShapedType::isDynamic(ofr)) {
2099 newInputShape[inDim] = ShapedType::kDynamic;
2102 newInputShape[inDim] *= ofr;
2109 newInputShape, expandOp.getSrcType().getElementType());
2111 newOutputShape, expandOp.getSrcType().getElementType());
2112 auto inputCast = rewriter.
create<CastOp>(expandOp.getLoc(), inputType,
2114 auto newExpand = rewriter.
create<ExpandShapeOp>(
2115 expandOp.getLoc(), outputType, inputCast.getResult(),
2116 expandOp.getReassociationIndices(), outputOfr);
2118 newExpand.getResult());
2129 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2130 FoldReshapeWithSplat<ExpandShapeOp>,
2131 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2132 FoldDimOfCollapseShape>(context);
2140 tensor::DimOp, RankedTensorType>,
2141 FoldReshapeWithConstant<CollapseShapeOp>,
2142 FoldReshapeWithSplat<CollapseShapeOp>,
2143 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2147 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2148 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2149 adaptor.getOperands());
2152 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2153 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2154 adaptor.getOperands());
2161 void ExtractSliceOp::getAsmResultNames(
2163 setNameFn(getResult(),
"extracted_slice");
2169 RankedTensorType ExtractSliceOp::inferResultType(
2175 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2176 sourceTensorType.getRank() &&
2177 "unexpected staticSizes not equal to rank of source");
2179 sourceTensorType.getEncoding());
2182 RankedTensorType ExtractSliceOp::inferResultType(
2190 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2191 staticSizes, staticStrides);
2202 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2203 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2207 auto inferredType = llvm::cast<RankedTensorType>(
2208 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2209 int rankDiff = inferredType.getRank() - desiredResultRank;
2211 auto shape = inferredType.getShape();
2212 llvm::SmallBitVector dimsToProject =
2216 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2217 if (!dimsToProject.test(pos))
2218 projectedShape.push_back(shape[pos]);
2222 return inferredType;
2225 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2226 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2234 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2235 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2242 RankedTensorType resultType,
Value source,
2252 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2255 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2256 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2259 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2272 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2281 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2287 RankedTensorType resultType,
Value source,
2296 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2303 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2308 RankedTensorType expectedType) {
2313 return op->
emitError(
"expected rank to be smaller or equal to ")
2314 <<
"the other rank. ";
2316 return op->
emitError(
"expected type to be ")
2317 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2319 return op->
emitError(
"expected element type to be ")
2320 << expectedType.getElementType();
2322 llvm_unreachable(
"unexpected extract_slice op verification result");
2329 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2330 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2342 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2343 assert(sourceTensorType &&
"not a ranked tensor type");
2344 auto sourceShape = sourceTensorType.getShape();
2345 if (sourceShape.equals(desiredShape))
2347 auto maybeRankReductionMask =
2349 if (!maybeRankReductionMask)
2358 reifiedReturnShapes.resize(1);
2359 reifiedReturnShapes[0].reserve(
getType().getRank());
2362 for (
const auto &size :
enumerate(mixedSizes)) {
2363 if (droppedDims.test(size.index()))
2365 reifiedReturnShapes[0].push_back(size.value());
2386 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2390 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2393 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2394 return matchPattern(operand, matchConstantIndex());
2398 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2407 Value newResult = rewriter.
create<ExtractSliceOp>(
2408 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2409 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2410 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2411 if (newResult.
getType() != sliceOp.getType())
2412 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
2421 template <
typename IterTy,
typename ElemTy>
2426 assert(offsets.size() == sizes.size());
2427 assert(offsets.size() == strides.size());
2428 if (offsets.empty())
2431 int64_t offset = offsets.front();
2432 int64_t size = sizes.front();
2433 int64_t stride = strides.front();
2434 if (offsets.size() == 1) {
2435 for (int64_t i = 0; i < size; ++i, offset += stride)
2436 outValues->push_back(*(values + offset));
2441 for (int64_t i = 0; i < size; ++i, offset += stride) {
2442 auto begin = values + offset * counts.front();
2443 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2444 offsets.drop_front(), sizes.drop_front(),
2445 strides.drop_front(), outValues);
2452 class ConstantOpExtractSliceFolder final
2457 ConstantOpExtractSliceFolder(
MLIRContext *context,
2460 controlFn(std::move(controlFn)) {}
2462 LogicalResult matchAndRewrite(ExtractSliceOp op,
2473 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2474 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2475 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2482 int64_t count = sourceType.getNumElements();
2487 auto offsets = op.getStaticOffsets();
2488 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2490 auto sizes = op.getStaticSizes();
2491 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2493 auto strides = op.getStaticStrides();
2494 if (llvm::is_contained(strides, ShapedType::kDynamic))
2500 counts.reserve(shape.size());
2501 for (int64_t v : shape) {
2503 counts.push_back(count);
2509 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2511 outValues.reserve(sourceType.getNumElements());
2512 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2513 elems.begin(), counts, offsets, sizes, strides, &outValues);
2515 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2517 outValues.reserve(sourceType.getNumElements());
2518 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2519 elems.begin(), counts, offsets, sizes, strides, &outValues);
2542 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2551 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2552 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2560 ExtractSliceOp newOp) {
2561 Value replacement = newOp.getResult();
2562 if (replacement.
getType() != op.getType())
2563 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2574 ExtractSliceOpCastFolder>(context);
2578 static LogicalResult
2580 ShapedType shapedType) {
2587 auto shape = shapedType.getShape();
2588 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2602 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2605 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2606 insertOp.isSameAs(extractOp, isSame))
2607 return insertOp.getSource();
2612 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2614 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2616 return reshapedSource;
2617 if (getSourceType() ==
getType() &&
2619 return this->getSource();
2628 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2629 unsigned rank = rankedTensorType.getRank();
2633 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2634 offsets, sizes, strides);
2641 void InsertSliceOp::getAsmResultNames(
2643 setNameFn(getResult(),
"inserted_slice");
2658 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2670 build(b, result, source, dest, offsets, sizes, strides, attrs);
2683 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2689 RankedTensorType srcType, RankedTensorType dstType,
2694 RankedTensorType expected = ExtractSliceOp::inferResultType(
2695 dstType, staticOffsets, staticSizes, staticStrides);
2697 *expectedType = expected;
2703 RankedTensorType expectedType;
2706 getStaticSizes(), getStaticStrides(), &expectedType);
2728 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2731 if (!prevInsertOp ||
2732 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2733 !prevInsertOp.isSameAs(insertOp, isSame))
2736 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2748 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2751 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2752 !extractOp.isSameAs(insertOp, isSame))
2755 return extractOp.getSource();
2759 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2760 getSourceType() ==
getType() &&
2762 return this->getSource();
2784 template <
typename InsertOpTy>
2785 class InsertSliceOpConstantArgumentFolder final
2790 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2803 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2804 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2805 mixedOffsets, mixedSizes, mixedStrides);
2806 Value toInsert = insertSliceOp.getSource();
2807 if (sourceType != insertSliceOp.getSourceType()) {
2812 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2814 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2815 sourceType, toInsert);
2818 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2819 mixedSizes, mixedStrides);
2844 template <
typename InsertOpTy>
2845 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2848 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2850 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2851 return matchPattern(operand, matchConstantIndex());
2855 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2856 auto castOp = v.getDefiningOp<tensor::CastOp>();
2858 return std::nullopt;
2859 return castOp.getSource();
2861 std::optional<Value> sourceCastSource =
2862 getSourceOfCastOp(insertSliceOp.getSource());
2863 std::optional<Value> destCastSource =
2864 getSourceOfCastOp(insertSliceOp.getDest());
2865 if (!sourceCastSource && !destCastSource)
2869 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2870 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2871 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2872 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2873 if (!srcType || !dstType)
2881 staticSizes, srcType.getShape(),
true);
2882 if (!rankReductionMask.has_value())
2890 int64_t rankReducedIdx = 0;
2891 for (
auto [idx, size] :
enumerate(staticSizes)) {
2892 if (!rankReductionMask.value().contains(idx) &&
2893 !srcType.isDynamicDim(rankReducedIdx)) {
2895 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2896 size = srcType.getDimSize(rankReducedIdx++);
2900 staticSizes, insertSliceOp.getStaticStrides()) !=
2905 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2906 mixedSizes, insertSliceOp.getMixedStrides());
2909 bool isParallelInsert =
2910 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2911 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2912 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2913 insertSliceOp.getDestType(),
2942 template <
typename InsertOpTy>
2943 struct InsertSliceOpSourceCastInserter final
2947 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2949 RankedTensorType srcType = insertSliceOp.getSourceType();
2950 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2953 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2954 if (std::optional<int64_t> constInt =
2959 newSrcShape[i] = *constInt;
2966 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2967 if (srcType == newSrcType ||
2969 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2981 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2984 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2986 insertSliceOp, cast, insertSliceOp.getDest(),
2987 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2988 insertSliceOp.getMixedStrides());
3000 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3001 InsertSliceOpCastFolder<InsertSliceOp>,
3002 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3009 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3010 unsigned rank = rankedTensorType.getRank();
3014 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3023 setNameFn(getResult(),
"padded");
3029 Type typeToInfer,
Type typeToInferFrom) {}
3033 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3034 Type &typeToInfer,
Type typeToInferFrom) {
3036 typeToInfer = typeToInferFrom;
3041 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3042 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3044 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3045 if (!expectedType) {
3046 return emitError(
"failed to infer expectedType from sourceType ")
3047 << sourceType <<
", specified resultType is " << resultType;
3049 if (resultType.getRank() != expectedType.getRank()) {
3051 << resultType <<
" does not match the inferred type "
3054 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3055 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3057 if (expectedType.isDynamicDim(i))
3060 << resultType <<
" does not match the inferred type "
3067 LogicalResult PadOp::verifyRegions() {
3068 auto ®ion = getRegion();
3069 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3072 return emitError(
"expected the block to have ") << rank <<
" arguments";
3076 if (!en.value().isIndex())
3077 return emitOpError(
"expected block argument ")
3078 << (en.index() + 1) <<
" to be an index";
3083 if (yieldOp.getValue().getType() !=
3085 return emitOpError(
"expected yield type to match shape element type");
3090 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3094 unsigned rank = sourceType.getRank();
3095 if (staticLow.size() != rank)
3096 return RankedTensorType();
3097 if (staticHigh.size() != rank)
3098 return RankedTensorType();
3099 if (!resultShape.empty() && resultShape.size() != rank)
3100 return RankedTensorType();
3103 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3104 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3105 staticHigh[i] == ShapedType::kDynamic) {
3106 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3109 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3110 assert((resultShape.empty() || size == resultShape[i] ||
3111 resultShape[i] == ShapedType::kDynamic) &&
3112 "mismatch between inferred shape and result shape");
3113 inferredShape.push_back(size);
3124 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3126 resultType = inferResultType(sourceType, staticLow, staticHigh);
3128 build(b, result, resultType, source, low, high,
3136 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3137 unsigned rank = sourceType.getRank();
3139 build(b, result, resultType, source, staticVector, staticVector, low, high,
3147 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3157 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3159 assert(llvm::isa<RankedTensorType>(resultType));
3161 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3170 build(b, result, resultType, source, low, high, nofold, attrs);
3174 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3181 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3185 llvm::SmallBitVector PadOp::getPaddedDims() {
3186 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3188 for (
const auto &en :
enumerate(paddingWidths))
3190 paddedDims.set(en.index());
3192 extractPaddedDims(getMixedLowPad());
3193 extractPaddedDims(getMixedHighPad());
3203 LogicalResult matchAndRewrite(PadOp padTensorOp,
3205 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3207 if (padTensorOp.getNofold())
3210 padTensorOp, padTensorOp.getResult().getType(),
3211 padTensorOp.getSource());
3220 LogicalResult matchAndRewrite(PadOp padTensorOp,
3222 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3226 auto newResultType = PadOp::inferResultType(
3227 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3228 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3229 padTensorOp.getResultType().getShape());
3231 if (newResultType == padTensorOp.getResultType()) {
3233 padTensorOp.getSourceMutable().assign(castOp.getSource());
3236 auto newOp = rewriter.
create<PadOp>(
3237 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3238 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3239 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3242 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3245 padTensorOp, padTensorOp.getResultType(), newOp);
3256 LogicalResult matchAndRewrite(PadOp padTensorOp,
3258 if (!padTensorOp.getResult().hasOneUse())
3261 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3265 tensorCastOp.getDest().getType()))
3268 auto replacementOp = rewriter.
create<PadOp>(
3269 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3270 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3271 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3272 padTensorOp.getHigh(), padTensorOp.getNofold(),
3276 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3277 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3320 LogicalResult matchAndRewrite(PadOp padOp,
3322 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3325 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3326 if (!outerPadOp || outerPadOp.getNofold())
3328 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3333 int64_t rank = padOp.getSourceType().getRank();
3334 if (outerSliceOp.getSourceType().getRank() != rank) {
3336 "cannot fold rank-reducing chain");
3340 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3342 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3346 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3348 "cannot fold PadOps with low padding");
3353 Value innerValue = padOp.getConstantPaddingValue();
3354 Value outerValue = outerPadOp.getConstantPaddingValue();
3355 if (!innerValue || !outerValue ||
3358 innerAttr != outerAttr) {
3360 padOp,
"cannot fold PadOps with different padding values");
3364 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3365 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3366 if (innerDims.anyCommon(outerDims)) {
3368 padOp,
"cannot fold PadOps with common padding dimensions");
3378 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3379 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3380 if (!innerDims.test(en.index()) &&
3382 en.value() = outerOffset;
3385 if (!outerDims.test(en.index()) &&
3387 en.value() = innerOffset;
3391 padOp,
"cannot find zero-offset and zero-padding pair");
3401 if (!outerDims.test(en.index()))
3403 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3404 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3405 assert(!ShapedType::isDynamic(sourceSize) &&
3406 "expected padded dimension to have a static size");
3409 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3410 "match the size of the outer padding");
3412 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3418 if (innerDims.test(en.index()))
3419 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3420 if (outerDims.test(en.index()))
3421 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3426 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3427 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3428 innerSliceOp.getMixedStrides());
3429 auto newPadOp = rewriter.
create<PadOp>(
3430 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3431 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3434 newPadOp.getRegion().begin());
3435 rewriter.
replaceOp(padOp, newPadOp.getResult());
3443 LogicalResult matchAndRewrite(PadOp padTensorOp,
3445 Value input = padTensorOp.getSource();
3446 if (!llvm::isa<RankedTensorType>(input.
getType()))
3448 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3449 auto inputRank = inputDims.size();
3451 auto oldResultType =
3452 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3456 auto outputDims = oldResultType.getShape();
3461 for (
auto operand : padTensorOp.getLow()) {
3464 constOperandsLow.push_back(ShapedType::kDynamic);
3465 newLows.push_back(operand);
3468 constOperandsLow.push_back(intOp.getExtValue());
3472 for (
auto operand : padTensorOp.getHigh()) {
3475 constOperandsHigh.push_back(ShapedType::kDynamic);
3476 newHighs.push_back(operand);
3479 constOperandsHigh.push_back(intOp.getExtValue());
3486 if (inputDims.size() != outputDims.size() ||
3487 inputDims.size() != constLow.size() ||
3488 inputDims.size() != constHigh.size())
3493 for (
size_t i = 0; i < inputRank; i++) {
3494 if (constLow[i] == ShapedType::kDynamic)
3495 constLow[i] = constOperandsLow[lowCount++];
3496 if (constHigh[i] == ShapedType::kDynamic)
3497 constHigh[i] = constOperandsHigh[highCount++];
3505 for (
size_t i = 0; i < inputRank; i++) {
3506 if (outputDims[i] == ShapedType::kDynamic) {
3507 newOutDims.push_back(
3508 (staticLow[i] == ShapedType::kDynamic ||
3509 staticHigh[i] == ShapedType::kDynamic ||
3510 inputDims[i] == ShapedType::kDynamic
3511 ? ShapedType::kDynamic
3512 : inputDims[i] + staticLow[i] + staticHigh[i]));
3514 newOutDims.push_back(outputDims[i]);
3519 llvm::all_of(newOutDims,
3520 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3525 newOutDims, padTensorOp.getType().getElementType());
3526 auto newOp = rewriter.
create<PadOp>(
3527 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3528 newLows, newHighs, padTensorOp.getNofold(),
3532 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3560 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3563 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3565 if (padOp.getNofold()) {
3569 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3570 if (!producerPad || producerPad.getNofold()) {
3572 padOp,
"producer is not a foldable tensor.pad op");
3576 Value consumerPadValue = padOp.getConstantPaddingValue();
3577 Value producerPadValue = producerPad.getConstantPaddingValue();
3578 if (!consumerPadValue || !producerPadValue ||
3579 consumerPadValue != producerPadValue) {
3582 "cannot fold PadOps with different or non-constant padding values");
3593 for (
auto [consumerIndex, producerIndex] :
3594 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3596 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3602 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3604 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3606 auto newPadOp = rewriter.
create<tensor::PadOp>(
3607 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3608 newLowPad, newHighPad, padOp.getNofold(),
3611 newPadOp.getRegion().begin());
3612 rewriter.
replaceOp(padOp, newPadOp.getResult());
3621 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3622 FoldOrthogonalPaddings, FoldStaticPadding,
3623 FoldConsecutiveConstantPadding>(context);
3635 Value PadOp::getConstantPaddingValue() {
3636 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3639 Value padValue = yieldOp.getValue();
3651 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3661 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3662 ParallelCombiningOpInterface parallelCombiningParent =
3663 getParallelCombiningParent();
3664 for (
const auto &it :
3667 if (&nextOp == getOperation())
3668 return parallelCombiningParent.getParentResult(it.index());
3670 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3686 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3699 build(b, result, source, dest, offsets, sizes, strides, attrs);
3713 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3717 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3718 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3719 << *(getOperation()->getParentOp());
3721 RankedTensorType expectedType;
3724 getStaticSizes(), getStaticStrides(), &expectedType);
3728 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3730 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3731 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3732 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3743 void ScatterOp::getAsmResultNames(
3745 setNameFn(getResult(),
"scatter");
3749 int64_t destRank = getDestType().getRank();
3752 getIndicesType().
getShape(), destRank,
3753 "scatter",
"dest")))
3757 return emitOpError(
"requires 'unique' attribute to be set");
3764 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3765 getDestType(), getIndicesType(), scatterDims,
false);
3766 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3767 getDestType(), getIndicesType(), scatterDims,
true);
3768 if (getSourceType() != expectedSourceType &&
3769 getSourceType() != expectedRankReducedSourceType) {
3770 return emitOpError(
"source type "
3773 << expectedSourceType <<
" or its rank-reduced variant "
3774 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3787 build(builder, result, aggregateType, element, dynamicSizes);
3793 build(builder, result, aggregateType, element, dynamicSizes);
3801 build(builder, result, element, staticShape, dynamicSizes);
3804 void SplatOp::getAsmResultNames(
3806 setNameFn(getResult(),
"splat");
3811 return emitOpError(
"incorrect number of dynamic sizes, has ")
3813 <<
getType().getNumDynamicDims();
3822 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3823 if (
getType().isDynamicDim(i)) {
3833 auto constOperand = adaptor.getInput();
3834 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3838 if (!
getType().hasStaticShape())
3850 template <
typename OpTy>
3851 static LogicalResult
3854 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3855 "applies to only pack or unpack operations");
3856 int64_t destRank = op.getDestRank();
3858 reifiedReturnShapes[0] =
3863 template <
typename OpTy>
3865 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3866 "applies to only pack or unpack operations");
3870 assert(tiles.size() == dimsToTile.size() &&
3871 "tiles must match indices of dimension to block");
3873 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3874 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3875 return dimAndTileMapping;
3878 template <
typename OpTy>
3880 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3881 "applies to only pack or unpack operations");
3884 unsigned dynamicValIndex = 0;
3885 for (int64_t staticTile : op.getStaticInnerTiles()) {
3886 if (!ShapedType::isDynamic(staticTile))
3889 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3891 return mixedInnerTiles;
3894 template <
typename OpTy>
3896 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3897 "applies to only pack or unpack operations");
3910 size_t dimsPosSize = dimsPos.size();
3911 if (dimsPosSize > rank)
3914 for (int64_t dim : dimsPos)
3915 uniqued.insert(dim);
3916 if (dimsPosSize != uniqued.size())
3918 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3919 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3928 sourceShape.size() == limitShape.size() &&
3929 "expected source shape rank, and limit of the shape to have same rank");
3930 return llvm::all_of(
3931 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3932 int64_t sourceExtent = std::get<0>(it);
3933 int64_t limit = std::get<1>(it);
3934 return ShapedType::isDynamic(sourceExtent) ||
3935 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3939 template <
typename OpTy>
3941 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3942 "applies to only pack or unpack operations");
3943 Operation *op = packOrUnPack.getOperation();
3947 return llvm::any_of(
3953 if (hasZeros(mixedTiles))
3954 return op->
emitError(
"invalid zero tile factor");
3957 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3958 ? packOrUnPack.getSourceType()
3959 : packOrUnPack.getDestType();
3960 size_t unpackedRank = unpackedType.getRank();
3964 return op->
emitError(
"invalid inner_dims_pos vector");
3966 return op->
emitError(
"invalid outer_dims_perm vector");
3967 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3968 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3972 if (mixedTiles.size() > unpackedRank) {
3973 return op->
emitError(
"tiling factors must be less than or equal to the "
3974 "input rank for pack or output rank for unpack");
3976 if (mixedTiles.size() != innerDimsPos.size()) {
3978 "tiling factors must equal the number of dimensions to tile");
3981 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3982 ? packOrUnPack.getDestType()
3983 : packOrUnPack.getSourceType();
3984 size_t packedRank = packedType.getRank();
3986 if (unpackedRank + mixedTiles.size() != packedRank) {
3988 "packed rank must equal unpacked rank + tiling factors");
3994 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3995 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3996 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3997 return op->
emitError(
"the shape of output is not large enough to hold the "
3998 "packed data. Expected at least ")
3999 << expectedPackedType <<
", got " << packedType;
4002 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4004 [](std::tuple<int64_t, OpFoldResult> it) {
4005 int64_t shape = std::get<0>(it);
4006 if (Attribute attr =
4007 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4008 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4009 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4010 return shape == staticTileSize;
4012 return ShapedType::isDynamic(shape);
4014 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4015 "tiled dimension in the packed type");
4027 struct PackOrUnPackTransposeResult {
4034 template <
typename OpTy>
4035 static PackOrUnPackTransposeResult
4039 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4040 "applies to only pack or unpack operations");
4041 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4042 "some permutation must be non-empty");
4043 PackOrUnPackTransposeResult metadata;
4044 metadata.innerDimsPos =
4046 metadata.innerTiles =
4048 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4049 ? packOrUnPackOp.getSourceRank()
4050 : packOrUnPackOp.getDestRank();
4051 metadata.outerDimsPerm =
4052 packOrUnPackOp.getOuterDimsPerm().empty()
4053 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4055 if (!innerPermutation.empty()) {
4056 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4058 "invalid inner permutation");
4062 if (!outerPermutation.empty()) {
4063 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4065 "invalid outer permutation");
4075 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4076 setNameFn(getResult(),
"pack");
4082 std::optional<Value> paddingValue,
4084 assert(innerDimsPos.size() == innerTiles.size() &&
4085 "number of tile sizes specified must match the specified number of "
4086 "original dimensions to be tiled");
4090 build(builder, state, dest.
getType(), source, dest,
4091 paddingValue ? *paddingValue :
nullptr,
4092 outerDimsPerm.empty() ?
nullptr
4117 ShapedType inputType = getSourceType();
4118 int64_t inputRank = inputType.getRank();
4119 return getDestType().getShape().take_front(inputRank);
4123 auto innerDimsPos = getInnerDimsPos();
4124 auto packedShape = getDestType().getShape();
4127 for (
auto index : innerDimsPos)
4128 res.push_back(packedShape[index]);
4139 outputShape.take_front(inputShape.size()));
4140 if (!outerDimsPerm.empty()) {
4141 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4142 "expected output and outer_dims_perm to have same size");
4146 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4147 if (ShapedType::isDynamic(inputShape[pos]))
4151 if (!constantTile) {
4152 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4153 (inputShape[pos] % outputTileSizes[pos] != 0))
4155 }
else if (inputShape[pos] % (*constantTile) != 0) {
4169 auto paddingValue = getPaddingValue();
4172 return emitOpError(
"expected padding_value has ")
4173 << getSourceType().getElementType()
4174 <<
" but got: " << paddingValue.getType();
4177 if (!paddingValue &&
4178 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4179 getDestType().
getShape(), getOuterDimsPerm(),
4182 "invalid tile factor or output size provided. Only full tiles are "
4183 "supported when padding_value is not set");
4193 for (
auto o : ofrs) {
4195 if (llvm::dyn_cast_if_present<Value>(o))
4196 result.push_back(ShapedType::kDynamic);
4210 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4211 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4213 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4214 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4217 resultShape[tiledDim.value()] = divideCeilSigned(
4218 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4222 if (!outerDimsPerm.empty())
4226 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4239 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4241 builder, loc, ceilDivExpr,
4242 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4244 if (!outerDimsPerm.empty())
4246 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4251 innerDimsPos, outerDimsPerm);
4257 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4258 if (!ShapedType::isDynamic(resultTypeShape[i]))
4269 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4274 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4291 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4292 if (ShapedType::isDynamic(value))
4293 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
4297 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4298 int64_t dimPos = std::get<0>(it);
4300 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4302 if (!outerDimsPerm.empty())
4303 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4305 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4306 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4307 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4314 *
this, innerPermutation, outerPermutation);
4315 Value transposedDest =
4316 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4317 metadata.innerDimsPos, metadata.outerDimsPerm);
4318 return b.
create<PackOp>(loc, getSource(), transposedDest,
4319 metadata.innerDimsPos, metadata.innerTiles,
4320 getPaddingValue(), metadata.outerDimsPerm);
4324 template <
typename OpTy>
4326 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4327 "applies to only pack or unpack operations");
4328 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4330 : op.getSourceType();
4332 for (
auto [dimDest,
tile] : llvm::zip(
4333 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4335 if (!constTileSize || ShapedType::isDynamic(dimDest))
4342 if (getPaddingValue())
4357 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4359 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4371 auto packTiles = packOp.getMixedTiles();
4372 auto unPackTiles = unPackOp.getMixedTiles();
4373 if (packTiles.size() != unPackTiles.size())
4375 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4384 auto srcType = op.getSourceType();
4385 if (llvm::any_of(op.getInnerDimsPos(),
4386 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4388 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4390 return !PackOp::requirePaddingValue(
4391 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4392 op.getOuterDimsPerm(), op.getMixedTiles());
4399 bool changeNeeded =
false;
4400 srcShape.assign(packOp.getSourceType().getShape().begin(),
4401 packOp.getSourceType().getShape().end());
4402 destShape.assign(packOp.getDestType().getShape().begin(),
4403 packOp.getDestType().getShape().end());
4404 llvm::SmallSetVector<int64_t, 4> innerDims;
4405 innerDims.insert(packOp.getInnerDimsPos().begin(),
4406 packOp.getInnerDimsPos().end());
4408 if (!packOp.getOuterDimsPerm().empty())
4410 int srcRank = packOp.getSourceRank();
4411 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4412 if (innerDims.contains(i))
4415 int64_t destPos = i;
4416 if (!inverseOuterDimsPerm.empty())
4417 destPos = inverseOuterDimsPerm[srcPos];
4418 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4419 ShapedType::isDynamic(destShape[destPos])) {
4422 int64_t size = srcShape[srcPos];
4423 if (ShapedType::isDynamic(size))
4424 size = destShape[destPos];
4425 srcShape[srcPos] = size;
4426 destShape[destPos] = size;
4427 changeNeeded =
true;
4429 return changeNeeded;
4432 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4434 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4435 if (unPackOp.getSourceType() != packOp.getDestType())
4437 if (packOp.getPaddingValue() ||
4441 rewriter.
replaceOp(packOp, unPackOp.getSource());
4448 packOp.getPaddingValueMutable().clear();
4457 Value source = packOp.getSource();
4458 if (srcShape != packOp.getSourceType().getShape()) {
4459 auto newSrcType = packOp.getSourceType().clone(srcShape);
4461 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4463 Value dest = packOp.getDest();
4464 RankedTensorType originalResultType = packOp.getDestType();
4465 bool needUpdateDestType = (destShape != originalResultType.getShape());
4466 if (needUpdateDestType) {
4467 auto newDestType = packOp.getDestType().clone(destShape);
4469 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4472 packOp.getSourceMutable().assign(source);
4473 packOp.getDestMutable().assign(dest);
4474 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
4477 if (needUpdateDestType) {
4480 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
4489 template <
typename PackOrUnpackOp>
4491 RankedTensorType packedTensorType) {
4492 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4493 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4494 "Function meant for pack/unpack");
4499 int64_t numPackedDims = innerDimsPos.size();
4500 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4501 if (orderedDims != innerDimsPos) {
4507 int64_t packedRank = packedTensorType.getRank();
4517 return llvm::all_of(
4518 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4519 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4522 bool PackOp::isLikePad() {
4523 auto packedTensorType =
4524 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4529 std::optional<Attribute> paddingValue;
4530 if (
auto pad = adaptor.getPaddingValue())
4533 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4534 getDestType(), paddingValue))
4535 return reshapedSource;
4543 void UnPackOp::getAsmResultNames(
4545 setNameFn(getResult(),
"unpack");
4567 ShapedType destType = getDestType();
4568 int64_t destRank = destType.getRank();
4569 return getSourceType().getShape().take_front(destRank);
4573 auto innerDimsPos = getInnerDimsPos();
4574 auto packedShape = getSourceType().getShape();
4577 for (
auto index : innerDimsPos)
4578 res.push_back(packedShape[index]);
4599 assert(innerDimsPos.size() == innerTiles.size() &&
4600 "number of tile sizes specified must match the specified number of "
4601 "original dimensions to be tiled");
4605 build(builder, state, dest.
getType(), source, dest,
4606 outerDimsPerm.empty() ?
nullptr
4624 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4626 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4627 if (srcType.isDynamicDim(i))
4630 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4632 if (!outerDimsPerm.empty()) {
4633 applyPermutationToVector<OpFoldResult>(
4637 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4638 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4640 auto elemType = srcType.getElementType();
4641 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4645 Value transposedSource,
4649 *
this, innerPermutation, outerPermutation);
4650 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4651 metadata.innerDimsPos, metadata.innerTiles,
4652 metadata.outerDimsPerm);
4659 bool changeNeeded =
false;
4660 srcShape.assign(op.getSourceType().getShape().begin(),
4661 op.getSourceType().getShape().end());
4662 destShape.assign(op.getDestType().getShape().begin(),
4663 op.getDestType().getShape().end());
4664 llvm::SmallSetVector<int64_t, 4> innerDims;
4665 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4667 if (!op.getOuterDimsPerm().empty())
4669 int destRank = op.getDestRank();
4670 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4671 if (innerDims.contains(i))
4674 int64_t destPos = i;
4675 if (!inverseOuterDimsPerm.empty())
4676 srcPos = inverseOuterDimsPerm[destPos];
4677 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4678 ShapedType::isDynamic(destShape[destPos])) {
4681 int64_t size = srcShape[srcPos];
4682 if (ShapedType::isDynamic(size))
4683 size = destShape[destPos];
4684 srcShape[srcPos] = size;
4685 destShape[destPos] = size;
4686 changeNeeded =
true;
4688 return changeNeeded;
4691 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4694 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4695 if (packOp.getSourceType() != unPackOp.getDestType())
4697 if (packOp.getPaddingValue() ||
4701 rewriter.
replaceOp(unPackOp, packOp.getSource());
4705 if (
auto dstStyleOp =
4706 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4707 auto destValue = cast<OpResult>(unPackOp.getDest());
4708 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4710 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4718 Value source = unPackOp.getSource();
4719 if (srcShape != unPackOp.getSourceType().getShape()) {
4720 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4721 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
4722 unPackOp.getSource());
4724 Value dest = unPackOp.getDest();
4725 if (destShape != unPackOp.getDestType().getShape()) {
4726 auto newDestType = unPackOp.getDestType().clone(destShape);
4728 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4731 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4732 unPackOp.getOuterDimsPerm());
4734 unPackOp, unPackOp.getResult().getType(), newOp);
4741 bool UnPackOp::isLikeUnPad() {
4742 RankedTensorType packedTensorType = getSourceType();
4748 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4750 return reshapedSource;
4761 if (isa<InsertSliceOp>(op.getOperation()) ||
4762 isa<LoopLikeOpInterface>(op.getOperation()))
4766 bool hasTensorCastOperand =
4767 llvm::any_of(op->getOpOperands(), [&](
OpOperand &opOperand) {
4768 if (llvm::isa<BlockArgument>(opOperand.get()))
4770 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4771 return castOp && canFoldIntoConsumerOp(castOp);
4774 return hasTensorCastOperand;
4780 newOperands.reserve(op->getNumOperands());
4783 int64_t dpsInitIdx = 0;
4784 for (
OpOperand &opOperand : op->getOpOperands()) {
4785 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4787 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4788 if (op.isDpsInit(&opOperand) &&
4789 !llvm::isa<MemRefType>(newOperands.back().getType()))
4790 newResTy[dpsInitIdx++] = newOperands.back().getType();
4822 for (
auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4824 .take_back(op.getMixedTiles().size()),
4825 op.getMixedTiles())) {
4826 int64_t shape = std::get<0>(it);
4827 if (shape == ShapedType::kDynamic) {
4828 newMixedTileSizes.push_back(std::get<1>(it));
4833 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4835 newMixedTileSizes.push_back(std::get<1>(it));
4838 assert(tileSize == shape &&
"tile size and dim size don't match!");
4840 newMixedTileSizes.push_back(
4846 PackOp newOp = rewriter.
create<PackOp>(
4847 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4848 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4852 Value oldResult = op.getResult();
4853 Value newResult = newOp.getResult();
4855 ? rewriter.
create<tensor::CastOp>(
4856 op->getLoc(), oldResult.
getType(), newResult)
4897 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4900 replacements.reserve(newOp->getNumResults());
4901 for (
auto [oldResult, newResult] :
4902 llvm::zip(op->getResults(), newOp->getResults())) {
4903 if (newResult.
getType() != oldResult.getType()) {
4904 replacements.push_back(rewriter.
create<tensor::CastOp>(
4905 op->getLoc(), oldResult.getType(), newResult));
4907 replacements.push_back(newResult);
4920 void TensorDialect::getCanonicalizationPatterns(
4930 #define GET_OP_CLASSES
4931 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< Value > getNewOperands(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming tensor::PackOp op if the tensor.cast has source that is more ...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)