38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/SmallSet.h"
40 #include "llvm/ADT/StringSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
53 auto type = cast<ShapedType>(v.
getType());
54 if (!type.isDynamicDim(dim))
59 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
60 return builder.create<tensor::DimOp>(loc, v, dim);
62 .Case<MemRefType>([&](MemRefType t) ->
Value {
63 return builder.create<memref::DimOp>(loc, v, dim);
74 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
75 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
78 .Case<MemRefType>([&](MemRefType type) ->
Value {
79 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
82 .Default([&](
Type t) {
return nullptr; });
91 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
93 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
95 llvm_unreachable(
"Expected MemRefType or TensorType");
100 auto shapedType = llvm::cast<ShapedType>(source.
getType());
101 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
123 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
127 for (
auto containers : {inputTypes, outputTypes}) {
128 for (
auto t : containers) {
140 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
144 regionBuilder(b, *body, attrs);
156 std::optional<TypeRange> resultTensorTypes,
163 if (!resultTensorTypes)
164 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
165 llvm::IsaPred<RankedTensorType>);
167 state.addOperands(inputs);
168 state.addOperands(outputs);
169 state.addTypes(derivedResultTypes);
170 state.addAttributes(attributes);
172 "operandSegmentSizes",
174 static_cast<int32_t>(outputs.size())}));
177 Region ®ion = *state.addRegion();
179 state.attributes.getAttrs(), regionBuilder);
188 bool addOperandSegmentSizes =
true) {
189 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
218 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
220 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
224 if (addOperandSegmentSizes) {
233 attrs.
append(
"operandSegmentSizes",
235 {static_cast<int32_t>(inputsOperands.size()),
236 static_cast<int32_t>(outputsOperands.size())}));
241 {static_cast<int32_t>(inputsOperands.size()),
242 static_cast<int32_t>(outputsOperands.size())}));
246 std::optional<RegisteredOperationName> info =
249 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
250 return parser.emitError(attrsLoc)
251 <<
"'" << result.name.getStringRef() <<
"' op ";
262 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
263 if (!outputs.empty())
264 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
275 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
278 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
279 "region expects {0} args, got {1}",
280 numRegionArgs, inputTypes.size() + outputTypes.size()));
299 unsigned numRegionArgs,
311 result.
addTypes(outputTensorsTypes);
313 std::unique_ptr<Region> region = std::make_unique<Region>();
325 if (resultTypes.empty())
334 {
"operandSegmentSizes",
337 "linalg.memoized_indexing_maps"});
374 class RegionBuilderHelper {
377 : builder(builder), block(block) {}
380 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
381 if (!isFloatingPoint(arg))
382 llvm_unreachable(
"unsupported non numeric type");
384 builder.setInsertionPointToEnd(&block);
387 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
389 return builder.create<math::LogOp>(arg.
getLoc(), arg);
391 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
393 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
395 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
397 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
398 case UnaryFn::reciprocal: {
400 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
401 ::cast<TypedAttr>(oneAttr));
402 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
405 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
407 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
409 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
410 case UnaryFn::square:
411 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
413 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
415 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
417 llvm_unreachable(
"unsupported unary function");
422 bool allComplex = isComplex(arg0) && isComplex(arg1);
423 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
424 bool allInteger = isInteger(arg0) && isInteger(arg1);
427 if (!allComplex && !allFloatingPoint && !allInteger)
428 llvm_unreachable(
"unsupported non numeric type");
430 builder.setInsertionPointToEnd(&block);
434 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
435 if (allFloatingPoint)
436 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
438 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
439 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
442 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
443 if (allFloatingPoint)
444 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
446 llvm_unreachable(
"unsupported operation: sub with bools");
447 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
450 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
451 if (allFloatingPoint)
452 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
454 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
455 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
458 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
459 if (allFloatingPoint)
460 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
462 llvm_unreachable(
"unsupported operation: div with bools");
463 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
464 case BinaryFn::div_unsigned:
465 if (!allInteger || allBool)
466 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
467 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
468 case BinaryFn::max_signed:
470 if (allFloatingPoint)
471 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
472 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
473 case BinaryFn::min_signed:
475 if (allFloatingPoint)
476 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
477 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
478 case BinaryFn::max_unsigned:
480 if (allFloatingPoint)
481 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
482 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
483 case BinaryFn::min_unsigned:
485 if (allFloatingPoint)
486 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
487 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
489 assert(allFloatingPoint);
490 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
492 llvm_unreachable(
"unsupported binary function");
500 bool tailFloatingPoint =
501 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
502 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
504 builder.setInsertionPointToEnd(&block);
506 case TernaryFn::select:
507 if (!headBool && !(tailFloatingPoint || tailInteger))
508 llvm_unreachable(
"unsupported non numeric type");
509 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
511 llvm_unreachable(
"unsupported ternary function");
517 case TypeFn::cast_signed:
518 return cast(toType, operand,
false);
519 case TypeFn::cast_unsigned:
520 return cast(toType, operand,
true);
522 llvm_unreachable(
"unsupported type conversion function");
527 builder.setInsertionPointToEnd(&block);
528 Location loc = builder.getUnknownLoc();
529 builder.create<YieldOp>(loc, values);
532 Value constant(
const std::string &value) {
534 builder.setInsertionPointToEnd(&block);
535 Location loc = builder.getUnknownLoc();
537 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
540 Value index(int64_t dim) {
542 builder.setInsertionPointToEnd(&block);
543 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
546 Type getIntegerType(
unsigned width) {
560 builder.setInsertionPointToEnd(&block);
561 auto loc = operand.
getLoc();
565 bool isComplex(
Value value) {
566 return llvm::isa<ComplexType>(value.
getType());
568 bool isFloatingPoint(
Value value) {
569 return llvm::isa<FloatType>(value.
getType());
571 bool isInteger(
Value value) {
572 return llvm::isa<IntegerType>(value.
getType());
589 LogicalResult matchAndRewrite(CopyOp copyOp,
591 if (copyOp.getInputs() != copyOp.getOutputs())
593 if (copyOp.hasPureBufferSemantics())
596 rewriter.
replaceOp(copyOp, copyOp.getInputs());
606 results.
add<EraseSelfCopy>(context);
619 template <
typename TensorReshapeOp>
622 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
624 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
629 TensorReshapeOp newInit;
630 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
632 newInit = rewriter.
create<TensorReshapeOp>(
633 loc, reshapeOp.getResultType(), oldFill.output(),
634 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
635 reshapeOp.getStaticOutputShape());
637 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
639 reshapeOp.getReassociation());
652 LogicalResult matchAndRewrite(tensor::PadOp padOp,
654 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
660 Value padValue = padOp.getConstantPaddingValue();
661 if (!padValue || fillOp.value() != padValue)
667 padOp,
"failed to reify tensor.pad op result shape");
669 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
670 padOp.getLoc(), reifiedShape.front(),
671 padOp.getResultType().getElementType());
677 if (replacement.getType() != padOp.getResultType()) {
678 replacement = rewriter.
create<tensor::CastOp>(
679 fillOp.getLoc(), padOp.getResultType(), replacement);
689 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
692 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
694 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
698 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
703 Value firstDest = insertOp.getDest();
704 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
705 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
710 bool disjoint =
false;
711 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
714 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
715 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
716 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
720 int64_t prevStart = prevOp.getStaticOffset(i);
721 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
722 prevOp.getStaticStride(i);
723 int64_t nextStart = insertOp.getStaticOffset(i);
724 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
725 insertOp.getStaticStride(i);
726 if (prevEnd < nextStart || nextEnd < prevStart) {
734 firstDest = prevOp.getDest();
745 Value padValue = srcPadOp.getConstantPaddingValue();
746 if (!padValue || dstFillOp.value() != padValue)
762 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
764 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
767 RankedTensorType srcPadType = srcPadOp.getSourceType();
769 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
770 if (srcPadType.isDynamicDim(i)) {
772 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
775 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
780 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
781 newSizes, insertOp.getMixedStrides());
787 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
791 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
795 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
800 Value extractedScalar = fillOp.getInputs()[0];
803 rewriter.
replaceOp(extractOp, extractedScalar);
811 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
812 tensor::PackOp packOp) {
813 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
817 if (
auto paddingValue = packOp.getPaddingValue())
821 Value packOpDest = packOp.getDest();
825 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
835 LogicalResult matchAndRewrite(tensor::PackOp packOp,
837 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
840 rewriter.
replaceOp(packOp, fillOp.value().result());
849 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
851 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
854 copyOp.getOutputs());
857 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
859 fillOp.getOutputs());
870 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
872 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
874 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
875 transposeOp.getDpsInitOperand(0)->get());
887 .
add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
888 FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
889 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
890 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
903 for (
ValueRange container : {inputs, outputs}) {
904 for (
Value v : container) {
905 Type t = v.getType();
906 blockArgTypes.push_back(
908 blockArgLocs.push_back(v.getLoc());
914 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
918 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
920 for (
Value v : getRegionInputArgs())
922 for (
Value v : getRegionOutputArgs())
926 void GenericOp::build(
929 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
932 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
933 iteratorTypes, doc, libraryCall);
937 inputs, outputs, bodyBuild);
940 void GenericOp::build(
944 StringRef libraryCall,
947 build(builder, result, resultTensorTypes, inputs, outputs,
952 return IteratorTypeAttr::get(builder.getContext(), iter);
955 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
956 bodyBuild, attributes);
959 void GenericOp::build(
963 StringRef libraryCall,
966 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
967 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
970 void GenericOp::build(
976 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
978 "", bodyBuild, attributes);
981 void GenericOp::build(
987 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
990 "", bodyBuild, attributes);
997 auto genericAttrNames = linalgTraitAttrNames();
1000 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1002 for (
auto attr : (*this)->getAttrs()) {
1003 if (attr.getName() == getIteratorTypesAttrName()) {
1004 auto iteratorTypes =
1005 llvm::cast<ArrayAttr>(attr.getValue())
1006 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1012 llvm::to_vector(llvm::map_range(
1013 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1017 genericAttrs.emplace_back(
1018 getIteratorTypesAttrName(),
1020 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1021 genericAttrs.push_back(attr);
1024 if (!genericAttrs.empty()) {
1026 p << genericDictAttr;
1032 genericAttrNames.push_back(
"operandSegmentSizes");
1033 genericAttrNamesSet.insert(genericAttrNames.back());
1035 bool hasExtraAttrs =
false;
1037 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1040 if (hasExtraAttrs) {
1047 if (!getRegion().empty()) {
1057 DictionaryAttr dictAttr;
1066 dictAttr.getValue().end());
1072 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1074 if (!iteratorTypes) {
1075 return parser.
emitError(attributeLocation)
1076 <<
"expected " << getIteratorTypesAttrName(result.
name)
1077 <<
" array attribute";
1082 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1083 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1084 if (!maybeIteratorType.has_value())
1086 <<
"unexpected iterator_type (" << s <<
")";
1088 iteratorTypeAttrs.push_back(
1105 std::unique_ptr<Region> region = std::make_unique<Region>();
1117 result.
addTypes(outputTensorsTypes);
1125 LinalgOp linalgOp) {
1126 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1127 if (!llvm::isa<MemRefType>(operand.
getType()))
1129 effects.emplace_back(
1134 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1135 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1137 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1148 void GenericOp::getEffects(
1163 template <
typename OpTy>
1167 LogicalResult matchAndRewrite(OpTy linalgOp,
1170 if (llvm::any_of(linalgOp.getIndexingMapsArray(),
1171 [](
AffineMap map) { return !map.isIdentity(); }))
1176 Block &body = linalgOp->getRegion(0).
front();
1177 if (!llvm::hasSingleElement(body))
1179 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1184 if (linalgOp.hasPureBufferSemantics()) {
1185 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1186 linalgOp.getDpsInputOperand(0)->get() ==
1187 linalgOp.getDpsInitOperand(0)->get()) {
1195 if (!linalgOp.hasPureTensorSemantics())
1202 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1203 if (!yieldArg || yieldArg.getOwner() != &body)
1205 unsigned argumentNumber = yieldArg.getArgNumber();
1206 Value returnedArg = linalgOp->getOperand(argumentNumber);
1207 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1211 if (returnType != resultType) {
1216 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1217 linalgOp.getLoc(), resultType, returnedArg);
1219 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1222 returnedArg = rewriter.
create<tensor::CastOp>(
1223 linalgOp.getLoc(), resultType, returnedArg);
1226 returnedArgs.push_back(returnedArg);
1229 if (returnedArgs.size() != linalgOp->getNumResults())
1231 rewriter.
replaceOp(linalgOp, returnedArgs);
1240 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1262 for (
Type outputType : outputTypes) {
1263 if (llvm::isa<RankedTensorType>(outputType))
1268 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1277 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1279 for (
Value v : getRegionInputArgs())
1284 if (!getResults().empty())
1285 setNameFn(getResults().front(),
"mapped");
1292 build(builder, result,
TypeRange{}, inputs, init);
1297 if (llvm::isa<RankedTensorType>(initType))
1302 inputs, {}, bodyBuild);
1309 bool initFirst =
false) {
1315 for (
auto &operand : operands) {
1317 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1324 payloadOpOperands.push_back(block.
getArguments().back());
1325 for (
const auto &arg : block.
getArguments().drop_back())
1326 payloadOpOperands.push_back(arg);
1335 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1342 std::optional<OperationName> payloadOpName;
1346 if (failed(operationName))
1350 payloadOpName = operationName.value();
1358 if (payloadOpName.has_value()) {
1396 for (
const auto &[operand, bbArg] :
1398 if (bbArg != operand)
1402 for (
const auto &[operand, bbArg] :
1404 if (bbArg != operand)
1413 std::string attrToElide;
1415 for (
const auto &attr : payloadOp->
getAttrs()) {
1417 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1418 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1419 attrToElide = attr.getName().str();
1420 elidedAttrs.push_back(attrToElide);
1429 Block *mapper = getBody();
1444 [&](
auto arg) { p.printRegionArgument(arg); });
1453 auto *bodyBlock = getBody();
1454 auto blockArgs = bodyBlock->getArguments();
1457 if (getInputs().size() != blockArgs.size())
1458 return emitOpError() <<
"expects number of operands to match the arity of "
1460 << getInputs().size() <<
" and " << blockArgs.size();
1463 for (
const auto &[bbArgType, inputArg] :
1464 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1465 auto inputElemType =
1466 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1467 if (bbArgType != inputElemType) {
1468 return emitOpError() <<
"expected element type of input " << inputElemType
1469 <<
" to match bbArg type " << bbArgType;
1474 auto outputShape = getInit().getType().getShape();
1476 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1477 if (inputElemShape != outputShape) {
1478 return emitOpError() <<
"expected shape of input (" << inputElemShape
1479 <<
") to match shape of output (" << outputShape
1488 int64_t rank = getInit().getType().getRank();
1492 ArrayAttr MapOp::getIndexingMaps() {
1494 int64_t rank = getInit().getType().getRank();
1495 int64_t numIndexingMaps = getOperands().size();
1500 void MapOp::getEffects(
1510 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1512 for (
Value v : getRegionInputArgs())
1514 for (
Value v : getRegionOutputArgs())
1515 setNameFn(v,
"init");
1518 void ReduceOp::getAsmResultNames(
1520 if (!getResults().empty())
1521 setNameFn(getResults().front(),
"reduced");
1524 void ReduceOp::build(
1529 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1533 for (
Value init : inits) {
1535 if (llvm::isa<RankedTensorType>(initType))
1541 inputs, inits, bodyBuild);
1546 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1548 utils::IteratorType::parallel);
1549 for (int64_t reductionDim : getDimensions())
1550 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1551 return iteratorTypes;
1554 ArrayAttr ReduceOp::getIndexingMaps() {
1556 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1563 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1564 affineMaps.push_back(resultMap);
1568 void ReduceOp::getEffects(
1576 StringRef attributeName) {
1585 std::optional<OperationName> payloadOpName;
1589 if (failed(operationName))
1593 payloadOpName = operationName.value();
1604 if (payloadOpName.has_value()) {
1624 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1628 Block *mapper = getBody();
1643 [&](
auto arg) { p.printRegionArgument(arg); });
1654 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1655 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1656 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1657 return emitOpError() <<
"expects all inputs to have the same shapes. "
1658 "Shape at input-index "
1660 <<
" is not equal to the shape at input-index 0.";
1663 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1664 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1665 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1666 return emitOpError() <<
"expects all outputs to have the same shapes. "
1667 "Shape at output-index "
1669 <<
" is not equal to the shape at output-index 0.";
1672 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1673 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1676 for (int64_t dimension : dimensionsRef) {
1677 if (dimension < 0 || dimension >= inputType.getRank()) {
1678 return emitOpError()
1679 <<
"dimensions for reduction should be in the range [0, "
1680 << inputType.getRank() - 1 <<
"].";
1682 dimensionsToReduce.insert(dimension);
1685 auto inputDims = inputType.getShape();
1686 auto initDims = initType.getShape();
1691 if (!dimensionsToReduce.count(en.index()))
1692 reducedInputDims.push_back(en.value());
1695 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1696 return emitOpError() <<
"number of dimensions after reduction "
1697 << reducedInputDims.size()
1698 <<
" doesn't match the init rank "
1699 << initType.getRank();
1702 if (reducedInputDims != initDims)
1703 return emitOpError() <<
"init dimensions [" << initDims
1704 <<
"] doesn't match input dimensions after reduction ["
1705 << reducedInputDims <<
"]";
1707 Block *block = getBody();
1709 return emitOpError()
1710 <<
"mismatching number of operands and block arguments";
1713 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1714 Type inputElementType =
1715 llvm::cast<ShapedType>(input.getType()).getElementType();
1716 if (inputElementType != bbArg.getType())
1717 return emitOpError()
1718 <<
"input element type " << inputElementType
1719 <<
" does not match corresponding block argument type "
1724 for (
auto [output, bbArg] : llvm::zip(
1725 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1726 auto outputElementType =
1727 llvm::cast<ShapedType>(output.getType()).getElementType();
1728 if (outputElementType != bbArg.getType())
1729 return emitOpError()
1730 <<
"output element type " << outputElementType
1731 <<
" does not match corresponding block argument type "
1747 b.
create<linalg::YieldOp>(loc, args[0]);
1762 if (llvm::isa<RankedTensorType>(initType))
1791 void TransposeOp::getAsmResultNames(
1793 if (!getResults().empty())
1794 setNameFn(getResults().front(),
"transposed");
1807 return emitOpError(
"permutation is not valid");
1809 auto inputType = getInput().getType();
1810 auto initType = getInit().getType();
1812 int64_t rank = inputType.getRank();
1814 if (rank != initType.getRank())
1815 return emitOpError() <<
"input rank " << rank
1816 <<
" does not match init rank " << initType.getRank();
1818 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1819 return emitOpError() <<
"size of permutation " << permutationRef.size()
1820 <<
" does not match the argument rank " << rank;
1822 auto inputDims = inputType.getShape();
1823 auto initDims = initType.getShape();
1825 for (int64_t i = 0; i < rank; ++i) {
1826 int64_t inputDim = inputDims[permutationRef[i]];
1827 int64_t initDim = initDims[i];
1829 if (inputDim != initDim) {
1830 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1831 <<
" doesn't match dim(input, permutation[" << i
1832 <<
"]) = " << inputDim;
1840 int64_t rank = getInit().getType().getRank();
1844 ArrayAttr TransposeOp::getIndexingMaps() {
1846 int64_t rank = getInit().getType().getRank();
1849 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1853 void TransposeOp::getEffects(
1859 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1862 if (getPermutation().size() == 0) {
1863 result.push_back(getInput());
1868 result.push_back(getInput());
1881 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1882 if (!defTransposeOp)
1887 foldedPerms.reserve(perms.size());
1888 for (int64_t perm : perms)
1889 foldedPerms.push_back(defPerms[perm]);
1892 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1906 Value input = transposeOp.getInput();
1907 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
1918 unsigned dimensionSize = dimensions.size();
1919 for (
unsigned i = 0; i < dimensionSize; ++i)
1920 resultDimensions.push_back(invertPerm[dimensions[i]]);
1923 Value broadcastInput = broadcastOp.getInput();
1924 Location loc = transposeOp.getLoc();
1927 auto broadcastInputTy =
1928 mlir::cast<RankedTensorType>(broadcastInput.
getType());
1929 unsigned inputRank = broadcastInputTy.getRank();
1930 for (
unsigned i = 0; i < inputRank; ++i) {
1931 if (broadcastInputTy.isDynamicDim(i)) {
1932 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
1936 broadcastInputTy.getDimSize(i)));
1941 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
1942 transposeOp.getLoc(), transposeResultShapes,
1943 broadcastInputTy.getElementType());
1946 Value transposeResult =
1948 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
1952 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
1977 if (llvm::isa<RankedTensorType>(initType))
2006 void BroadcastOp::getAsmResultNames(
2008 if (!getResults().empty())
2009 setNameFn(getResults().front(),
"broadcasted");
2021 auto inputType = getInput().getType();
2022 auto initType = getInit().getType();
2024 int64_t inputRank = inputType.getRank();
2025 int64_t initRank = initType.getRank();
2027 auto inputShape = inputType.getShape();
2028 auto initShape = initType.getShape();
2030 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2031 return emitOpError() <<
"input rank plus added dimensions does not "
2032 "match init rank. input rank: "
2034 <<
", dimensions size: " << dimensionsRef.size()
2035 <<
", init rank: " << initRank;
2038 if (dim < 0 || dim >= initRank)
2039 return emitOpError() <<
"dimension " << idx
2040 <<
" is out of range. expected range: [0, "
2041 << initRank - 1 <<
"], got: " << dim;
2046 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2047 if (!llvm::is_contained(dimensionsRef, dim))
2048 dimMap.push_back(dim);
2051 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2054 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2055 return emitOpError() <<
"input dim " << inputDimIdx
2056 <<
" should match init dim " << initDimIdx
2057 <<
". input: " << inputShape[inputDimIdx]
2058 <<
", init: " << initShape[initDimIdx];
2065 int64_t rank = getInit().getType().getRank();
2069 ArrayAttr BroadcastOp::getIndexingMaps() {
2071 int64_t rank = getInit().getType().getRank();
2077 void BroadcastOp::getEffects(
2085 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2093 if (getNumOperands() > 0)
2094 p <<
' ' << getOperands();
2096 if (getNumOperands() > 0)
2097 p <<
" : " << getOperandTypes();
2112 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2114 return op.
emitOpError(
"expected number of yield values (")
2116 <<
") to match the number of inits / outs operands of the enclosing "
2117 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2121 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2123 if (isa<MemRefType, RankedTensorType>(elementType))
2125 if (opOperand.get().getType() != elementType)
2127 << (opOperand.getOperandNumber() + 1) <<
" ("
2128 << opOperand.get().getType() <<
") doesn't match "
2129 <<
"the element type of the enclosing linalg.generic op ("
2130 << elementType <<
")";
2136 auto *parentOp = (*this)->getParentOp();
2137 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2138 return emitOpError(
"expected single non-empty parent region");
2140 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2143 return emitOpError(
"expected parent op with LinalgOp interface");
2151 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2153 return emitOpError(
"expected parent op with LinalgOp interface");
2154 if (linalgOp.getNumLoops() <= getDim())
2155 return emitOpError(
"expected dim (")
2156 << getDim() <<
") to be lower than the number of loops ("
2157 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2163 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2165 #define GET_OP_CLASSES
2166 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2168 #define GET_OP_CLASSES
2169 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2186 for (
unsigned i = 0; i < num; ++i)
2193 auto rangeA = llvm::make_range(a.begin(), a.end());
2194 auto rangeB = llvm::make_range(b.begin(), b.end());
2195 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2196 return llvm::to_vector<4>(concatRanges);
2200 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2202 for (
auto size : memref.getShape())
2209 if (
auto as = memref.getMemorySpace()) {
2210 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2211 ss <<
"as" << attr.getInt();
2217 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2220 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2233 assert(isa<LinalgOp>(op));
2235 std::string fun =
"";
2237 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2238 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2239 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2240 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2244 std::replace(name.begin(), name.end(),
'.',
'_');
2245 llvm::raw_string_ostream ss(name);
2249 return std::string();
2252 std::string res = ss.str();
2265 LogicalResult matchAndRewrite(LinalgOp op,
2271 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2274 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2285 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2288 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2293 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2300 if (castOp->getBlock() != linalgOp->getBlock())
2307 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2310 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2316 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2318 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2321 linalgOp.getDpsInits().end());
2322 outputOperands[resultNumber] = newOperand;
2323 newOperands.append(outputOperands.begin(), outputOperands.end());
2326 linalgOp->result_type_end());
2327 resultTypes[resultNumber] = resultType;
2328 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2335 results[resultNumber] = castBack;
2347 if (linalgOp.isScalar(&opOperand))
2349 Value src = opOperand.get();
2350 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2351 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2359 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2360 Value castSource = castOp.getSource();
2361 auto castSourceType =
2362 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2363 if (castSourceType && castSourceType.hasStaticShape())
2364 sourceShape = castSourceType.getShape();
2370 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2371 if (sourceType.isDynamicDim(i))
2373 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2374 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2384 static void createNewOperandWithStaticSizes(
2388 bool &changeNeeded) {
2390 newOperands.push_back(src);
2391 if (linalgOp.isScalar(opOperand))
2393 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2394 Type resultType = sourceType;
2395 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2396 resultTypes.push_back(resultType);
2400 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2404 bool newOperandNeeded =
false;
2405 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2406 int64_t dimShape = sourceShape[i];
2408 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2409 newShape.push_back(dimShape);
2415 newShape.push_back(affineExprToSize[dimExpr]);
2416 newOperandNeeded =
true;
2419 if (newOperandNeeded) {
2420 changeNeeded =
true;
2423 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2425 newOperands[index] = newOperand;
2427 if (linalgOp.isDpsInit(opOperand))
2428 resultTypes.push_back(resultType);
2437 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2439 if (!linalgOp.hasPureTensorSemantics())
2443 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2444 return !map.isProjectedPermutation();
2454 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2461 bool changeNeeded =
false;
2462 newOperands.reserve(linalgOp->getNumOperands());
2463 resultTypes.reserve(linalgOp.getNumDpsInits());
2466 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2467 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2468 affineExprToSize, linalgOp, newOperands,
2469 resultTypes, changeNeeded);
2478 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2481 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2482 Value newResult = std::get<1>(it);
2483 Value oldResult = std::get<0>(it);
2486 replacements.push_back(
2487 (newType != oldType)
2488 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2491 rewriter.
replaceOp(linalgOp, replacements);
2506 ShapedType inputType = getInputOperandType();
2507 ShapedType outputType = getOutputOperandType();
2512 return emitOpError(
"incompatible output shape");
2514 int64_t inputRank = getInputOperandRank();
2515 int64_t dimension = getDimension();
2516 if ((dimension < 0) || (dimension >= inputRank))
2517 return emitOpError(
"incorrect dimension specified");
2523 int64_t operandRank = getInputOperandRank();
2526 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2527 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2528 Value source = getInput();
2529 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2530 loopBounds[dim].offset = zero;
2531 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2532 loopBounds[dim].stride = one;
2539 utils::IteratorType::parallel);
2540 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2541 return iteratorTypes;
2544 FailureOr<TilingResult>
2545 SoftmaxOp::getTiledImplementation(
OpBuilder &builder,
2548 int64_t rank = getInputOperandRank();
2552 tiledOperands.emplace_back(
2553 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2554 tiledOperands.emplace_back(
2555 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2558 if (hasPureTensorSemantics())
2559 resultTypes.push_back(tiledOperands[1].
getType());
2561 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2566 LogicalResult SoftmaxOp::getResultTilePosition(
2570 if (resultNumber == 0) {
2571 resultOffsets.assign(offsets.begin(), offsets.end());
2572 resultSizes.assign(sizes.begin(), sizes.end());
2587 Location loc = getOperation()->getLoc();
2589 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2590 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2591 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2592 if (!outputShapedType.isDynamicDim(dim)) {
2594 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2601 reifiedReturnShapes.emplace_back(std::move(shapes));
2605 void SoftmaxOp::getEffects(
2609 if (!llvm::isa<MemRefType>(operand.
getType()))
2612 &getOperation()->getOpOperand(index), 0,
2617 for (
OpOperand &operand : getDpsInitsMutable()) {
2618 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2651 int64_t dim,
bool allParallel =
false) {
2653 utils::IteratorType::parallel);
2655 iteratorTypes[dim] = utils::IteratorType::reduction;
2659 for (
int i = 0; i < inputRank; i++) {
2666 return std::make_tuple(iteratorTypes, indexingMaps);
2671 template <
typename T>
2674 auto inputType = cast<ShapedType>(input.
getType());
2676 int64_t inputRank = inputShape.size();
2677 auto [iteratorTypes, indexingMaps] =
2679 assert(indexingMaps.size() == 2 &&
2680 "We should have two maps: 1 for the input, 1 for the output");
2681 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2683 auto genericOp = builder.
create<linalg::GenericOp>(
2684 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2686 Value result = b.create<T>(loc, args[0], args[1]);
2687 b.create<linalg::YieldOp>(loc, result);
2697 auto inputType = cast<ShapedType>(input.
getType());
2699 int64_t inputRank = inputShape.size();
2701 builder, inputRank, dim,
true);
2702 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2703 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2705 indexingMaps.push_back(indexingMaps[0]);
2706 auto genericOp = builder.
create<linalg::GenericOp>(
2709 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2710 Value result = b.create<math::ExpOp>(loc, diff);
2711 b.create<linalg::YieldOp>(loc, result);
2722 Value denominator,
Value output, int64_t dim) {
2723 auto inputType = cast<ShapedType>(numerator.
getType());
2725 int64_t inputRank = inputShape.size();
2727 builder, inputRank, dim,
true);
2728 assert(indexingMaps.size() == 2 &&
2729 "We should have one map for each input (2)");
2730 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2732 indexingMaps.push_back(indexingMaps[0]);
2733 auto genericOp = builder.
create<linalg::GenericOp>(
2735 indexingMaps, iteratorTypes,
2737 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2738 b.create<linalg::YieldOp>(loc, result);
2762 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2766 Value input = getInput();
2767 ShapedType inputType = getInputOperandType();
2768 Type elementType = inputType.getElementType();
2769 int64_t reductionDim = getDimension();
2771 Value output = getOutput();
2772 dims.erase(dims.begin() + reductionDim);
2774 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2776 elementType, b, loc,
2778 Value neutralForMaxFInit =
2779 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2782 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2791 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2793 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2797 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2806 auto filterType = cast<ShapedType>(getFilter().
getType());
2808 int64_t filterH = filterShape[1];
2809 int64_t filterW = filterShape[2];
2813 if (filterH != r && filterH != 1)
2814 return emitOpError(
"expect filter height either equals to r or 1");
2815 if (filterW != r && filterW != 1)
2816 return emitOpError(
"expect filter width either equals to r or 1");
2817 if (filterH == 1 && filterW == 1)
2818 return emitOpError(
"expect either filter height or width equals to r");
2821 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2822 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2823 expectedOutputShape.push_back(filterShape[3]);
2824 expectedOutputShape.push_back(filterShape[0]);
2826 auto outputType = cast<ShapedType>(getOutput().
getType());
2829 return emitOpError(
"the output shape is not expected");
2839 auto inputType = cast<ShapedType>(getInput().
getType());
2841 int64_t inputH = inputShape[1];
2842 int64_t inputW = inputShape[2];
2845 int64_t tileSize = m + r - 1;
2846 bool leftTransform = inputH != 1;
2847 bool rightTransform = inputW != 1;
2850 if (ShapedType::isDynamic(inputH)) {
2851 expectedOutputShape[0] = tileSize;
2852 expectedOutputShape[2] = ShapedType::kDynamic;
2854 expectedOutputShape[0] = leftTransform ? tileSize : 1;
2855 expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
2857 if (ShapedType::isDynamic(inputW)) {
2858 expectedOutputShape[1] = tileSize;
2859 expectedOutputShape[3] = ShapedType::kDynamic;
2861 expectedOutputShape[1] = rightTransform ? tileSize : 1;
2862 expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
2864 expectedOutputShape[4] = inputShape[0];
2865 expectedOutputShape[5] = inputShape[3];
2867 auto outputType = cast<ShapedType>(getOutput().
getType());
2870 return emitOpError(
"the output shape is not expected");
2880 auto valueType = cast<ShapedType>(getValue().
getType());
2882 int64_t valueH = valueShape[0];
2883 int64_t valueW = valueShape[1];
2884 int64_t valueTileH = valueShape[2];
2885 int64_t valueTileW = valueShape[3];
2888 bool leftTransform = valueH != 1;
2889 bool rightTransform = valueW != 1;
2892 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
2893 expectedOutputShape[1] = ShapedType::kDynamic;
2895 if (valueH != (leftTransform ? m + r - 1 : 1))
2896 return emitOpError(
"expect input height equals to input tile size");
2897 expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
2899 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
2900 expectedOutputShape[2] = ShapedType::kDynamic;
2902 if (valueW != (rightTransform ? m + r - 1 : 1))
2903 return emitOpError(
"expect input width equals to input tile size");
2904 expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
2906 expectedOutputShape[0] = valueShape[4];
2907 expectedOutputShape[3] = valueShape[5];
2909 auto outputType = cast<ShapedType>(getOutput().
getType());
2912 return emitOpError(
"the output shape is not expected");
2921 void LinalgDialect::getCanonicalizationPatterns(
2923 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2930 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.