42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
63 auto type = cast<ShapedType>(v.
getType());
64 if (!type.isDynamicDim(dim))
69 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
70 return builder.create<tensor::DimOp>(loc, v, dim);
72 .Case<MemRefType>([&](MemRefType t) ->
Value {
73 return builder.create<memref::DimOp>(loc, v, dim);
84 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
85 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
88 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
89 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
105 llvm_unreachable(
"Expected MemRefType or TensorType");
110 auto shapedType = llvm::cast<ShapedType>(source.
getType());
111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
151 regionBuilder(b, *body, attrs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
178 state.addAttributes(attributes);
180 "operandSegmentSizes",
182 static_cast<int32_t>(outputs.size())}));
185 Region ®ion = *state.addRegion();
187 state.attributes.getAttrs(), regionBuilder);
191 std::optional<TypeRange> resultTensorTypes,
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
201 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
207 std::optional<TypeRange> resultTensorTypes,
214 indexingMapsAttrVal =
218 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
220 attributes, regionBuilder);
229 bool addOperandSegmentSizes =
true) {
230 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
259 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
261 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
265 if (addOperandSegmentSizes) {
274 attrs.
append(
"operandSegmentSizes",
276 {static_cast<int32_t>(inputsOperands.size()),
277 static_cast<int32_t>(outputsOperands.size())}));
282 {static_cast<int32_t>(inputsOperands.size()),
283 static_cast<int32_t>(outputsOperands.size())}));
287 std::optional<RegisteredOperationName> info =
290 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
291 return parser.emitError(attrsLoc)
292 <<
"'" << result.name.getStringRef() <<
"' op ";
303 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
304 if (!outputs.empty())
305 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
316 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
319 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
320 "region expects {0} args, got {1}",
321 numRegionArgs, inputTypes.size() + outputTypes.size()));
340 unsigned numRegionArgs,
356 result.
addTypes(outputTensorsTypes);
358 std::unique_ptr<Region> region = std::make_unique<Region>();
370 if (resultTypes.empty())
415 class RegionBuilderHelper {
418 : builder(builder), block(block) {}
422 if (!isFloatingPoint(arg))
423 llvm_unreachable(
"unsupported non numeric type");
425 builder.setInsertionPointToEnd(&block);
428 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
430 return builder.create<math::LogOp>(arg.
getLoc(), arg);
432 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
434 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
436 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
438 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
439 case UnaryFn::reciprocal: {
441 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
442 ::cast<TypedAttr>(oneAttr));
443 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
446 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
448 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
450 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
451 case UnaryFn::square:
452 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
454 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
456 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
458 llvm_unreachable(
"unsupported unary function");
463 bool allComplex = isComplex(arg0) && isComplex(arg1);
464 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
465 bool allInteger = isInteger(arg0) && isInteger(arg1);
468 if (!allComplex && !allFloatingPoint && !allInteger)
469 llvm_unreachable(
"unsupported non numeric type");
471 builder.setInsertionPointToEnd(&block);
475 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
476 if (allFloatingPoint)
477 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
479 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
480 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
483 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
484 if (allFloatingPoint)
485 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
487 llvm_unreachable(
"unsupported operation: sub with bools");
488 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
491 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
492 if (allFloatingPoint)
493 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
495 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
496 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
499 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
500 if (allFloatingPoint)
501 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
503 llvm_unreachable(
"unsupported operation: div with bools");
504 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
505 case BinaryFn::div_unsigned:
506 if (!allInteger || allBool)
507 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
508 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
509 case BinaryFn::max_signed:
511 if (allFloatingPoint)
512 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
513 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
514 case BinaryFn::min_signed:
516 if (allFloatingPoint)
517 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
518 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
519 case BinaryFn::max_unsigned:
521 if (allFloatingPoint)
522 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
523 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
524 case BinaryFn::min_unsigned:
526 if (allFloatingPoint)
527 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
528 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
530 assert(allFloatingPoint);
531 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
533 llvm_unreachable(
"unsupported binary function");
541 bool tailFloatingPoint =
542 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
543 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
545 builder.setInsertionPointToEnd(&block);
547 case TernaryFn::select:
548 if (!headBool && !(tailFloatingPoint || tailInteger))
549 llvm_unreachable(
"unsupported non numeric type");
550 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
552 llvm_unreachable(
"unsupported ternary function");
558 case TypeFn::cast_signed:
559 return cast(toType, operand,
false);
560 case TypeFn::cast_unsigned:
561 return cast(toType, operand,
true);
563 llvm_unreachable(
"unsupported type conversion function");
568 builder.setInsertionPointToEnd(&block);
569 Location loc = builder.getUnknownLoc();
570 builder.create<YieldOp>(loc, values);
573 Value constant(
const std::string &value) {
575 builder.setInsertionPointToEnd(&block);
576 Location loc = builder.getUnknownLoc();
578 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
581 Value index(int64_t dim) {
583 builder.setInsertionPointToEnd(&block);
584 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
587 Type getIntegerType(
unsigned width) {
601 builder.setInsertionPointToEnd(&block);
602 auto loc = operand.
getLoc();
606 bool isComplex(
Value value) {
607 return llvm::isa<ComplexType>(value.
getType());
609 bool isFloatingPoint(
Value value) {
610 return llvm::isa<FloatType>(value.
getType());
612 bool isInteger(
Value value) {
613 return llvm::isa<IntegerType>(value.
getType());
630 LogicalResult matchAndRewrite(CopyOp copyOp,
632 if (copyOp.getInputs() != copyOp.getOutputs())
634 if (copyOp.hasPureBufferSemantics())
637 rewriter.
replaceOp(copyOp, copyOp.getInputs());
647 results.
add<EraseSelfCopy>(context);
660 template <
typename TensorReshapeOp>
663 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
665 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
670 TensorReshapeOp newInit;
671 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
673 newInit = rewriter.
create<TensorReshapeOp>(
674 loc, reshapeOp.getResultType(), oldFill.output(),
675 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
676 reshapeOp.getStaticOutputShape());
678 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
680 reshapeOp.getReassociation());
693 LogicalResult matchAndRewrite(tensor::PadOp padOp,
695 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
701 Value padValue = padOp.getConstantPaddingValue();
702 if (!padValue || fillOp.value() != padValue)
708 padOp,
"failed to reify tensor.pad op result shape");
710 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
711 padOp.getLoc(), reifiedShape.front(),
712 padOp.getResultType().getElementType());
718 if (replacement.getType() != padOp.getResultType()) {
719 replacement = rewriter.
create<tensor::CastOp>(
720 fillOp.getLoc(), padOp.getResultType(), replacement);
730 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
733 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
735 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
739 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
744 Value firstDest = insertOp.getDest();
745 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
746 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
751 bool disjoint =
false;
752 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
755 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
756 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
757 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
761 int64_t prevStart = prevOp.getStaticOffset(i);
762 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
763 prevOp.getStaticStride(i);
764 int64_t nextStart = insertOp.getStaticOffset(i);
765 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
766 insertOp.getStaticStride(i);
767 if (prevEnd < nextStart || nextEnd < prevStart) {
775 firstDest = prevOp.getDest();
786 Value padValue = srcPadOp.getConstantPaddingValue();
787 if (!padValue || dstFillOp.value() != padValue)
803 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
805 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
808 RankedTensorType srcPadType = srcPadOp.getSourceType();
810 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
811 if (srcPadType.isDynamicDim(i)) {
813 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
816 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
821 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
822 newSizes, insertOp.getMixedStrides());
828 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
832 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
836 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
841 Value extractedScalar = fillOp.getInputs()[0];
844 rewriter.
replaceOp(extractOp, extractedScalar);
852 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
853 linalg::PackOp packOp) {
854 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
858 if (
auto paddingValue = packOp.getPaddingValue())
862 Value packOpDest = packOp.getDest();
866 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
876 LogicalResult matchAndRewrite(linalg::PackOp packOp,
878 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
881 rewriter.
replaceOp(packOp, fillOp.value().result());
890 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
892 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
895 copyOp.getOutputs());
898 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
900 fillOp.getOutputs());
911 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
913 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
915 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
916 transposeOp.getDpsInitOperand(0)->get());
928 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
930 auto concatOperands = concatOp.getInputs();
931 if (concatOperands.empty()) {
935 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
944 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
946 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
947 auto fillOp = v.getDefiningOp<linalg::FillOp>();
954 if (fillVal != firstFillVal)
957 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
960 if (!llvm::all_of(concatOperands.drop_front(),
961 isDefinedByCompatibleFillOp)) {
963 concatOp,
"not all operands are defined by a compatible fill op");
966 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
967 concatOp.getLoc(), concatOp.getDim(), allOuts);
969 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
978 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
979 FoldFillWithPack, FoldFillWithPad,
980 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
981 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
982 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
995 for (
ValueRange container : {inputs, outputs}) {
996 for (
Value v : container) {
997 Type t = v.getType();
998 blockArgTypes.push_back(
1000 blockArgLocs.push_back(v.getLoc());
1006 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1010 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1012 for (
Value v : getRegionInputArgs())
1014 for (
Value v : getRegionOutputArgs())
1015 setNameFn(v,
"out");
1018 void GenericOp::build(
1021 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1024 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1025 iteratorTypes, doc, libraryCall);
1029 inputs, outputs, bodyBuild);
1032 void GenericOp::build(
1036 StringRef libraryCall,
1039 build(builder, result, resultTensorTypes, inputs, outputs,
1044 return IteratorTypeAttr::get(builder.getContext(), iter);
1047 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1048 bodyBuild, attributes);
1051 void GenericOp::build(
1055 StringRef libraryCall,
1058 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1059 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1062 void GenericOp::build(
1068 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1070 "", bodyBuild, attributes);
1073 void GenericOp::build(
1079 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1082 "", bodyBuild, attributes);
1089 auto genericAttrNames = linalgTraitAttrNames();
1092 genericAttrNamesSet.insert_range(genericAttrNames);
1094 for (
auto attr : (*this)->getAttrs()) {
1095 if (attr.getName() == getIteratorTypesAttrName()) {
1096 auto iteratorTypes =
1097 llvm::cast<ArrayAttr>(attr.getValue())
1098 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1104 llvm::to_vector(llvm::map_range(
1105 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1109 genericAttrs.emplace_back(
1110 getIteratorTypesAttrName(),
1112 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1113 genericAttrs.push_back(attr);
1116 if (!genericAttrs.empty()) {
1118 p << genericDictAttr;
1124 genericAttrNames.push_back(
"operandSegmentSizes");
1125 genericAttrNamesSet.insert(genericAttrNames.back());
1127 bool hasExtraAttrs =
false;
1129 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1132 if (hasExtraAttrs) {
1139 if (!getRegion().empty()) {
1149 DictionaryAttr dictAttr;
1158 dictAttr.getValue().end());
1164 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1166 if (!iteratorTypes) {
1167 return parser.
emitError(attributeLocation)
1168 <<
"expected " << getIteratorTypesAttrName(result.
name)
1169 <<
" array attribute";
1174 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1175 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1176 if (!maybeIteratorType.has_value())
1178 <<
"unexpected iterator_type (" << s <<
")";
1180 iteratorTypeAttrs.push_back(
1197 std::unique_ptr<Region> region = std::make_unique<Region>();
1209 result.
addTypes(outputTensorsTypes);
1217 LinalgOp linalgOp) {
1218 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1219 if (!llvm::isa<MemRefType>(operand.
getType()))
1221 effects.emplace_back(
1226 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1227 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1229 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1240 void GenericOp::getEffects(
1250 if (!linalgOp.hasPureTensorSemantics())
1269 template <
typename OpTy>
1273 LogicalResult matchAndRewrite(OpTy linalgOp,
1276 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1281 Block &body = linalgOp->getRegion(0).
front();
1282 if (!llvm::hasSingleElement(body))
1284 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1289 if (linalgOp.hasPureBufferSemantics()) {
1290 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1291 linalgOp.getDpsInputOperand(0)->get() ==
1292 linalgOp.getDpsInitOperand(0)->get()) {
1300 if (!linalgOp.hasPureTensorSemantics())
1307 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1308 if (!yieldArg || yieldArg.getOwner() != &body)
1310 unsigned argumentNumber = yieldArg.getArgNumber();
1311 Value returnedArg = linalgOp->getOperand(argumentNumber);
1312 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1316 if (returnType != resultType) {
1321 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1322 linalgOp.getLoc(), resultType, returnedArg);
1324 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1327 returnedArg = rewriter.
create<tensor::CastOp>(
1328 linalgOp.getLoc(), resultType, returnedArg);
1331 returnedArgs.push_back(returnedArg);
1334 if (returnedArgs.size() != linalgOp->getNumResults())
1336 rewriter.
replaceOp(linalgOp, returnedArgs);
1345 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1367 for (
Type outputType : outputTypes) {
1368 if (llvm::isa<RankedTensorType>(outputType))
1373 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1382 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1384 for (
Value v : getRegionInputArgs())
1389 if (!getResults().empty())
1390 setNameFn(getResults().front(),
"mapped");
1397 build(builder, result,
TypeRange{}, inputs, init);
1402 if (llvm::isa<RankedTensorType>(initType))
1407 inputs, {}, bodyBuild);
1414 bool initFirst =
false) {
1419 for (
auto &operand : operands) {
1421 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1428 payloadOpOperands.push_back(block.
getArguments().back());
1429 for (
const auto &arg : block.
getArguments().drop_back())
1430 payloadOpOperands.push_back(arg);
1439 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1446 std::optional<OperationName> payloadOpName;
1450 if (failed(operationName))
1454 payloadOpName = operationName.value();
1462 if (payloadOpName.has_value()) {
1500 for (
const auto &[operand, bbArg] :
1502 if (bbArg != operand)
1506 for (
const auto &[operand, bbArg] :
1508 if (bbArg != operand)
1517 std::string attrToElide;
1519 for (
const auto &attr : payloadOp->
getAttrs()) {
1521 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1522 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1523 attrToElide = attr.getName().str();
1524 elidedAttrs.push_back(attrToElide);
1533 Block *mapper = getBody();
1548 [&](
auto arg) { p.printRegionArgument(arg); });
1557 auto *bodyBlock = getBody();
1558 auto blockArgs = bodyBlock->getArguments();
1561 if (getInputs().size() != blockArgs.size())
1562 return emitOpError() <<
"expects number of operands to match the arity of "
1564 << getInputs().size() <<
" and " << blockArgs.size();
1567 for (
const auto &[bbArgType, inputArg] :
1568 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1569 auto inputElemType =
1570 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1571 if (bbArgType != inputElemType) {
1572 return emitOpError() <<
"expected element type of input " << inputElemType
1573 <<
" to match bbArg type " << bbArgType;
1578 auto outputShape = getInit().getType().getShape();
1580 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1581 if (inputElemShape != outputShape) {
1582 return emitOpError() <<
"expected shape of input (" << inputElemShape
1583 <<
") to match shape of output (" << outputShape
1592 int64_t rank = getInit().getType().getRank();
1596 ArrayAttr MapOp::getIndexingMaps() {
1598 int64_t rank = getInit().getType().getRank();
1599 int64_t numIndexingMaps = getOperands().size();
1604 void MapOp::getEffects(
1618 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1620 for (
Value v : getRegionInputArgs())
1622 for (
Value v : getRegionOutputArgs())
1623 setNameFn(v,
"init");
1626 void ReduceOp::getAsmResultNames(
1628 if (!getResults().empty())
1629 setNameFn(getResults().front(),
"reduced");
1632 void ReduceOp::build(
1637 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1641 for (
Value init : inits) {
1643 if (llvm::isa<RankedTensorType>(initType))
1649 inputs, inits, bodyBuild);
1654 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1656 utils::IteratorType::parallel);
1657 for (int64_t reductionDim : getDimensions())
1658 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1659 return iteratorTypes;
1662 ArrayAttr ReduceOp::getIndexingMaps() {
1664 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1671 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1672 affineMaps.push_back(resultMap);
1676 void ReduceOp::getEffects(
1688 StringRef attributeName) {
1697 std::optional<OperationName> payloadOpName;
1701 if (failed(operationName))
1705 payloadOpName = operationName.value();
1716 if (payloadOpName.has_value()) {
1736 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1740 Block *mapper = getBody();
1755 [&](
auto arg) { p.printRegionArgument(arg); });
1766 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1767 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1768 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1769 return emitOpError() <<
"expects all inputs to have the same shapes. "
1770 "Shape at input-index "
1772 <<
" is not equal to the shape at input-index 0.";
1775 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1776 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1777 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1778 return emitOpError() <<
"expects all outputs to have the same shapes. "
1779 "Shape at output-index "
1781 <<
" is not equal to the shape at output-index 0.";
1784 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1785 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1788 for (int64_t dimension : dimensionsRef) {
1789 if (dimension < 0 || dimension >= inputType.getRank()) {
1790 return emitOpError()
1791 <<
"dimensions for reduction should be in the range [0, "
1792 << inputType.getRank() - 1 <<
"].";
1794 dimensionsToReduce.insert(dimension);
1797 auto inputDims = inputType.getShape();
1798 auto initDims = initType.getShape();
1803 if (!dimensionsToReduce.count(en.index()))
1804 reducedInputDims.push_back(en.value());
1807 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1808 return emitOpError() <<
"number of dimensions after reduction "
1809 << reducedInputDims.size()
1810 <<
" doesn't match the init rank "
1811 << initType.getRank();
1814 if (reducedInputDims != initDims)
1815 return emitOpError() <<
"init dimensions [" << initDims
1816 <<
"] doesn't match input dimensions after reduction ["
1817 << reducedInputDims <<
"]";
1819 Block *block = getBody();
1821 return emitOpError()
1822 <<
"mismatching number of operands and block arguments";
1825 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1826 Type inputElementType =
1827 llvm::cast<ShapedType>(input.getType()).getElementType();
1828 if (inputElementType != bbArg.getType())
1829 return emitOpError()
1830 <<
"input element type " << inputElementType
1831 <<
" does not match corresponding block argument type "
1836 for (
auto [output, bbArg] : llvm::zip(
1837 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1838 auto outputElementType =
1839 llvm::cast<ShapedType>(output.getType()).getElementType();
1840 if (outputElementType != bbArg.getType())
1841 return emitOpError()
1842 <<
"output element type " << outputElementType
1843 <<
" does not match corresponding block argument type "
1859 b.
create<linalg::YieldOp>(loc, args[0]);
1874 if (llvm::isa<RankedTensorType>(initType))
1903 void TransposeOp::getAsmResultNames(
1905 if (!getResults().empty())
1906 setNameFn(getResults().front(),
"transposed");
1919 return emitOpError(
"permutation is not valid");
1921 auto inputType = getInput().getType();
1922 auto initType = getInit().getType();
1924 int64_t rank = inputType.getRank();
1926 if (rank != initType.getRank())
1927 return emitOpError() <<
"input rank " << rank
1928 <<
" does not match init rank " << initType.getRank();
1930 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1931 return emitOpError() <<
"size of permutation " << permutationRef.size()
1932 <<
" does not match the argument rank " << rank;
1934 auto inputDims = inputType.getShape();
1935 auto initDims = initType.getShape();
1937 for (int64_t i = 0; i < rank; ++i) {
1938 int64_t inputDim = inputDims[permutationRef[i]];
1939 int64_t initDim = initDims[i];
1941 if (inputDim != initDim) {
1942 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1943 <<
" doesn't match dim(input, permutation[" << i
1944 <<
"]) = " << inputDim;
1952 int64_t rank = getInit().getType().getRank();
1956 ArrayAttr TransposeOp::getIndexingMaps() {
1958 int64_t rank = getInit().getType().getRank();
1961 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1965 void TransposeOp::getEffects(
1975 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1978 if (!isa<TensorType>(getInput().
getType()))
1982 if (getPermutation().size() == 0) {
1983 result.push_back(getInput());
1988 result.push_back(getInput());
2001 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2002 if (!defTransposeOp)
2007 foldedPerms.reserve(perms.size());
2008 for (int64_t perm : perms)
2009 foldedPerms.push_back(defPerms[perm]);
2012 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2026 Value input = transposeOp.getInput();
2027 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2038 unsigned dimensionSize = dimensions.size();
2039 for (
unsigned i = 0; i < dimensionSize; ++i)
2040 resultDimensions.push_back(invertPerm[dimensions[i]]);
2043 Value broadcastInput = broadcastOp.getInput();
2044 Location loc = transposeOp.getLoc();
2047 auto broadcastInputTy =
2048 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2049 unsigned inputRank = broadcastInputTy.getRank();
2050 for (
unsigned i = 0; i < inputRank; ++i) {
2051 if (broadcastInputTy.isDynamicDim(i)) {
2052 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2056 broadcastInputTy.getDimSize(i)));
2061 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2062 transposeOp.getLoc(), transposeResultShapes,
2063 broadcastInputTy.getElementType());
2066 Value transposeResult =
2068 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2072 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2097 if (llvm::isa<RankedTensorType>(initType))
2126 void BroadcastOp::getAsmResultNames(
2128 if (!getResults().empty())
2129 setNameFn(getResults().front(),
"broadcasted");
2141 auto inputType = getInput().getType();
2142 auto initType = getInit().getType();
2144 int64_t inputRank = inputType.getRank();
2145 int64_t initRank = initType.getRank();
2147 auto inputShape = inputType.getShape();
2148 auto initShape = initType.getShape();
2150 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2151 return emitOpError() <<
"input rank plus added dimensions does not "
2152 "match init rank. input rank: "
2154 <<
", dimensions size: " << dimensionsRef.size()
2155 <<
", init rank: " << initRank;
2158 if (dim < 0 || dim >= initRank)
2159 return emitOpError() <<
"dimension " << idx
2160 <<
" is out of range. expected range: [0, "
2161 << initRank - 1 <<
"], got: " << dim;
2166 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2167 if (!llvm::is_contained(dimensionsRef, dim))
2168 dimMap.push_back(dim);
2171 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2174 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2175 return emitOpError() <<
"input dim " << inputDimIdx
2176 <<
" should match init dim " << initDimIdx
2177 <<
". input: " << inputShape[inputDimIdx]
2178 <<
", init: " << initShape[initDimIdx];
2185 int64_t rank = getInit().getType().getRank();
2189 ArrayAttr BroadcastOp::getIndexingMaps() {
2191 int64_t rank = getInit().getType().getRank();
2197 void BroadcastOp::getEffects(
2209 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2217 if (getNumOperands() > 0)
2218 p <<
' ' << getOperands();
2220 if (getNumOperands() > 0)
2221 p <<
" : " << getOperandTypes();
2236 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2237 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2238 return op.emitOpError(
"expected number of yield values (")
2239 << op.getNumOperands()
2240 <<
") to match the number of inits / outs operands of the enclosing "
2241 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2243 for (
OpOperand &opOperand : op->getOpOperands()) {
2245 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2247 if (isa<MemRefType, RankedTensorType>(elementType))
2249 if (opOperand.get().getType() != elementType)
2250 return op.emitOpError(
"type of yield operand ")
2251 << (opOperand.getOperandNumber() + 1) <<
" ("
2252 << opOperand.get().getType() <<
") doesn't match "
2253 <<
"the element type of the enclosing linalg.generic op ("
2254 << elementType <<
")";
2260 auto *parentOp = (*this)->getParentOp();
2261 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2262 return emitOpError(
"expected single non-empty parent region");
2264 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2267 return emitOpError(
"expected parent op with LinalgOp interface");
2275 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2277 return emitOpError(
"expected parent op with LinalgOp interface");
2278 if (linalgOp.getNumLoops() <= getDim())
2279 return emitOpError(
"expected dim (")
2280 << getDim() <<
") to be lower than the number of loops ("
2281 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2286 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2295 uint64_t dim = getDim();
2296 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2297 if (loopBounds[dim] == 1)
2305 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2307 #define GET_OP_CLASSES
2308 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2310 #define GET_OP_CLASSES
2311 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2312 #define GET_OP_CLASSES
2313 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2330 for (
unsigned i = 0; i < num; ++i)
2337 auto rangeA = llvm::make_range(a.begin(), a.end());
2338 auto rangeB = llvm::make_range(b.begin(), b.end());
2339 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2340 return llvm::to_vector<4>(concatRanges);
2344 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2346 for (
auto size : memref.getShape())
2353 if (
auto as = memref.getMemorySpace()) {
2354 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2355 ss <<
"as" << attr.getInt();
2361 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2364 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2377 assert(isa<LinalgOp>(op));
2379 std::string fun =
"";
2381 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2382 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2383 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2384 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2388 std::replace(name.begin(), name.end(),
'.',
'_');
2389 llvm::raw_string_ostream ss(name);
2393 return std::string();
2408 LogicalResult matchAndRewrite(LinalgOp op,
2410 for (
OpOperand &opOperand : op->getOpOperands()) {
2414 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2417 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2428 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2431 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2436 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2443 if (castOp->getBlock() != linalgOp->getBlock())
2450 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2453 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2459 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2461 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2464 linalgOp.getDpsInits().end());
2465 outputOperands[resultNumber] = newOperand;
2466 newOperands.append(outputOperands.begin(), outputOperands.end());
2469 linalgOp->result_type_end());
2470 resultTypes[resultNumber] = resultType;
2471 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2478 results[resultNumber] = castBack;
2490 if (linalgOp.isScalar(&opOperand))
2492 Value src = opOperand.get();
2493 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2494 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2502 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2503 Value castSource = castOp.getSource();
2504 auto castSourceType =
2505 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2506 if (castSourceType && castSourceType.hasStaticShape())
2507 sourceShape = castSourceType.getShape();
2513 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2514 if (sourceType.isDynamicDim(i))
2516 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2517 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2527 static void createNewOperandWithStaticSizes(
2531 bool &changeNeeded) {
2533 newOperands.push_back(src);
2534 if (linalgOp.isScalar(opOperand))
2536 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2537 Type resultType = sourceType;
2538 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2539 resultTypes.push_back(resultType);
2543 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2547 bool newOperandNeeded =
false;
2548 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2549 int64_t dimShape = sourceShape[i];
2551 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2552 newShape.push_back(dimShape);
2558 newShape.push_back(affineExprToSize[dimExpr]);
2559 newOperandNeeded =
true;
2562 sourceType.getEncoding());
2563 if (newOperandNeeded) {
2564 changeNeeded =
true;
2567 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2569 newOperands[index] = newOperand;
2571 if (linalgOp.isDpsInit(opOperand))
2572 resultTypes.push_back(resultType);
2581 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2583 if (!linalgOp.hasPureTensorSemantics())
2587 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2588 return !map.isProjectedPermutation();
2598 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2605 bool changeNeeded =
false;
2606 newOperands.reserve(linalgOp->getNumOperands());
2607 resultTypes.reserve(linalgOp.getNumDpsInits());
2610 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2611 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2612 affineExprToSize, linalgOp, newOperands,
2613 resultTypes, changeNeeded);
2622 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2625 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2626 Value newResult = std::get<1>(it);
2627 Value oldResult = std::get<0>(it);
2630 replacements.push_back(
2631 (newType != oldType)
2632 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2635 rewriter.
replaceOp(linalgOp, replacements);
2650 ShapedType inputType = getInputOperandType();
2651 ShapedType outputType = getOutputOperandType();
2656 return emitOpError(
"incompatible output shape");
2658 int64_t inputRank = getInputOperandRank();
2659 int64_t dimension = getDimension();
2660 if ((dimension < 0) || (dimension >= inputRank))
2661 return emitOpError(
"incorrect dimension specified");
2667 int64_t operandRank = getInputOperandRank();
2670 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2671 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2672 Value source = getInput();
2673 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2674 loopBounds[dim].offset = zero;
2675 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2676 loopBounds[dim].stride = one;
2683 utils::IteratorType::parallel);
2684 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2685 return iteratorTypes;
2688 FailureOr<TilingResult>
2692 int64_t rank = getInputOperandRank();
2697 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2699 return emitOpError(
"failed to compute input slice");
2701 tiledOperands.emplace_back(inputSlice->
getResult(0));
2703 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2705 return emitOpError(
"failed to compute output slice");
2707 tiledOperands.emplace_back(outputSlice->
getResult(0));
2710 if (hasPureTensorSemantics())
2711 resultTypes.push_back(tiledOperands[1].
getType());
2713 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2725 if (resultNumber == 0) {
2726 resultOffsets.assign(offsets.begin(), offsets.end());
2727 resultSizes.assign(sizes.begin(), sizes.end());
2742 Location loc = getOperation()->getLoc();
2744 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2745 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2746 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2747 if (!outputShapedType.isDynamicDim(dim)) {
2749 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2756 reifiedReturnShapes.emplace_back(std::move(shapes));
2760 void SoftmaxOp::getEffects(
2764 if (!llvm::isa<MemRefType>(operand.
getType()))
2767 &getOperation()->getOpOperand(index), 0,
2772 for (
OpOperand &operand : getDpsInitsMutable()) {
2773 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2806 int64_t dim,
bool allParallel =
false) {
2808 utils::IteratorType::parallel);
2810 iteratorTypes[dim] = utils::IteratorType::reduction;
2814 for (
int i = 0; i < inputRank; i++) {
2821 return std::make_tuple(iteratorTypes, indexingMaps);
2826 template <
typename T>
2829 auto inputType = cast<ShapedType>(input.
getType());
2831 int64_t inputRank = inputShape.size();
2832 auto [iteratorTypes, indexingMaps] =
2834 assert(indexingMaps.size() == 2 &&
2835 "We should have two maps: 1 for the input, 1 for the output");
2836 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2838 auto genericOp = builder.
create<linalg::GenericOp>(
2839 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2841 Value result = b.create<T>(loc, args[0], args[1]);
2842 b.create<linalg::YieldOp>(loc, result);
2852 auto inputType = cast<ShapedType>(input.
getType());
2854 int64_t inputRank = inputShape.size();
2856 builder, inputRank, dim,
true);
2857 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2858 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2860 indexingMaps.push_back(indexingMaps[0]);
2861 auto genericOp = builder.
create<linalg::GenericOp>(
2864 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2865 Value result = b.create<math::ExpOp>(loc, diff);
2866 b.create<linalg::YieldOp>(loc, result);
2877 Value denominator,
Value output, int64_t dim) {
2878 auto inputType = cast<ShapedType>(numerator.
getType());
2880 int64_t inputRank = inputShape.size();
2882 builder, inputRank, dim,
true);
2883 assert(indexingMaps.size() == 2 &&
2884 "We should have one map for each input (2)");
2885 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2887 indexingMaps.push_back(indexingMaps[0]);
2888 auto genericOp = builder.
create<linalg::GenericOp>(
2890 indexingMaps, iteratorTypes,
2892 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2893 b.create<linalg::YieldOp>(loc, result);
2917 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2921 Value input = getInput();
2922 ShapedType inputType = getInputOperandType();
2923 Type elementType = inputType.getElementType();
2924 int64_t reductionDim = getDimension();
2926 Value output = getOutput();
2927 dims.erase(dims.begin() + reductionDim);
2929 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2931 elementType, b, loc,
2933 Value neutralForMaxFInit =
2934 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2937 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2946 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2948 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2952 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2961 auto filterType = cast<ShapedType>(getFilter().
getType());
2963 int64_t filterH = filterShape[getFilterHDim()];
2964 int64_t filterW = filterShape[getFilterWDim()];
2968 if (filterH != r && filterH != 1)
2969 return emitOpError(
"expect filter height either equals to r or 1");
2970 if (filterW != r && filterW != 1)
2971 return emitOpError(
"expect filter width either equals to r or 1");
2972 if (filterH == 1 && filterW == 1)
2973 return emitOpError(
"expect either filter height or width equals to r");
2976 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2977 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2978 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2979 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2981 auto outputType = cast<ShapedType>(getOutput().
getType());
2984 return emitOpError(
"the output shape is not expected");
2990 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2994 Value filter = getFilter();
2995 int64_t filterRank = getFilterOperandRank();
2997 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2998 loopBounds[dim].offset = zeroAttr;
2999 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3000 loopBounds[dim].stride = oneAttr;
3006 WinogradFilterTransformOp::getLoopIteratorTypes() {
3007 int64_t filterRank = getFilterOperandRank();
3009 utils::IteratorType::parallel);
3010 return iteratorTypes;
3018 ShapedType filterType = getFilterOperandType();
3020 int64_t filterH = filterShape[getFilterHDim()];
3021 int64_t filterW = filterShape[getFilterWDim()];
3024 int64_t alpha = m + r - 1;
3025 int64_t alphaH = filterH != 1 ? alpha : 1;
3026 int64_t alphaW = filterW != 1 ? alpha : 1;
3030 resultOffsets.append(
3031 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3033 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3049 ShapedType filterType = getFilterOperandType();
3051 int64_t filterH = filterShape[getFilterHDim()];
3052 int64_t filterW = filterShape[getFilterWDim()];
3058 sliceOffsets.append(
3059 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3060 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3061 sizes[getFilterCDim()]});
3062 int64_t filterRank = getFilterOperandRank();
3065 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3066 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3067 tiledOperands.emplace_back(filterSlice);
3074 int64_t outputRank = getOutputOperandRank();
3076 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3077 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3078 tiledOperands.emplace_back(outputSlice);
3081 resultTypes.push_back(tiledOperands[1].
getType());
3083 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3096 auto inputType = cast<ShapedType>(getInput().
getType());
3098 int64_t inputH = inputShape[getInputHDim()];
3099 int64_t inputW = inputShape[getInputWDim()];
3102 int64_t tileSize = m + r - 1;
3104 auto outputType = cast<ShapedType>(getOutput().
getType());
3106 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3107 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3110 if (ShapedType::isDynamic(inputH)) {
3111 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3112 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3114 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3115 expectedOutputShape[getOutputTileHDim()] =
3116 leftTransform ? (inputH - (r - 1)) / m : inputH;
3118 if (ShapedType::isDynamic(inputW)) {
3119 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3120 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3122 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3123 expectedOutputShape[getOutputTileWDim()] =
3124 rightTransform ? (inputW - (r - 1)) / m : inputW;
3126 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3127 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3130 return emitOpError(
"the output shape is not expected");
3136 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3140 Value output = getOutput();
3141 int64_t outputRank = getOutputOperandRank();
3143 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3144 loopBounds[dim].offset = zeroAttr;
3146 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3147 loopBounds[dim].stride = oneAttr;
3153 WinogradInputTransformOp::getLoopIteratorTypes() {
3154 int64_t outputRank = getOutputOperandRank();
3156 utils::IteratorType::parallel);
3157 return iteratorTypes;
3165 ShapedType outputType = getOutputOperandType();
3167 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3168 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3172 int64_t alpha = m + r - 1;
3173 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3174 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3179 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3180 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3181 offsets[getOutputCDim()]});
3182 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3183 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3184 sizes[getOutputCDim()]});
3195 FailureOr<TilingResult>
3203 ShapedType outputType = getOutputOperandType();
3205 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3206 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3210 auto identityAffineMap =
3212 auto offsetAffineMap =
3215 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3216 offsets[getOutputTileHDim()]);
3218 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3219 offsets[getOutputTileWDim()]);
3223 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3225 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3232 sliceOffsets.append(
3233 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3239 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3240 int64_t inputRank = getInputOperandRank();
3242 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3243 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3244 tiledOperands.emplace_back(inputSlice);
3251 int64_t outputRank = getOutputOperandRank();
3253 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3254 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3255 tiledOperands.emplace_back(outputSlice);
3258 resultTypes.push_back(tiledOperands[1].
getType());
3260 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3273 auto valueType = cast<ShapedType>(getValue().
getType());
3275 int64_t valueH = valueShape[getValueAlphaHDim()];
3276 int64_t valueW = valueShape[getValueAlphaWDim()];
3277 int64_t valueTileH = valueShape[getValueTileHDim()];
3278 int64_t valueTileW = valueShape[getValueTileWDim()];
3281 bool leftTransform = valueH != 1;
3282 bool rightTransform = valueW != 1;
3284 int64_t outputRank = getOutputOperandRank();
3286 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3287 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3289 if (valueH != (leftTransform ? m + r - 1 : 1))
3290 return emitOpError(
"expect input height equals to input tile size");
3291 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3293 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3294 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3296 if (valueW != (rightTransform ? m + r - 1 : 1))
3297 return emitOpError(
"expect input width equals to input tile size");
3298 expectedOutputShape[getOutputWDim()] =
3299 (rightTransform ? m : 1) * valueTileW;
3301 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3302 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3304 auto outputType = cast<ShapedType>(getOutput().
getType());
3307 return emitOpError(
"the output shape is not expected");
3313 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3317 Value value = getValue();
3318 int64_t valueRank = getValueOperandRank();
3320 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3321 loopBounds[dim].offset = zeroAttr;
3323 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3324 loopBounds[dim].stride = oneAttr;
3330 WinogradOutputTransformOp::getLoopIteratorTypes() {
3331 int64_t valueRank = getValueOperandRank();
3333 utils::IteratorType::parallel);
3334 return iteratorTypes;
3345 auto identityAffineMap =
3350 ShapedType valueType = getValueOperandType();
3352 int64_t valueH = valueShape[0];
3353 int64_t valueW = valueShape[1];
3355 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3356 offsets[getValueTileHDim()]);
3358 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3359 offsets[getValueTileWDim()]);
3361 builder, loc, affineMap, sizes[getValueTileHDim()]);
3363 builder, loc, affineMap, sizes[getValueTileWDim()]);
3373 resultOffsets.append(
3374 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3376 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3395 ShapedType valueType = getValueOperandType();
3397 int64_t alphaH = valueShape[getValueAlphaHDim()];
3398 int64_t alphaW = valueShape[getValueAlphaWDim()];
3402 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3403 offsets[getValueTileWDim()], offsets[getValueNDim()],
3404 offsets[getValueFDim()]});
3405 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3406 sizes[getValueTileWDim()], sizes[getValueNDim()],
3407 sizes[getValueFDim()]});
3408 int64_t valueRank = getValueOperandRank();
3410 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3411 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3412 tiledOperands.emplace_back(valueSlice);
3419 int64_t outputRank = getOutputOperandRank();
3421 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3422 loc, getOutput(), resultOffsets, resultSizes, strides);
3423 tiledOperands.emplace_back(outputSlice);
3426 resultTypes.push_back(tiledOperands[1].
getType());
3428 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3447 llvm::set_union(explicitSet, defaultSet);
3448 return explicitSet == defaultSet;
3468 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3470 auto opIndexingMap = opIndexingMaps[opIndex];
3471 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3474 return matmulOp->emitOpError()
3475 <<
"Unexpected dim expression in map result.";
3478 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3479 return matmulOp->emitOpError()
3480 <<
"Invalid broadcast requested, should be (d2).";
3490 AffineMap defaultIndexingMap,
bool isLHS) {
3493 return batchMatmulOp->emitOpError()
3494 <<
"Unexpected result dim expression (outside the set of default "
3499 return batchMatmulOp->emitOpError()
3500 <<
"no. of result dim expressions exceeds 3.";
3502 auto hasValidBatchDim = [](
AffineMap map) {
3509 if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3510 return batchMatmulOp->emitOpError() <<
"Invalid broadcast requested.";
3511 }
else if (!hasValidBatchDim(opIndexingMap)) {
3512 return batchMatmulOp->emitOpError()
3513 <<
"Invalid batch dimension expression.";
3524 return batchMatmulOp->emitOpError()
3525 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3528 auto areValidOutputResultDim = [](
AffineMap outputMap) {
3529 return outputMap.getResult(0).isFunctionOfDim(0) &&
3530 outputMap.getResult(1).isFunctionOfDim(1) &&
3531 outputMap.getResult(2).isFunctionOfDim(2);
3534 if (!areValidOutputResultDim(opIndexingMap))
3535 return batchMatmulOp->emitOpError()
3536 <<
"Invalid output map result dimension.";
3543 static LogicalResult
3547 batchMatmulOp.getIndexingMapsArray();
3549 batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3551 if (opIndexingMaps.size() != 3)
3552 return batchMatmulOp->emitOpError()
3553 <<
"Indexing_map attribute must have 3 affine maps.";
3555 auto opIndexingMap = opIndexingMaps[opIndex];
3556 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3558 if (opIndex == 2 && failed(
verifyOutputMap(batchMatmulOp, opIndexingMap)))
3561 if (failed(
verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3583 return indexingMaps;
3588 utils::IteratorType::parallel,
3589 utils::IteratorType::reduction};
3592 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3594 std::string MatmulOp::getLibraryCallName() {
3598 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3602 bool MatmulOp::hasUserDefinedMaps() {
3606 return defaultMaps != explicitMaps;
3614 "MatmulOp regionBuilder expects 3 (>=0) args");
3615 RegionBuilderHelper helper(b, block);
3618 TypeFn castVal = TypeFn::cast_signed;
3619 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3620 return attr.
getName() ==
"cast";
3622 if (castIter != attrs.end()) {
3623 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3631 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3633 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3634 yields.push_back(value4);
3635 helper.yieldOutputs(yields);
3639 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3640 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3651 ArrayAttr arrayAttr;
3655 if (llvm::any_of(arrayAttr,
3656 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3658 <<
"element of indexing_maps array is not an affine_map";
3665 if (failed(indexingMapsAttr))
3668 if (*indexingMapsAttr ==
nullptr) {
3669 auto indexingMapAttrs = llvm::map_to_vector(
3670 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3675 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3677 MatmulOp::getRegionBuilder());
3682 MatmulOp::getDefaultIndexingMaps(
getContext()),
3684 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3685 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3687 std::array<StringRef, 3> elidedAttrs = {
3688 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3696 if (!hasUserDefinedMaps())
3699 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3710 void MatmulOp::getEffects(
3713 if (hasPureTensorSemantics())
3727 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3737 for (
auto result : outAffineMap.
getResults()) {
3738 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3739 assert(dimExpr &&
"affine_map is a projected permutation");
3740 dimsInOutput[dimExpr.getPosition()] =
true;
3744 for (
auto dimOccursInOutput : dimsInOutput)
3745 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3746 : utils::IteratorType::reduction);
3748 return iteratorTypes;
3751 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3757 "ContractOp regionBuilder expects 3 args");
3758 RegionBuilderHelper helper(b, block);
3760 TypeFn castSignedness = TypeFn::cast_signed;
3761 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3762 return attr.
getName() ==
"cast";
3764 if (castIter != attrs.end()) {
3765 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3771 Value lhsAtOutType =
3772 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3773 Value rhsAtOutType =
3774 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3775 Value productAtOutType =
3776 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3779 helper.yieldOutputs({result});
3784 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3786 "expected 'indexing_maps' attribute");
3787 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3794 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3796 p, getOperation(), getInputs(), getOutputs(),
3797 {
"indexing_maps",
"operandSegmentSizes"});
3801 int iterationSpaceDims = -1;
3810 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3811 bool isInput) -> LogicalResult {
3814 return emitError(
"provided affine_map is not a projected permutation");
3817 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3819 return emitError(
"ranks of shaped operand and results of corresponding "
3820 "affine_map differ");
3822 return emitError(
"affine_map specifies shaped access while operand has "
3827 if (iterationSpaceDims == -1) {
3831 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
3832 return emitError(
"iteration spaces of provided affine_maps differ");
3837 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3839 llvm_unreachable(
"affine_map is a projected permutation");
3842 inOccurrences[affineDimExpr.getPosition()] += 1;
3844 outOccurrences[affineDimExpr.getPosition()] += 1;
3850 for (
auto &&[affineMap, operandType, isInput] :
3851 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3853 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3857 bool hasContractingDim =
false;
3858 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3859 size_t inOccCount = inOccurrences[dimIndex];
3860 size_t outOccCount = outOccurrences[dimIndex];
3863 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3865 if (inOccCount == 0 && outOccCount == 0)
3866 return emitError() <<
"iteration space dim at index " << dimIndex
3867 <<
" not used to access any operand";
3878 if (inOccCount == 1 && outOccCount != 1)
3880 <<
"iteration space dim at index " << dimIndex
3881 <<
" is neither a contracting dim nor of parallel iteration type";
3884 if (!hasContractingDim)
3885 return emitError(
"'indexing_maps' do not specify a contracting dimension");
3894 void ContractOp::getEffects(
3897 if (hasPureTensorSemantics())
3910 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
3914 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
3915 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
3916 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
3917 return indexingMaps;
3922 utils::IteratorType::parallel, utils::IteratorType::parallel,
3923 utils::IteratorType::parallel, utils::IteratorType::reduction};
3926 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
3928 std::string BatchMatmulOp::getLibraryCallName() {
3934 bool BatchMatmulOp::hasUserDefinedMaps() {
3938 return defaultMaps != explicitMaps;
3942 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
3944 "Expected less than 3 result dim expr.");
3945 bool isValid =
false;
3946 enum Indices { batchPos, mPos, nPos, kPos };
3963 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3964 RegionBuilderHelper helper(b, block);
3967 TypeFn castVal = TypeFn::cast_signed;
3968 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3969 return attr.
getName() ==
"cast";
3971 if (castIter != attrs.end()) {
3972 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3977 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
3978 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
3979 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3981 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
3982 yields.push_back(addVal);
3983 helper.yieldOutputs(yields);
3999 if (!isa<AffineMapAttr>(mapAttr)) {
4001 "expected affine map attribute");
4003 indexingMapsAttr.push_back(mapAttr);
4013 if (indexingMapsAttr.empty()) {
4014 indexingMapsAttr = llvm::map_to_vector(
4015 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4022 BatchMatmulOp::getNumRegionArgs(),
4023 BatchMatmulOp::getRegionBuilder());
4028 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4030 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4031 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4033 std::array<StringRef, 3> elidedAttrs = {
4034 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4043 if (!hasUserDefinedMaps())
4046 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4053 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4058 void BatchMatmulOp::getEffects(
4061 if (hasPureTensorSemantics())
4075 struct ArityGroupAndKind {
4087 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4093 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4094 constexpr
int lastBinary =
4095 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4096 constexpr
int lastTernary =
4097 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4099 int val =
static_cast<int>(
kind);
4100 ArityGroupAndKind result;
4102 if (val < lastUnary) {
4103 result.arityGroup = ElementwiseArityGroup::Unary;
4104 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4107 if (val < lastBinary) {
4108 result.arityGroup = ElementwiseArityGroup::Binary;
4109 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4112 if (val >= lastTernary) {
4113 llvm_unreachable(
"unhandled ElementwiseFn");
4115 result.arityGroup = ElementwiseArityGroup::Ternary;
4116 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4121 auto rank = getResultRank();
4126 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4135 mlir::linalg::ElementwiseKind elemwiseKindVal;
4140 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4141 if (!elemwiseKindAttr)
4143 "expected ElementwiseKind attribute");
4144 elemwiseKindVal = elemwiseKindAttr.getValue();
4147 "expected operation 'kind' attribute");
4163 if (!isa<AffineMapAttr>(mapAttr))
4165 "expected affine map attribute");
4166 indexingMapsAttr.push_back(mapAttr);
4177 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4179 ElementwiseOp::getRegionBuilder())) {
4181 "unable to parse elemwise op");
4185 if (indexingMapsAttr.empty()) {
4189 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4192 "return type needs to be shaped type");
4193 auto numDims = shapedType.getRank();
4194 indexingMapsAttr = llvm::map_to_vector(
4195 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4212 unsigned numDims = getResultRank();
4215 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4219 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4220 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4237 ElementwiseKind elemwiseKind;
4238 for (
auto attr : attrs) {
4240 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4241 assert(kindAttr &&
"op kind attribute incorrectly set");
4242 elemwiseKind = kindAttr.getValue();
4249 auto kind = groupAndKind.kind;
4252 &&
"Elementwise regionBuilder number of block args mismatch");
4254 RegionBuilderHelper helper(b, block);
4258 if (
arityGroup == ElementwiseArityGroup::Unary) {
4261 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4265 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4270 assert(
false &&
"found unhandled category in elemwise");
4272 yields.push_back(result);
4273 helper.yieldOutputs(yields);
4276 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4281 void ElementwiseOp::getEffects(
4284 if (hasPureTensorSemantics())
4307 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4309 .take_back(mixedTiles.size()),
4311 int64_t shape = std::get<0>(it);
4312 if (shape == ShapedType::kDynamic) {
4313 newMixedTileSizes.push_back(std::get<1>(it));
4320 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4322 newMixedTileSizes.push_back(
tile);
4325 "tile size and dim size don't match!");
4326 newMixedTileSizes.push_back(
4331 return newMixedTileSizes;
4334 template <
typename OpTy>
4335 static LogicalResult
4338 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4339 "applies to only pack or unpack operations");
4340 int64_t destRank = op.getDestRank();
4342 reifiedReturnShapes[0] =
4347 template <
typename OpTy>
4349 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4350 "applies to only pack or unpack operations");
4354 assert(tiles.size() == dimsToTile.size() &&
4355 "tiles must match indices of dimension to block");
4357 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4358 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4359 return dimAndTileMapping;
4362 template <
typename OpTy>
4364 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4365 "applies to only pack or unpack operations");
4368 unsigned dynamicValIndex = 0;
4369 for (int64_t staticTile : op.getStaticInnerTiles()) {
4370 if (!ShapedType::isDynamic(staticTile))
4373 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4375 return mixedInnerTiles;
4378 template <
typename OpTy>
4380 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4381 "applies to only pack or unpack operations");
4394 size_t dimsPosSize = dimsPos.size();
4395 if (dimsPosSize > rank)
4398 if (dimsPosSize != uniqued.size())
4400 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4401 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4410 sourceShape.size() == limitShape.size() &&
4411 "expected source shape rank, and limit of the shape to have same rank");
4412 return llvm::all_of(
4413 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4414 int64_t sourceExtent = std::get<0>(it);
4415 int64_t limit = std::get<1>(it);
4416 return ShapedType::isDynamic(sourceExtent) ||
4417 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4421 template <
typename OpTy>
4423 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4424 "applies to only pack or unpack operations");
4425 Operation *op = packOrUnPack.getOperation();
4429 return llvm::any_of(
4435 if (hasZeros(mixedTiles))
4436 return op->
emitError(
"invalid zero tile factor");
4439 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4440 ? packOrUnPack.getSourceType()
4441 : packOrUnPack.getDestType();
4442 size_t unpackedRank = unpackedType.getRank();
4446 return op->
emitError(
"invalid inner_dims_pos vector");
4448 return op->
emitError(
"invalid outer_dims_perm vector");
4449 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4450 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4454 if (mixedTiles.size() > unpackedRank) {
4455 return op->
emitError(
"tiling factors must be less than or equal to the "
4456 "input rank for pack or output rank for unpack");
4460 "tiling factors must equal the number of dimensions to tile");
4463 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4464 ? packOrUnPack.getDestType()
4465 : packOrUnPack.getSourceType();
4466 size_t packedRank = packedType.getRank();
4468 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4469 if (expectedPackedRank != packedRank) {
4471 "packed rank != (unpacked rank + num tiling factors), got ")
4472 << packedRank <<
" != " << expectedPackedRank;
4478 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4479 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4480 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4481 return op->
emitError(
"the shape of output is not large enough to hold the "
4482 "packed data. Expected at least ")
4483 << expectedPackedType <<
", got " << packedType;
4486 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4488 [](std::tuple<int64_t, OpFoldResult> it) {
4489 int64_t shape = std::get<0>(it);
4490 if (Attribute attr =
4491 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4492 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4493 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4494 return shape == staticTileSize;
4496 return ShapedType::isDynamic(shape);
4498 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4499 "tiled dimension in the packed type");
4511 struct PackOrUnPackTransposeResult {
4518 template <
typename OpTy>
4519 static PackOrUnPackTransposeResult
4523 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4524 "applies to only pack or unpack operations");
4525 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4526 "some permutation must be non-empty");
4527 PackOrUnPackTransposeResult metadata;
4528 metadata.innerDimsPos =
4530 metadata.innerTiles =
4532 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4533 ? packOrUnPackOp.getSourceRank()
4534 : packOrUnPackOp.getDestRank();
4535 metadata.outerDimsPerm =
4536 packOrUnPackOp.getOuterDimsPerm().empty()
4537 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4539 if (!innerPermutation.empty()) {
4540 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4542 "invalid inner permutation");
4546 if (!outerPermutation.empty()) {
4547 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4549 "invalid outer permutation");
4559 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4560 setNameFn(getResult(),
"pack");
4566 std::optional<Value> paddingValue,
4569 "number of tile sizes specified must match the specified number of "
4570 "original dimensions to be tiled");
4574 build(builder, state, dest.
getType(), source, dest,
4575 paddingValue ? *paddingValue :
nullptr,
4601 ShapedType inputType = getSourceType();
4602 int64_t inputRank = inputType.getRank();
4603 return getDestType().getShape().take_front(inputRank);
4608 auto packedShape = getDestType().getShape();
4612 res.push_back(packedShape[index]);
4623 outputShape.take_front(inputShape.size()));
4626 "expected output and outer_dims_perm to have same size");
4631 if (ShapedType::isDynamic(inputShape[pos]))
4635 if (!constantTile) {
4636 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4637 (inputShape[pos] % outputTileSizes[pos] != 0))
4639 }
else if (inputShape[pos] % (*constantTile) != 0) {
4653 auto paddingValue = getPaddingValue();
4656 return emitOpError(
"expected padding_value has ")
4657 << getSourceType().getElementType()
4658 <<
" but got: " << paddingValue.getType();
4661 if (!paddingValue &&
4662 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4663 getDestType().
getShape(), getOuterDimsPerm(),
4666 "invalid tile factor or output size provided. Only full tiles are "
4667 "supported when padding_value is not set");
4677 for (
auto o : ofrs) {
4679 if (llvm::dyn_cast_if_present<Value>(o))
4680 result.push_back(ShapedType::kDynamic);
4695 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4697 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4698 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4701 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4702 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4710 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4725 builder, loc, ceilDivExpr,
4726 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4730 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4741 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4742 if (!ShapedType::isDynamic(resultTypeShape[i]))
4753 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4775 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4776 if (ShapedType::isDynamic(value))
4777 mixedSizes.push_back(
4782 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
4783 int64_t dimPos = std::get<0>(it);
4785 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4788 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
4790 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4791 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4792 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4799 *
this, innerPermutation, outerPermutation);
4800 Value transposedDest =
4801 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4802 metadata.innerDimsPos, metadata.outerDimsPerm);
4803 return b.
create<PackOp>(loc, getSource(), transposedDest,
4804 metadata.innerDimsPos, metadata.innerTiles,
4805 getPaddingValue(), metadata.outerDimsPerm);
4809 template <
typename OpTy>
4811 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4812 "applies to only pack or unpack operations");
4813 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4815 : op.getSourceType();
4817 for (
auto [dimDest,
tile] : llvm::zip(
4818 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4820 if (!constTileSize || ShapedType::isDynamic(dimDest))
4827 if (getPaddingValue())
4842 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4844 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4856 auto packTiles = packOp.getMixedTiles();
4857 auto unPackTiles = unPackOp.getMixedTiles();
4858 if (packTiles.size() != unPackTiles.size())
4860 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4869 auto srcType = op.getSourceType();
4870 if (llvm::any_of(op.getInnerDimsPos(),
4871 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4873 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4875 return !PackOp::requirePaddingValue(
4876 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4877 op.getOuterDimsPerm(), op.getMixedTiles());
4884 bool changeNeeded =
false;
4885 srcShape.assign(packOp.getSourceType().getShape().begin(),
4886 packOp.getSourceType().getShape().end());
4887 destShape.assign(packOp.getDestType().getShape().begin(),
4888 packOp.getDestType().getShape().end());
4889 llvm::SmallSetVector<int64_t, 4> innerDims;
4890 innerDims.insert_range(packOp.getInnerDimsPos());
4892 if (!packOp.getOuterDimsPerm().empty())
4894 int srcRank = packOp.getSourceRank();
4895 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4896 if (innerDims.contains(i))
4899 int64_t destPos = i;
4900 if (!inverseOuterDimsPerm.empty())
4901 destPos = inverseOuterDimsPerm[srcPos];
4902 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4903 ShapedType::isDynamic(destShape[destPos])) {
4906 int64_t size = srcShape[srcPos];
4907 if (ShapedType::isDynamic(size))
4908 size = destShape[destPos];
4909 srcShape[srcPos] = size;
4910 destShape[destPos] = size;
4911 changeNeeded =
true;
4913 return changeNeeded;
4916 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4918 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4919 if (unPackOp.getSourceType() != packOp.getDestType())
4921 if (packOp.getPaddingValue() ||
4925 rewriter.
replaceOp(packOp, unPackOp.getSource());
4932 packOp.getPaddingValueMutable().clear();
4941 Value source = packOp.getSource();
4942 if (srcShape != packOp.getSourceType().getShape()) {
4943 auto newSrcType = packOp.getSourceType().clone(srcShape);
4945 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4947 Value dest = packOp.getDest();
4948 RankedTensorType originalResultType = packOp.getDestType();
4949 bool needUpdateDestType = (destShape != originalResultType.getShape());
4950 if (needUpdateDestType) {
4951 auto newDestType = packOp.getDestType().clone(destShape);
4953 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4956 packOp.getSourceMutable().assign(source);
4957 packOp.getDestMutable().assign(dest);
4958 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
4961 if (needUpdateDestType) {
4964 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
4973 template <
typename PackOrUnpackOp>
4975 RankedTensorType packedTensorType) {
4976 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4977 std::is_same<PackOrUnpackOp, UnPackOp>::value,
4978 "Function meant for pack/unpack");
4984 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4991 int64_t packedRank = packedTensorType.getRank();
5001 return llvm::all_of(
5002 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5003 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5006 bool PackOp::isLikePad() {
5007 auto packedTensorType =
5008 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5013 std::optional<Attribute> paddingValue;
5014 if (
auto pad = adaptor.getPaddingValue())
5016 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5017 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5018 getDestType(), paddingValue))
5019 return reshapedSource;
5057 PackOp newOp = rewriter.
create<PackOp>(
5058 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5059 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5063 Value oldResult = op.getResult();
5064 Value newResult = newOp.getResult();
5066 ? rewriter.
create<tensor::CastOp>(
5067 op->getLoc(), oldResult.
getType(), newResult)
5080 void UnPackOp::getAsmResultNames(
5082 setNameFn(getResult(),
"unpack");
5104 ShapedType destType = getDestType();
5105 int64_t destRank = destType.getRank();
5106 return getSourceType().getShape().take_front(destRank);
5111 auto packedShape = getSourceType().getShape();
5115 res.push_back(packedShape[index]);
5137 "number of tile sizes specified must match the specified number of "
5138 "original dimensions to be tiled");
5142 build(builder, state, dest.
getType(), source, dest,
5161 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5163 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5164 if (srcType.isDynamicDim(i))
5165 mixedSizes.push_back(b.
create<tensor::DimOp>(loc, source, i).
getResult());
5167 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5170 applyPermutationToVector<OpFoldResult>(
5174 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5175 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5177 auto elemType = srcType.getElementType();
5178 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5182 Value transposedSource,
5186 *
this, innerPermutation, outerPermutation);
5187 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
5188 metadata.innerDimsPos, metadata.innerTiles,
5189 metadata.outerDimsPerm);
5196 bool changeNeeded =
false;
5197 srcShape.assign(op.getSourceType().getShape().begin(),
5198 op.getSourceType().getShape().end());
5199 destShape.assign(op.getDestType().getShape().begin(),
5200 op.getDestType().getShape().end());
5201 llvm::SmallSetVector<int64_t, 4> innerDims;
5202 innerDims.insert_range(op.getInnerDimsPos());
5204 if (!op.getOuterDimsPerm().empty())
5206 int destRank = op.getDestRank();
5207 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5208 if (innerDims.contains(i))
5211 int64_t destPos = i;
5212 if (!inverseOuterDimsPerm.empty())
5213 srcPos = inverseOuterDimsPerm[destPos];
5214 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5215 ShapedType::isDynamic(destShape[destPos])) {
5218 int64_t size = srcShape[srcPos];
5219 if (ShapedType::isDynamic(size))
5220 size = destShape[destPos];
5221 srcShape[srcPos] = size;
5222 destShape[destPos] = size;
5223 changeNeeded =
true;
5225 return changeNeeded;
5228 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5231 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5232 if (packOp.getSourceType() != unPackOp.getDestType())
5234 if (packOp.getPaddingValue() ||
5238 rewriter.
replaceOp(unPackOp, packOp.getSource());
5242 if (
auto dstStyleOp =
5243 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5244 auto destValue = cast<OpResult>(unPackOp.getDest());
5245 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5247 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5251 if (unPackOp->hasOneUse()) {
5252 auto extractSliceUser =
5253 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5254 if (extractSliceUser &&
5257 extractSliceUser.getSourceType().getRank() ==
5258 extractSliceUser.getResultType().getRank()) {
5261 auto newDest = rewriter.
create<tensor::ExtractSliceOp>(
5262 unPackOp->getLoc(), unPackOp.getDest(),
5263 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5264 extractSliceUser.getMixedStrides());
5266 unPackOp.setDpsInitOperand(0, newDest);
5267 unPackOp.getResult().setType(newDest.
getType());
5269 rewriter.
replaceOp(extractSliceUser, unPackOp);
5278 Value source = unPackOp.getSource();
5279 if (srcShape != unPackOp.getSourceType().getShape()) {
5280 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5281 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
5282 unPackOp.getSource());
5284 Value dest = unPackOp.getDest();
5285 if (destShape != unPackOp.getDestType().getShape()) {
5286 auto newDestType = unPackOp.getDestType().clone(destShape);
5288 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5291 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5292 unPackOp.getOuterDimsPerm());
5294 unPackOp, unPackOp.getResult().getType(), newOp);
5301 bool UnPackOp::isLikeUnPad() {
5302 RankedTensorType packedTensorType = getSourceType();
5307 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5308 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5310 return reshapedSource;
5339 Value sourceTensor = newOperands[0];
5343 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5349 UnPackOp newOp = rewriter.
create<UnPackOp>(
5350 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5351 newMixedTileSizes, op.getOuterDimsPerm());
5355 Value oldResult = op.getResult();
5356 Value newResult = newOp.getResult();
5358 ? rewriter.
create<tensor::CastOp>(
5359 op->getLoc(), oldResult.
getType(), newResult)
5375 void LinalgDialect::getCanonicalizationPatterns(
5384 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override