27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/MathExtras.h"
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
47 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
49 if (complex::ConstantOp::isBuildableWith(value, type))
50 return builder.
create<complex::ConstantOp>(loc, type,
51 llvm::cast<ArrayAttr>(value));
57 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
59 if (tensorType.isDynamicDim(dim))
60 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
67 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
69 for (int64_t i = 0; i < tensorType.getRank(); ++i)
76 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
77 assert(tensorType &&
"expected tensor type");
81 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
83 return destOp.getTiedOpOperand(opResult)->get();
91 if (!tensorType.hasStaticShape()) {
99 for (int64_t sz : tensorType.getShape())
105 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
113 if (llvm::isa<TensorType>(opResult.getType())) {
115 if (failed(destination))
117 result.push_back(*destination);
124 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
125 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
126 return rtp1.getShape() == rtp2.getShape() &&
127 rtp1.getElementType() == rtp2.getElementType();
137 llvm::SmallBitVector droppedDims(mixedSizes.size());
138 int64_t shapePos = reducedShape.size() - 1;
140 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
141 size_t idx = mixedSizes.size() - size.index() - 1;
143 bool isStaticUnitSize =
145 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
150 assert(isStaticUnitSize &&
"expected unit dim");
151 droppedDims.set(idx);
156 if (!isStaticUnitSize) {
162 if (reducedShape[shapePos] == 1) {
168 droppedDims.set(idx);
171 assert(shapePos < 0 &&
"dimension mismatch");
178 static RankedTensorType
182 assert(type.getNumDynamicDims() ==
183 static_cast<int64_t
>(dynamicSizes.size()) &&
184 "incorrect number of dynamic sizes");
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (!ShapedType::isDynamic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
362 operand.set(castOp.getOperand());
366 return success(folded);
370 if (inputs.size() != 1 || outputs.size() != 1)
372 Type a = inputs.front(), b = outputs.front();
373 auto aT = llvm::dyn_cast<TensorType>(a);
374 auto bT = llvm::dyn_cast<TensorType>(b);
378 if (aT.getElementType() != bT.getElementType())
394 int64_t rank = one.getRank();
395 if (rank != two.getRank())
400 for (int64_t i = 0; i < rank; ++i) {
401 if (one.isDynamicDim(i)) {
402 join.push_back(two.getDimSize(i));
405 if (two.isDynamicDim(i)) {
406 join.push_back(one.getDimSize(i));
409 if (one.getDimSize(i) != two.getDimSize(i))
411 join.push_back(one.getDimSize(i));
423 LogicalResult matchAndRewrite(CastOp tensorCast,
425 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 if (!tensorCastOperand)
431 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
447 auto newJoin =
joinShapes(sourceType, resultType);
448 if (firstJoin != newJoin)
451 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452 tensorCastOperand.getOperand());
472 LogicalResult matchAndRewrite(CastOp tensorCast,
474 auto extractOperand =
475 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
478 auto rankedResultType =
479 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480 if (!rankedResultType)
484 rankedResultType.getShape() ==
485 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
491 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
493 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
494 if (dimMask && dimMask->count(i))
496 int64_t dim = rankedResultType.getShape()[dimIndex++];
497 if (ShapedType::isDynamic(dim))
499 sizes[i] = rewriter.getIndexAttr(dim);
502 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503 tensorCast, rankedResultType, extractOperand.getSource(),
504 extractOperand.getMixedOffsets(), sizes,
505 extractOperand.getMixedStrides());
514 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
521 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
522 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
524 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
525 return llvm::cast<RankedTensorType>(type);
527 int64_t concatRank = tensorTypes[0].getRank();
530 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
533 for (int64_t i = 0, e = concatRank; i < e; ++i) {
537 for (
auto tensorType : tensorTypes)
542 for (
auto tensorType : tensorTypes)
545 sizes[dim] = concatSize.asInteger();
551 FailureOr<RankedTensorType> resultType =
552 inferResultType(dim, inputs.
getTypes());
553 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
554 build(builder, result, *resultType, dim, inputs);
558 if (getInputs().size() < 1)
559 return emitOpError(
"requires at least one input");
562 for (
auto input : getInputs())
563 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 RankedTensorType resultType = getResultType();
566 int64_t resultRank = getRank();
567 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568 return type.getRank() != resultRank;
570 return emitOpError(
"rank of concatenated inputs must match result rank");
572 Type resultElementType = resultType.getElementType();
573 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574 return type.getElementType() != resultElementType;
576 return emitOpError(
"inputs and result element type must match");
578 int64_t dim = getDim();
579 if (dim >= resultRank)
580 return emitOpError(
"concatenation dim must be less than the tensor rank");
583 for (int64_t i = 0, e = resultRank; i < e; ++i) {
587 for (
auto tensorType : inputTypes) {
588 FailureOr<SaturatedInteger> maybeSize =
590 if (failed(maybeSize))
591 return emitOpError(
"static concatenation size mismatch along ")
592 <<
"non-concatenated dimension " << i;
598 for (
auto tensorType : inputTypes)
601 sizes[dim] = concatSize.asInteger();
602 auto inferredResultType =
605 for (
auto [inferredSize, actualSize] :
606 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608 ShapedType::isDynamic(actualSize);
609 if (!hasDynamic && inferredSize != actualSize)
610 return emitOpError(
"result type ")
611 << resultType <<
"does not match inferred shape "
612 << inferredResultType <<
" static sizes";
622 int64_t dim = getDim();
623 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
625 Value init = inputs[0];
626 int64_t rank =
getType().getRank();
633 for (int64_t i = 0; i < rank; ++i) {
636 if (!
getType().isDynamicDim(i)) {
638 }
else if (!inferredResultType.isDynamicDim(i)) {
641 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
643 reifiedReturnShapes[0][i] =
644 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
648 if (
getType().isDynamicDim(dim)) {
656 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
664 reifiedReturnShapes[0][dim] =
670 void ConcatOp::getAsmResultNames(
672 setNameFn(getResult(),
"concat");
677 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
687 LogicalResult matchAndRewrite(ConcatOp concatOp,
689 if (concatOp.getInputs().size() != 1)
692 concatOp.getInputs()[0]);
700 results.
add<SingleInputConcatOp>(context);
708 setNameFn(getResult(),
"dim");
714 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
715 build(builder, result, source, indexValue);
718 std::optional<int64_t> DimOp::getConstantIndex() {
727 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
728 if (!rankedSourceType)
739 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
744 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
750 int64_t indexVal = index.getInt();
751 if (indexVal < 0 || indexVal >= tensorType.getRank())
755 if (!tensorType.isDynamicDim(index.getInt())) {
757 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
760 Operation *definingOp = getSource().getDefiningOp();
763 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
765 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
768 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
771 auto dynExtents = fromElements.getDynamicExtents().begin();
772 for (
auto dim : resultType.getShape().take_front(index.getInt()))
773 if (ShapedType::isDynamic(dim))
776 return Value{*dynExtents};
780 unsigned unsignedIndex = index.getValue().getZExtValue();
782 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
785 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
786 sliceOp.isDynamicSize(unsignedIndex)) {
787 return {sliceOp.getDynamicSize(unsignedIndex)};
803 LogicalResult matchAndRewrite(DimOp dimOp,
805 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
808 Value newSource = castOp.getOperand();
819 LogicalResult matchAndRewrite(DimOp dimOp,
821 auto source = dimOp.getSource();
822 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
826 auto resultIndex = cast<OpResult>(source).getResultNumber();
827 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
830 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
840 LogicalResult matchAndRewrite(DimOp dim,
842 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
852 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
853 if (extract.
getType() != dim.getType())
855 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
864 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
874 assert(all_of(staticShape,
875 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
876 "expected only static sizes");
877 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
884 build(builder, result, tensorType, dynamicSizes);
893 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
897 if (
getType().getNumDynamicDims() !=
899 return emitOpError(
"incorrect number of dynamic sizes, has ")
901 <<
getType().getNumDynamicDims();
910 for (int64_t i = 0; i <
getType().getRank(); ++i) {
911 if (
getType().isDynamicDim(i)) {
920 Value EmptyOp::getDynamicSize(
unsigned idx) {
921 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
923 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
933 for (int64_t i = 0; i <
getType().getRank(); ++i) {
934 if (
getType().isDynamicDim(i)) {
958 LogicalResult matchAndRewrite(EmptyOp op,
962 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
965 if (foldedTensorType == op.getType())
968 auto newOp = rewriter.
create<EmptyOp>(op.
getLoc(), foldedTensorType,
978 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
980 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
981 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
982 if (!emptyTensorOp || !maybeConstantIndex)
984 if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
987 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1010 LogicalResult matchAndRewrite(CastOp castOp,
1014 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1019 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1023 newMixedSizes.reserve(currMixedSizes.size());
1024 assert(resultShape.size() == currMixedSizes.size() &&
1025 "mismatch in result shape and sizes of empty op");
1026 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1027 int64_t newDim = std::get<0>(it);
1031 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1032 if (ShapedType::isDynamic(newDim) ||
1033 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1038 producer,
"mismatch in static value of shape of empty tensor "
1039 "result and cast result");
1041 newMixedSizes.push_back(attr);
1047 if (!ShapedType::isDynamic(newDim)) {
1048 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1054 newMixedSizes.push_back(currDim);
1059 resultType.getElementType());
1068 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1069 ReplaceEmptyTensorStaticShapeDims>(context);
1078 std::optional<Attribute> cst = std::nullopt) {
1079 if (source && source.
isSplat() && result.hasStaticShape() &&
1100 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1103 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1105 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1108 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1111 extract, tensorCast.getSource(), extract.getIndices());
1118 void ExtractOp::getAsmResultNames(
1120 setNameFn(getResult(),
"extracted");
1125 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1126 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1127 return emitOpError(
"incorrect number of indices for extract_element");
1134 if (
Attribute tensor = adaptor.getTensor())
1135 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1136 return splatTensor.getSplatValue<
Attribute>();
1140 for (
Attribute indice : adaptor.getIndices()) {
1141 if (!indice || !llvm::isa<IntegerAttr>(indice))
1143 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1147 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1148 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1149 auto rank = tensorType.getRank();
1150 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1154 for (
int i = rank - 1; i >= 0; --i) {
1155 flatIndex += indices[i] * stride;
1156 stride *= tensorType.getDimSize(i);
1160 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1163 return fromElementsOp.getElements()[flatIndex];
1167 if (
Attribute tensor = adaptor.getTensor()) {
1168 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1169 if (elementsAttr && elementsAttr.isValidIndex(indices))
1170 return elementsAttr.getValues<
Attribute>()[indices];
1178 results.
add<ExtractFromTensorCast>(context);
1185 void FromElementsOp::getAsmResultNames(
1187 setNameFn(getResult(),
"from_elements");
1192 assert(!elements.empty() &&
"expected at least one element");
1194 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1195 build(builder, result, resultType, elements);
1198 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1199 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1222 struct ExtractElementFromIndexCast
1226 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1229 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1235 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1236 loc, elementTy, indexCast.getIn(), extract.getIndices());
1249 results.
add<ExtractElementFromIndexCast>(context);
1256 void GatherOp::getAsmResultNames(
1258 setNameFn(getResult(),
"gather");
1273 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1274 RankedTensorType indicesType,
1278 resultShape.reserve(resultShape.size() + sourceType.getRank());
1279 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1280 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1282 resultShape.push_back(1);
1285 resultShape.push_back(sourceType.getDimSize(idx));
1290 static LogicalResult
1293 StringRef gatherOrScatter, StringRef sourceOrDest) {
1295 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1297 int64_t numGatherDims = dims.size();
1298 if (numGatherDims > rank)
1300 <<
"_dims overflow " << sourceOrDest <<
" rank";
1301 if (indices.empty() || indices.back() != numGatherDims)
1303 <<
"_dims length must match the size of last dimension of indices";
1304 for (int64_t val : dims) {
1307 <<
"_dims value must be non-negative";
1310 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1312 for (int64_t i = 1; i < numGatherDims; ++i) {
1313 if (dims[i - 1] >= dims[i])
1315 <<
"_dims values must be strictly increasing";
1321 int64_t sourceRank = getSourceType().getRank();
1324 getIndicesType().
getShape(), sourceRank,
1325 "gather",
"source")))
1328 RankedTensorType expectedResultType = GatherOp::inferResultType(
1329 getSourceType(), getIndicesType(), gatherDims,
false);
1330 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1331 getSourceType(), getIndicesType(), gatherDims,
true);
1332 if (getResultType() != expectedResultType &&
1333 getResultType() != expectedRankReducedResultType) {
1334 return emitOpError(
"result type "
1337 << expectedResultType <<
" or its rank-reduced variant "
1338 << expectedRankReducedResultType <<
" (got: " << getResultType()
1347 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1349 return reshapedSource;
1357 void InsertOp::getAsmResultNames(
1359 setNameFn(getResult(),
"inserted");
1364 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1365 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1366 return emitOpError(
"incorrect number of indices");
1374 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1375 if (scalar == splatDest.getSplatValue<
Attribute>())
1384 void GenerateOp::getAsmResultNames(
1386 setNameFn(getResult(),
"generated");
1393 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1394 if (
getType().isDynamicDim(dim)) {
1395 reifiedReturnShapes[0][dim] = getOperand(idx++);
1397 reifiedReturnShapes[0][dim] =
1407 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1408 if (getNumOperands() != resultType.getNumDynamicDims())
1409 return emitError(
"must have as many index operands as dynamic extents "
1410 "in the result type");
1414 LogicalResult GenerateOp::verifyRegions() {
1415 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1417 if (!llvm::all_of(getBody().getArgumentTypes(),
1419 return emitError(
"all body arguments must be index");
1420 if (getBody().getNumArguments() != resultTy.getRank())
1421 return emitError(
"must have one body argument per input dimension");
1424 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1426 if (yieldOp.getValue().getType() != resultTy.getElementType())
1428 "body must be terminated with a `yield` operation of the tensor "
1434 void GenerateOp::build(
1438 build(b, result, resultTy, dynamicExtents);
1443 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1447 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1460 LogicalResult matchAndRewrite(GenerateOp generateOp,
1464 generateOp.getType(), generateOp.getDynamicExtents(),
1465 foldedDynamicSizes);
1468 if (foldedTensorType == generateOp.getType())
1471 auto loc = generateOp.getLoc();
1473 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1475 newOp.getBody().begin());
1477 generateOp.getType(), newOp);
1493 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1496 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1498 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1503 Block *body = &tensorFromElements.getBody().
front();
1506 rewriter.
clone(op, mapping);
1520 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1527 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1528 setNameFn(getResult(),
"rank");
1533 auto type = getOperand().getType();
1534 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1535 if (shapedType && shapedType.hasRank())
1537 return IntegerAttr();
1544 void ReshapeOp::getAsmResultNames(
1546 setNameFn(getResult(),
"reshape");
1550 int64_t numElements = 1;
1551 for (
auto dim : type.getShape())
1561 return emitOpError(
"element types of source and destination tensor "
1562 "types should be the same");
1566 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1567 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1569 if (resultRankedType) {
1570 if (operandRankedType && resultRankedType.hasStaticShape() &&
1571 operandRankedType.hasStaticShape()) {
1573 return emitOpError(
"source and destination tensor should have the "
1574 "same number of elements");
1576 if (ShapedType::isDynamic(shapeSize))
1577 return emitOpError(
"cannot use shape operand with dynamic length to "
1578 "reshape to statically-ranked tensor type");
1579 if (shapeSize != resultRankedType.getRank())
1581 "length of shape operand differs from the result's tensor rank");
1588 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1590 return reshapedSource;
1595 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1596 getSourceMutable().assign(reshapeOpProducer.getSource());
1600 auto source = getSource();
1601 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1602 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1603 if (!sourceTy || !resultTy || sourceTy != resultTy)
1608 if (sourceTy.getRank() == 1)
1611 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1612 auto elements = fromElements.getElements();
1614 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1615 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1616 auto element = elements[id];
1619 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1623 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1624 dynamicNoop &= dimOp.getSource() == source;
1629 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1633 dynamicNoop =
false;
1648 void CollapseShapeOp::getAsmResultNames(
1650 setNameFn(getResult(),
"collapsed");
1653 void ExpandShapeOp::getAsmResultNames(
1655 setNameFn(getResult(),
"expanded");
1658 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1659 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1660 "invalid resultDim");
1662 if (llvm::is_contained(it.value(), resultDim))
1664 llvm_unreachable(
"could not find reassociation group");
1667 FailureOr<SmallVector<OpFoldResult>>
1669 RankedTensorType expandedType,
1672 std::optional<SmallVector<OpFoldResult>> outputShape =
1677 return *outputShape;
1684 auto [staticOutputShape, dynamicOutputShape] =
1686 build(builder, result, cast<RankedTensorType>(resultType), src,
1688 dynamicOutputShape, staticOutputShape);
1696 auto tensorResultTy = cast<RankedTensorType>(resultType);
1697 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1698 builder, result.
location, tensorResultTy, reassociation, inputShape);
1700 if (succeeded(outputShape)) {
1701 outputShapeOrEmpty = *outputShape;
1703 build(builder, result, tensorResultTy, src, reassociation,
1704 outputShapeOrEmpty);
1712 getReassociationIndices());
1720 getReassociationIndices());
1723 RankedTensorType CollapseShapeOp::inferCollapsedType(
1725 return inferCollapsedType(
1727 type.getContext(), reassociation)));
1733 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1735 auto shape = type.getShape();
1737 newShape.reserve(reassociation.size());
1742 unsigned currentDim = 0;
1744 unsigned dim = m.getNumResults();
1745 auto band = shape.slice(currentDim, dim);
1747 if (llvm::is_contained(band, ShapedType::kDynamic))
1748 size = ShapedType::kDynamic;
1750 for (
unsigned d = 0; d < dim; ++d)
1751 size *= shape[currentDim + d];
1752 newShape.push_back(size);
1762 auto resultType = inferCollapsedType(
1763 llvm::cast<RankedTensorType>(src.
getType()),
1768 build(b, result, resultType, src, attrs);
1771 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1772 TensorReshapeOp, ExpandShapeOp>::value>
1774 RankedTensorType expandedType,
1775 RankedTensorType collapsedType) {
1780 auto maps = op.getReassociationMaps();
1781 RankedTensorType expectedType =
1782 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1784 return op.
emitOpError(
"expected collapsed type to be ")
1785 << expectedType <<
", but got " << collapsedType;
1790 auto srcType = getSrcType();
1791 auto resultType = getResultType();
1793 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1794 return emitOpError(
"expected number of static shape dims to be equal to "
1795 "the output rank (")
1796 << resultType.getRank() <<
") but found "
1797 << getStaticOutputShape().size() <<
" inputs instead";
1799 if ((int64_t)getOutputShape().size() !=
1800 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1801 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1802 "static_output_shape: static_output_shape has ")
1803 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1804 <<
" dynamic dims while output_shape has " << getOutputShape().size()
1817 template <
typename TensorReshapeOp>
1820 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1828 reshapeOp.getResultType(), attr.
getRawData());
1835 template <
typename TensorReshapeOp>
1840 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1842 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1843 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1847 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1854 template <
typename TensorReshapeOp>
1857 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1860 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1864 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1866 if (!shapedTy.hasStaticShape())
1870 fromElements.getElements());
1879 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1881 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1885 RankedTensorType srcType =
1886 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1887 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1888 srcType, collapseShapeOp.getReassociationMaps());
1890 if (newResultType == collapseShapeOp.getResultType()) {
1892 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1895 auto newOp = rewriter.
create<CollapseShapeOp>(
1896 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1897 collapseShapeOp.getReassociation());
1899 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1908 LogicalResult matchAndRewrite(DimOp dimOp,
1910 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1915 std::optional<int64_t> dim = dimOp.getConstantIndex();
1916 if (!dim.has_value())
1920 RankedTensorType resultType = expandShapeOp.getResultType();
1921 if (!resultType.isDynamicDim(*dim))
1925 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1931 for (int64_t d : grp) {
1933 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1934 product *= resultType.getDimSize(d);
1940 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1944 dimOp, expr.floorDiv(
product), srcDimSz);
1952 LogicalResult matchAndRewrite(DimOp dimOp,
1954 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1955 if (!collapseShapeOp)
1959 std::optional<int64_t> dim = dimOp.getConstantIndex();
1960 if (!dim.has_value())
1964 RankedTensorType resultType = collapseShapeOp.getResultType();
1965 if (!resultType.isDynamicDim(*dim))
1970 collapseShapeOp.getReassociationIndices()[*dim];
1977 srcDimSizes.push_back(rewriter.
create<DimOp>(
1978 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1994 FoldReshapeWithConstant<ExpandShapeOp>,
1995 FoldReshapeWithSplat<ExpandShapeOp>,
1996 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1997 FoldDimOfCollapseShape>(context);
2005 tensor::DimOp, RankedTensorType>,
2006 FoldReshapeWithConstant<CollapseShapeOp>,
2007 FoldReshapeWithSplat<CollapseShapeOp>,
2008 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2012 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2013 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2014 adaptor.getOperands());
2017 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2018 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2019 adaptor.getOperands());
2026 void ExtractSliceOp::getAsmResultNames(
2028 setNameFn(getResult(),
"extracted_slice");
2034 RankedTensorType ExtractSliceOp::inferResultType(
2040 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2041 sourceTensorType.getRank() &&
2042 "unexpected staticSizes not equal to rank of source");
2044 sourceTensorType.getEncoding());
2047 RankedTensorType ExtractSliceOp::inferResultType(
2055 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2056 staticSizes, staticStrides);
2067 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2068 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2072 auto inferredType = llvm::cast<RankedTensorType>(
2073 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2074 int rankDiff = inferredType.getRank() - desiredResultRank;
2076 auto shape = inferredType.getShape();
2077 llvm::SmallBitVector dimsToProject =
2081 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2082 if (!dimsToProject.test(pos))
2083 projectedShape.push_back(shape[pos]);
2087 return inferredType;
2090 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2091 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2099 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2100 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2107 RankedTensorType resultType,
Value source,
2117 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2120 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2121 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2124 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2137 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2146 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2152 RankedTensorType resultType,
Value source,
2161 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2168 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2173 RankedTensorType expectedType) {
2178 return op->
emitError(
"expected rank to be smaller or equal to ")
2179 <<
"the other rank. ";
2181 return op->
emitError(
"expected type to be ")
2182 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2184 return op->
emitError(
"expected element type to be ")
2185 << expectedType.getElementType();
2187 llvm_unreachable(
"unexpected extract_slice op verification result");
2194 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2195 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2207 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2208 assert(sourceTensorType &&
"not a ranked tensor type");
2209 auto sourceShape = sourceTensorType.getShape();
2210 if (sourceShape.equals(desiredShape))
2212 auto maybeRankReductionMask =
2214 if (!maybeRankReductionMask)
2223 reifiedReturnShapes.resize(1);
2224 reifiedReturnShapes[0].reserve(
getType().getRank());
2227 for (
const auto &size :
enumerate(mixedSizes)) {
2228 if (droppedDims.test(size.index()))
2230 reifiedReturnShapes[0].push_back(size.value());
2251 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2255 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2258 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2259 return matchPattern(operand, matchConstantIndex());
2263 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2272 Value newResult = rewriter.
create<ExtractSliceOp>(
2273 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2274 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2275 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2276 if (newResult.
getType() != sliceOp.getType())
2277 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
2286 template <
typename IterTy,
typename ElemTy>
2291 assert(offsets.size() == sizes.size());
2292 assert(offsets.size() == strides.size());
2293 if (offsets.empty())
2296 int64_t offset = offsets.front();
2297 int64_t size = sizes.front();
2298 int64_t stride = strides.front();
2299 if (offsets.size() == 1) {
2300 for (int64_t i = 0; i < size; ++i, offset += stride)
2301 outValues->push_back(*(values + offset));
2306 for (int64_t i = 0; i < size; ++i, offset += stride) {
2307 auto begin = values + offset * counts.front();
2308 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2309 offsets.drop_front(), sizes.drop_front(),
2310 strides.drop_front(), outValues);
2317 class ConstantOpExtractSliceFolder final
2322 ConstantOpExtractSliceFolder(
MLIRContext *context,
2325 controlFn(std::move(controlFn)) {}
2327 LogicalResult matchAndRewrite(ExtractSliceOp op,
2338 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2340 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2347 int64_t count = sourceType.getNumElements();
2352 auto offsets = op.getStaticOffsets();
2353 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2355 auto sizes = op.getStaticSizes();
2356 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2358 auto strides = op.getStaticStrides();
2359 if (llvm::is_contained(strides, ShapedType::kDynamic))
2365 counts.reserve(shape.size());
2366 for (int64_t v : shape) {
2368 counts.push_back(count);
2374 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2376 outValues.reserve(sourceType.getNumElements());
2377 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2378 elems.begin(), counts, offsets, sizes, strides, &outValues);
2380 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2382 outValues.reserve(sourceType.getNumElements());
2383 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2384 elems.begin(), counts, offsets, sizes, strides, &outValues);
2407 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2416 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2417 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2425 ExtractSliceOp newOp) {
2426 Value replacement = newOp.getResult();
2427 if (replacement.
getType() != op.getType())
2428 replacement = rewriter.
create<tensor::CastOp>(op.
getLoc(), op.getType(),
2439 ExtractSliceOpCastFolder>(context);
2443 static LogicalResult
2445 ShapedType shapedType) {
2452 auto shape = shapedType.getShape();
2453 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2467 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2470 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2471 insertOp.isSameAs(extractOp, isSame))
2472 return insertOp.getSource();
2477 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2479 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2481 return reshapedSource;
2482 if (getSourceType() ==
getType() &&
2484 return this->getSource();
2493 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2494 unsigned rank = rankedTensorType.getRank();
2498 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2499 offsets, sizes, strides);
2506 void InsertSliceOp::getAsmResultNames(
2508 setNameFn(getResult(),
"inserted_slice");
2523 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2535 build(b, result, source, dest, offsets, sizes, strides, attrs);
2548 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2554 RankedTensorType srcType, RankedTensorType dstType,
2559 RankedTensorType expected = ExtractSliceOp::inferResultType(
2560 dstType, staticOffsets, staticSizes, staticStrides);
2562 *expectedType = expected;
2568 RankedTensorType expectedType;
2571 getStaticSizes(), getStaticStrides(), &expectedType);
2593 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2596 if (!prevInsertOp ||
2597 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2598 !prevInsertOp.isSameAs(insertOp, isSame))
2601 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2613 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2616 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2617 !extractOp.isSameAs(insertOp, isSame))
2620 return extractOp.getSource();
2624 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2625 getSourceType() ==
getType() &&
2627 return this->getSource();
2649 template <
typename InsertOpTy>
2650 class InsertSliceOpConstantArgumentFolder final
2655 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2668 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2669 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2670 mixedOffsets, mixedSizes, mixedStrides);
2671 Value toInsert = insertSliceOp.getSource();
2672 if (sourceType != insertSliceOp.getSourceType()) {
2677 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2679 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2680 sourceType, toInsert);
2683 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2684 mixedSizes, mixedStrides);
2709 template <
typename InsertOpTy>
2710 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2713 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2715 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2716 return matchPattern(operand, matchConstantIndex());
2720 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2721 auto castOp = v.getDefiningOp<tensor::CastOp>();
2723 return std::nullopt;
2724 return castOp.getSource();
2726 std::optional<Value> sourceCastSource =
2727 getSourceOfCastOp(insertSliceOp.getSource());
2728 std::optional<Value> destCastSource =
2729 getSourceOfCastOp(insertSliceOp.getDest());
2730 if (!sourceCastSource && !destCastSource)
2734 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2735 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2736 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2737 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2738 if (!srcType || !dstType)
2746 staticSizes, srcType.getShape(),
true);
2747 if (!rankReductionMask.has_value())
2755 int64_t rankReducedIdx = 0;
2756 for (
auto [idx, size] :
enumerate(staticSizes)) {
2757 if (!rankReductionMask.value().contains(idx) &&
2758 !srcType.isDynamicDim(rankReducedIdx)) {
2760 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2761 size = srcType.getDimSize(rankReducedIdx++);
2765 staticSizes, insertSliceOp.getStaticStrides()) !=
2770 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2771 mixedSizes, insertSliceOp.getMixedStrides());
2774 bool isParallelInsert =
2775 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2776 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2777 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2778 insertSliceOp.getDestType(),
2807 template <
typename InsertOpTy>
2808 struct InsertSliceOpSourceCastInserter final
2812 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2814 RankedTensorType srcType = insertSliceOp.getSourceType();
2815 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2818 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2819 if (std::optional<int64_t> constInt =
2824 newSrcShape[i] = *constInt;
2831 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2832 if (srcType == newSrcType ||
2834 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2846 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2849 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2851 insertSliceOp, cast, insertSliceOp.getDest(),
2852 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2853 insertSliceOp.getMixedStrides());
2865 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2866 InsertSliceOpCastFolder<InsertSliceOp>,
2867 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2874 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
2875 unsigned rank = rankedTensorType.getRank();
2879 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2888 setNameFn(getResult(),
"padded");
2894 Type typeToInfer,
Type typeToInferFrom) {}
2898 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2899 Type &typeToInfer,
Type typeToInferFrom) {
2901 typeToInfer = typeToInferFrom;
2906 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
2907 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
2909 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2910 if (!expectedType) {
2911 return emitError(
"failed to infer expectedType from sourceType ")
2912 << sourceType <<
", specified resultType is " << resultType;
2914 if (resultType.getRank() != expectedType.getRank()) {
2916 << resultType <<
" does not match the inferred type "
2919 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
2920 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2922 if (expectedType.isDynamicDim(i))
2925 << resultType <<
" does not match the inferred type "
2932 LogicalResult PadOp::verifyRegions() {
2933 auto ®ion = getRegion();
2934 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
2937 return emitError(
"expected the block to have ") << rank <<
" arguments";
2941 if (!en.value().isIndex())
2942 return emitOpError(
"expected block argument ")
2943 << (en.index() + 1) <<
" to be an index";
2948 if (yieldOp.getValue().getType() !=
2950 return emitOpError(
"expected yield type to match shape element type");
2955 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2959 unsigned rank = sourceType.getRank();
2960 if (staticLow.size() != rank)
2961 return RankedTensorType();
2962 if (staticHigh.size() != rank)
2963 return RankedTensorType();
2964 if (!resultShape.empty() && resultShape.size() != rank)
2965 return RankedTensorType();
2968 for (
auto i : llvm::seq<unsigned>(0, rank)) {
2969 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2970 staticHigh[i] == ShapedType::kDynamic) {
2971 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2974 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2975 assert((resultShape.empty() || size == resultShape[i] ||
2976 resultShape[i] == ShapedType::kDynamic) &&
2977 "mismatch between inferred shape and result shape");
2978 inferredShape.push_back(size);
2989 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2991 resultType = inferResultType(sourceType, staticLow, staticHigh);
2993 build(b, result, resultType, source, low, high,
3001 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3002 unsigned rank = sourceType.getRank();
3004 build(b, result, resultType, source, staticVector, staticVector, low, high,
3012 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3022 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3024 assert(llvm::isa<RankedTensorType>(resultType));
3026 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3035 build(b, result, resultType, source, low, high, nofold, attrs);
3039 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3046 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3050 llvm::SmallBitVector PadOp::getPaddedDims() {
3051 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3053 for (
const auto &en :
enumerate(paddingWidths))
3055 paddedDims.set(en.index());
3057 extractPaddedDims(getMixedLowPad());
3058 extractPaddedDims(getMixedHighPad());
3068 LogicalResult matchAndRewrite(PadOp padTensorOp,
3070 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3072 if (padTensorOp.getNofold())
3075 padTensorOp, padTensorOp.getResult().getType(),
3076 padTensorOp.getSource());
3085 LogicalResult matchAndRewrite(PadOp padTensorOp,
3087 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3091 auto newResultType = PadOp::inferResultType(
3092 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3093 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3094 padTensorOp.getResultType().getShape());
3096 if (newResultType == padTensorOp.getResultType()) {
3098 padTensorOp.getSourceMutable().assign(castOp.getSource());
3101 auto newOp = rewriter.
create<PadOp>(
3102 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3103 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3104 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3107 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3110 padTensorOp, padTensorOp.getResultType(), newOp);
3121 LogicalResult matchAndRewrite(PadOp padTensorOp,
3123 if (!padTensorOp.getResult().hasOneUse())
3126 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3130 tensorCastOp.getDest().getType()))
3133 auto replacementOp = rewriter.
create<PadOp>(
3134 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3135 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3136 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3137 padTensorOp.getHigh(), padTensorOp.getNofold(),
3141 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3142 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3185 LogicalResult matchAndRewrite(PadOp padOp,
3187 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3190 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3191 if (!outerPadOp || outerPadOp.getNofold())
3193 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3198 int64_t rank = padOp.getSourceType().getRank();
3199 if (outerSliceOp.getSourceType().getRank() != rank) {
3201 "cannot fold rank-reducing chain");
3205 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3207 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3211 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3213 "cannot fold PadOps with low padding");
3218 Value innerValue = padOp.getConstantPaddingValue();
3219 Value outerValue = outerPadOp.getConstantPaddingValue();
3220 if (!innerValue || !outerValue ||
3223 innerAttr != outerAttr) {
3225 padOp,
"cannot fold PadOps with different padding values");
3229 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3230 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3231 if (innerDims.anyCommon(outerDims)) {
3233 padOp,
"cannot fold PadOps with common padding dimensions");
3243 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3244 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3245 if (!innerDims.test(en.index()) &&
3247 en.value() = outerOffset;
3250 if (!outerDims.test(en.index()) &&
3252 en.value() = innerOffset;
3256 padOp,
"cannot find zero-offset and zero-padding pair");
3266 if (!outerDims.test(en.index()))
3268 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3269 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3270 assert(!ShapedType::isDynamic(sourceSize) &&
3271 "expected padded dimension to have a static size");
3274 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3275 "match the size of the outer padding");
3277 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3283 if (innerDims.test(en.index()))
3284 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3285 if (outerDims.test(en.index()))
3286 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3291 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3292 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3293 innerSliceOp.getMixedStrides());
3294 auto newPadOp = rewriter.
create<PadOp>(
3295 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3296 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3299 newPadOp.getRegion().begin());
3300 rewriter.
replaceOp(padOp, newPadOp.getResult());
3308 LogicalResult matchAndRewrite(PadOp padTensorOp,
3310 Value input = padTensorOp.getSource();
3311 if (!llvm::isa<RankedTensorType>(input.
getType()))
3313 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3314 auto inputRank = inputDims.size();
3316 auto oldResultType =
3317 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3321 auto outputDims = oldResultType.getShape();
3326 for (
auto operand : padTensorOp.getLow()) {
3329 constOperandsLow.push_back(ShapedType::kDynamic);
3330 newLows.push_back(operand);
3333 constOperandsLow.push_back(intOp.getExtValue());
3337 for (
auto operand : padTensorOp.getHigh()) {
3340 constOperandsHigh.push_back(ShapedType::kDynamic);
3341 newHighs.push_back(operand);
3344 constOperandsHigh.push_back(intOp.getExtValue());
3351 if (inputDims.size() != outputDims.size() ||
3352 inputDims.size() != constLow.size() ||
3353 inputDims.size() != constHigh.size())
3358 for (
size_t i = 0; i < inputRank; i++) {
3359 if (constLow[i] == ShapedType::kDynamic)
3360 constLow[i] = constOperandsLow[lowCount++];
3361 if (constHigh[i] == ShapedType::kDynamic)
3362 constHigh[i] = constOperandsHigh[highCount++];
3370 for (
size_t i = 0; i < inputRank; i++) {
3371 if (outputDims[i] == ShapedType::kDynamic) {
3372 newOutDims.push_back(
3373 (staticLow[i] == ShapedType::kDynamic ||
3374 staticHigh[i] == ShapedType::kDynamic ||
3375 inputDims[i] == ShapedType::kDynamic
3376 ? ShapedType::kDynamic
3377 : inputDims[i] + staticLow[i] + staticHigh[i]));
3379 newOutDims.push_back(outputDims[i]);
3384 llvm::all_of(newOutDims,
3385 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3390 newOutDims, padTensorOp.getType().getElementType());
3391 auto newOp = rewriter.
create<PadOp>(
3392 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3393 newLows, newHighs, padTensorOp.getNofold(),
3397 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3425 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3428 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3430 if (padOp.getNofold()) {
3434 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3435 if (!producerPad || producerPad.getNofold()) {
3437 padOp,
"producer is not a foldable tensor.pad op");
3441 Value consumerPadValue = padOp.getConstantPaddingValue();
3442 Value producerPadValue = producerPad.getConstantPaddingValue();
3443 if (!consumerPadValue || !producerPadValue ||
3444 consumerPadValue != producerPadValue) {
3447 "cannot fold PadOps with different or non-constant padding values");
3458 for (
auto [consumerIndex, producerIndex] :
3459 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3461 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3467 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3469 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3471 auto newPadOp = rewriter.
create<tensor::PadOp>(
3472 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3473 newLowPad, newHighPad, padOp.getNofold(),
3476 newPadOp.getRegion().begin());
3477 rewriter.
replaceOp(padOp, newPadOp.getResult());
3486 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3487 FoldOrthogonalPaddings, FoldStaticPadding,
3488 FoldConsecutiveConstantPadding>(context);
3500 Value PadOp::getConstantPaddingValue() {
3501 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3504 Value padValue = yieldOp.getValue();
3516 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3526 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3527 ParallelCombiningOpInterface parallelCombiningParent =
3528 getParallelCombiningParent();
3529 for (
const auto &it :
3532 if (&nextOp == getOperation())
3533 return parallelCombiningParent.getParentResult(it.index());
3535 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3551 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3564 build(b, result, source, dest, offsets, sizes, strides, attrs);
3578 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3582 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3583 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3584 << *(getOperation()->getParentOp());
3586 RankedTensorType expectedType;
3589 getStaticSizes(), getStaticStrides(), &expectedType);
3593 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3595 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3596 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3597 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3608 void ScatterOp::getAsmResultNames(
3610 setNameFn(getResult(),
"scatter");
3614 int64_t destRank = getDestType().getRank();
3617 getIndicesType().
getShape(), destRank,
3618 "scatter",
"dest")))
3622 return emitOpError(
"requires 'unique' attribute to be set");
3629 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3630 getDestType(), getIndicesType(), scatterDims,
false);
3631 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3632 getDestType(), getIndicesType(), scatterDims,
true);
3633 if (getSourceType() != expectedSourceType &&
3634 getSourceType() != expectedRankReducedSourceType) {
3635 return emitOpError(
"source type "
3638 << expectedSourceType <<
" or its rank-reduced variant "
3639 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3652 build(builder, result, aggregateType, element, dynamicSizes);
3658 build(builder, result, aggregateType, element, dynamicSizes);
3666 build(builder, result, element, staticShape, dynamicSizes);
3669 void SplatOp::getAsmResultNames(
3671 setNameFn(getResult(),
"splat");
3675 if (
getType().getNumDynamicDims() !=
3677 return emitOpError(
"incorrect number of dynamic sizes, has ")
3679 <<
getType().getNumDynamicDims();
3688 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3689 if (
getType().isDynamicDim(i)) {
3699 auto constOperand = adaptor.getInput();
3700 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3704 if (!
getType().hasStaticShape())
3716 template <
typename OpTy>
3717 static LogicalResult
3720 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3721 "applies to only pack or unpack operations");
3722 int64_t destRank = op.getDestRank();
3724 reifiedReturnShapes[0] =
3729 template <
typename OpTy>
3731 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3732 "applies to only pack or unpack operations");
3736 assert(tiles.size() == dimsToTile.size() &&
3737 "tiles must match indices of dimension to block");
3739 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3740 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3741 return dimAndTileMapping;
3744 template <
typename OpTy>
3746 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3747 "applies to only pack or unpack operations");
3750 unsigned dynamicValIndex = 0;
3751 for (int64_t staticTile : op.getStaticInnerTiles()) {
3752 if (!ShapedType::isDynamic(staticTile))
3755 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3757 return mixedInnerTiles;
3760 template <
typename OpTy>
3762 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3763 "applies to only pack or unpack operations");
3776 size_t dimsPosSize = dimsPos.size();
3777 if (dimsPosSize > rank)
3780 for (int64_t dim : dimsPos)
3781 uniqued.insert(dim);
3782 if (dimsPosSize != uniqued.size())
3784 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3785 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3794 sourceShape.size() == limitShape.size() &&
3795 "expected source shape rank, and limit of the shape to have same rank");
3796 return llvm::all_of(
3797 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3798 int64_t sourceExtent = std::get<0>(it);
3799 int64_t limit = std::get<1>(it);
3800 return ShapedType::isDynamic(sourceExtent) ||
3801 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3805 template <
typename OpTy>
3807 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3808 "applies to only pack or unpack operations");
3809 Operation *op = packOrUnPack.getOperation();
3813 return llvm::any_of(
3819 if (hasZeros(mixedTiles))
3820 return op->
emitError(
"invalid zero tile factor");
3823 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3824 ? packOrUnPack.getSourceType()
3825 : packOrUnPack.getDestType();
3826 size_t unpackedRank = unpackedType.getRank();
3830 return op->
emitError(
"invalid inner_dims_pos vector");
3832 return op->
emitError(
"invalid outer_dims_perm vector");
3833 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3834 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3838 if (mixedTiles.size() > unpackedRank) {
3839 return op->
emitError(
"tiling factors must be less than or equal to the "
3840 "input rank for pack or output rank for unpack");
3842 if (mixedTiles.size() != innerDimsPos.size()) {
3844 "tiling factors must equal the number of dimensions to tile");
3847 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3848 ? packOrUnPack.getDestType()
3849 : packOrUnPack.getSourceType();
3850 size_t packedRank = packedType.getRank();
3852 if (unpackedRank + mixedTiles.size() != packedRank) {
3854 "packed rank must equal unpacked rank + tiling factors");
3860 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3861 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3862 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3863 return op->
emitError(
"the shape of output is not large enough to hold the "
3864 "packed data. Expected at least ")
3865 << expectedPackedType <<
", got " << packedType;
3868 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3870 [](std::tuple<int64_t, OpFoldResult> it) {
3871 std::optional<int64_t> constTileSize =
3872 getConstantIntValue(std::get<1>(it));
3873 int64_t shape = std::get<0>(it);
3874 if (!constTileSize) {
3877 return ShapedType::isDynamic(shape);
3879 if (ShapedType::isDynamic(shape)) {
3886 return shape == constTileSize.value();
3888 return op->
emitError(
"mismatch in inner tile sizes specified and shaped of "
3889 "tiled dimension in the packed type");
3901 struct PackOrUnPackTransposeResult {
3908 template <
typename OpTy>
3909 static PackOrUnPackTransposeResult
3913 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3914 "applies to only pack or unpack operations");
3915 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3916 "some permutation must be non-empty");
3917 PackOrUnPackTransposeResult metadata;
3918 metadata.innerDimsPos =
3920 metadata.innerTiles =
3922 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3923 ? packOrUnPackOp.getSourceRank()
3924 : packOrUnPackOp.getDestRank();
3925 metadata.outerDimsPerm =
3926 packOrUnPackOp.getOuterDimsPerm().empty()
3927 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3929 if (!innerPermutation.empty()) {
3930 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3932 "invalid inner permutation");
3936 if (!outerPermutation.empty()) {
3937 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3939 "invalid outer permutation");
3949 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3950 setNameFn(getResult(),
"pack");
3956 std::optional<Value> paddingValue,
3958 assert(innerDimsPos.size() == innerTiles.size() &&
3959 "number of tile sizes specified must match the specified number of "
3960 "original dimensions to be tiled");
3964 build(builder, state, dest.
getType(), source, dest,
3965 paddingValue ? *paddingValue :
nullptr,
3966 outerDimsPerm.empty() ?
nullptr
3996 outputShape.take_front(inputShape.size()));
3997 if (!outerDimsPerm.empty()) {
3998 assert(outerDimsPerm.size() == outputTileSizes.size() &&
3999 "expected output and outer_dims_perm to have same size");
4003 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4004 if (ShapedType::isDynamic(inputShape[pos]))
4008 if (!constantTile) {
4009 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4010 (inputShape[pos] % outputTileSizes[pos] != 0))
4012 }
else if (inputShape[pos] % (*constantTile) != 0) {
4026 auto paddingValue = getPaddingValue();
4029 return emitOpError(
"expected padding_value has ")
4030 << getSourceType().getElementType()
4031 <<
" but got: " << paddingValue.getType();
4034 if (!paddingValue &&
4035 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4036 getDestType().
getShape(), getOuterDimsPerm(),
4039 "invalid tile factor or output size provided. Only full tiles are "
4040 "supported when padding_value is not set");
4050 for (
auto o : ofrs) {
4052 if (llvm::dyn_cast_if_present<Value>(o))
4053 result.push_back(ShapedType::kDynamic);
4067 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4068 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4070 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4071 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4074 resultShape[tiledDim.value()] = divideCeilSigned(
4075 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4079 if (!outerDimsPerm.empty())
4083 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4096 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4098 builder, loc, ceilDivExpr,
4099 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4101 if (!outerDimsPerm.empty())
4103 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4108 innerDimsPos, outerDimsPerm);
4114 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4115 if (!ShapedType::isDynamic(resultTypeShape[i]))
4126 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4131 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4148 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4149 if (ShapedType::isDynamic(value))
4150 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
4154 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4155 int64_t dimPos = std::get<0>(it);
4157 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4159 if (!outerDimsPerm.empty())
4160 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4162 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4163 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4164 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4171 *
this, innerPermutation, outerPermutation);
4172 Value transposedDest =
4173 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4174 metadata.innerDimsPos, metadata.outerDimsPerm);
4175 return b.
create<PackOp>(loc, getSource(), transposedDest,
4176 metadata.innerDimsPos, metadata.innerTiles,
4177 getPaddingValue(), metadata.outerDimsPerm);
4181 template <
typename OpTy>
4183 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4184 "applies to only pack or unpack operations");
4185 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4187 : op.getSourceType();
4189 for (
auto [dimDest,
tile] : llvm::zip(
4190 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4192 if (!constTileSize || ShapedType::isDynamic(dimDest))
4199 if (getPaddingValue())
4214 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4216 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4228 auto packTiles = packOp.getMixedTiles();
4229 auto unPackTiles = unPackOp.getMixedTiles();
4230 if (packTiles.size() != unPackTiles.size())
4232 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4241 auto srcType = op.getSourceType();
4242 if (llvm::any_of(op.getInnerDimsPos(),
4243 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4245 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4247 return !PackOp::requirePaddingValue(
4248 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4249 op.getOuterDimsPerm(), op.getMixedTiles());
4256 bool changeNeeded =
false;
4257 srcShape.assign(packOp.getSourceType().getShape().begin(),
4258 packOp.getSourceType().getShape().end());
4259 destShape.assign(packOp.getDestType().getShape().begin(),
4260 packOp.getDestType().getShape().end());
4261 llvm::SmallSetVector<int64_t, 4> innerDims;
4262 innerDims.insert(packOp.getInnerDimsPos().begin(),
4263 packOp.getInnerDimsPos().end());
4265 if (!packOp.getOuterDimsPerm().empty())
4267 int srcRank = packOp.getSourceRank();
4268 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4269 if (innerDims.contains(i))
4272 int64_t destPos = i;
4273 if (!inverseOuterDimsPerm.empty())
4274 destPos = inverseOuterDimsPerm[srcPos];
4275 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4276 ShapedType::isDynamic(destShape[destPos])) {
4279 int64_t size = srcShape[srcPos];
4280 if (ShapedType::isDynamic(size))
4281 size = destShape[destPos];
4282 srcShape[srcPos] = size;
4283 destShape[destPos] = size;
4284 changeNeeded =
true;
4286 return changeNeeded;
4289 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4291 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4292 if (unPackOp.getSourceType() != packOp.getDestType())
4294 if (packOp.getPaddingValue() ||
4298 rewriter.
replaceOp(packOp, unPackOp.getSource());
4305 packOp.getPaddingValueMutable().clear();
4314 Value source = packOp.getSource();
4315 if (srcShape != packOp.getSourceType().getShape()) {
4316 auto newSrcType = packOp.getSourceType().clone(srcShape);
4318 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4320 Value dest = packOp.getDest();
4321 if (destShape != packOp.getDestType().getShape()) {
4322 auto newDestType = packOp.getDestType().clone(destShape);
4324 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4327 loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4328 packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4330 packOp, packOp.getResult().getType(), newOp);
4337 template <
typename PackOrUnpackOp>
4339 RankedTensorType packedTensorType) {
4340 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4341 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4342 "Function meant for pack/unpack");
4347 int64_t numPackedDims = innerDimsPos.size();
4348 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4349 if (orderedDims != innerDimsPos) {
4355 int64_t packedRank = packedTensorType.getRank();
4365 return llvm::all_of(
4366 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4367 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4370 bool PackOp::isLikePad() {
4371 auto packedTensorType =
4372 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4377 std::optional<Attribute> paddingValue;
4378 if (
auto pad = adaptor.getPaddingValue())
4381 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4382 getDestType(), paddingValue))
4383 return reshapedSource;
4391 void UnPackOp::getAsmResultNames(
4393 setNameFn(getResult(),
"unpack");
4430 assert(innerDimsPos.size() == innerTiles.size() &&
4431 "number of tile sizes specified must match the specified number of "
4432 "original dimensions to be tiled");
4436 build(builder, state, dest.
getType(), source, dest,
4437 outerDimsPerm.empty() ?
nullptr
4455 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4457 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4458 if (srcType.isDynamicDim(i))
4461 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4463 if (!outerDimsPerm.empty()) {
4464 applyPermutationToVector<OpFoldResult>(
4468 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4469 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4471 auto elemType = srcType.getElementType();
4472 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4476 Value transposedSource,
4480 *
this, innerPermutation, outerPermutation);
4481 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4482 metadata.innerDimsPos, metadata.innerTiles,
4483 metadata.outerDimsPerm);
4490 bool changeNeeded =
false;
4491 srcShape.assign(op.getSourceType().getShape().begin(),
4492 op.getSourceType().getShape().end());
4493 destShape.assign(op.getDestType().getShape().begin(),
4494 op.getDestType().getShape().end());
4495 llvm::SmallSetVector<int64_t, 4> innerDims;
4496 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4498 if (!op.getOuterDimsPerm().empty())
4500 int destRank = op.getDestRank();
4501 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4502 if (innerDims.contains(i))
4505 int64_t destPos = i;
4506 if (!inverseOuterDimsPerm.empty())
4507 srcPos = inverseOuterDimsPerm[destPos];
4508 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4509 ShapedType::isDynamic(destShape[destPos])) {
4512 int64_t size = srcShape[srcPos];
4513 if (ShapedType::isDynamic(size))
4514 size = destShape[destPos];
4515 srcShape[srcPos] = size;
4516 destShape[destPos] = size;
4517 changeNeeded =
true;
4519 return changeNeeded;
4522 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4525 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4526 if (packOp.getSourceType() != unPackOp.getDestType())
4528 if (packOp.getPaddingValue() ||
4532 rewriter.
replaceOp(unPackOp, packOp.getSource());
4536 if (
auto dstStyleOp =
4537 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4538 auto destValue = cast<OpResult>(unPackOp.getDest());
4539 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4541 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4549 Value source = unPackOp.getSource();
4550 if (srcShape != unPackOp.getSourceType().getShape()) {
4551 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4552 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
4553 unPackOp.getSource());
4555 Value dest = unPackOp.getDest();
4556 if (destShape != unPackOp.getDestType().getShape()) {
4557 auto newDestType = unPackOp.getDestType().clone(destShape);
4559 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4562 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4563 unPackOp.getOuterDimsPerm());
4565 unPackOp, unPackOp.getResult().getType(), newOp);
4572 bool UnPackOp::isLikeUnPad() {
4573 RankedTensorType packedTensorType = getSourceType();
4579 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4581 return reshapedSource;
4613 if (isa<InsertSliceOp>(op.getOperation()))
4618 if (isa<LoopLikeOpInterface>(op.getOperation()))
4622 bool hasTensorCastOperand =
4624 if (llvm::isa<BlockArgument>(opOperand.get()))
4626 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4627 return castOp && canFoldIntoConsumerOp(castOp);
4629 if (!hasTensorCastOperand)
4636 int64_t dpsInitIdx = 0;
4640 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.
get());
4641 if (op.isDpsInit(&opOperand) &&
4642 !llvm::isa<MemRefType>(newOperands.back().getType()))
4643 newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4647 Operation *newOp =
clone(rewriter, op, newResultTypes, newOperands);
4650 for (
auto [oldResult, newResult] :
4652 if (newResult.
getType() != oldResult.getType()) {
4653 replacements.push_back(rewriter.
create<tensor::CastOp>(
4654 op->
getLoc(), oldResult.getType(), newResult));
4656 replacements.push_back(newResult);
4669 void TensorDialect::getCanonicalizationPatterns(
4678 #define GET_OP_CLASSES
4679 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)