39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/SmallSet.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "llvm/Support/MathExtras.h"
45 #include "llvm/Support/raw_ostream.h"
54 auto type = cast<ShapedType>(v.
getType());
55 if (!type.isDynamicDim(dim))
60 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
61 return builder.create<tensor::DimOp>(loc, v, dim);
63 .Case<MemRefType>([&](MemRefType t) ->
Value {
64 return builder.create<memref::DimOp>(loc, v, dim);
75 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
76 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
79 .Case<MemRefType>([&](MemRefType type) ->
Value {
80 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
83 .Default([&](
Type t) {
return nullptr; });
92 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
94 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
96 llvm_unreachable(
"Expected MemRefType or TensorType");
101 auto shapedType = llvm::cast<ShapedType>(source.
getType());
102 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
124 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
128 for (
auto containers : {inputTypes, outputTypes}) {
129 for (
auto t : containers) {
141 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
145 regionBuilder(b, *body, attrs);
157 std::optional<TypeRange> resultTensorTypes,
164 if (!resultTensorTypes)
165 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
166 llvm::IsaPred<RankedTensorType>);
168 state.addOperands(inputs);
169 state.addOperands(outputs);
170 state.addTypes(derivedResultTypes);
171 state.addAttributes(attributes);
173 "operandSegmentSizes",
175 static_cast<int32_t>(outputs.size())}));
178 Region ®ion = *state.addRegion();
180 state.attributes.getAttrs(), regionBuilder);
189 bool addOperandSegmentSizes =
true) {
190 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
219 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
221 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
225 if (addOperandSegmentSizes) {
234 attrs.
append(
"operandSegmentSizes",
236 {static_cast<int32_t>(inputsOperands.size()),
237 static_cast<int32_t>(outputsOperands.size())}));
242 {static_cast<int32_t>(inputsOperands.size()),
243 static_cast<int32_t>(outputsOperands.size())}));
247 std::optional<RegisteredOperationName> info =
250 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
251 return parser.emitError(attrsLoc)
252 <<
"'" << result.name.getStringRef() <<
"' op ";
263 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
264 if (!outputs.empty())
265 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
276 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
279 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
280 "region expects {0} args, got {1}",
281 numRegionArgs, inputTypes.size() + outputTypes.size()));
300 unsigned numRegionArgs,
312 result.
addTypes(outputTensorsTypes);
314 std::unique_ptr<Region> region = std::make_unique<Region>();
326 if (resultTypes.empty())
335 {
"operandSegmentSizes",
338 "linalg.memoized_indexing_maps"});
375 class RegionBuilderHelper {
378 : builder(builder), block(block) {}
381 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
382 if (!isFloatingPoint(arg))
383 llvm_unreachable(
"unsupported non numeric type");
385 builder.setInsertionPointToEnd(&block);
388 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
390 return builder.create<math::LogOp>(arg.
getLoc(), arg);
392 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
394 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
396 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
398 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
399 case UnaryFn::reciprocal: {
401 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
402 ::cast<TypedAttr>(oneAttr));
403 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
406 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
408 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
410 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
411 case UnaryFn::square:
412 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
414 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
416 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
418 llvm_unreachable(
"unsupported unary function");
423 bool allComplex = isComplex(arg0) && isComplex(arg1);
424 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
425 bool allInteger = isInteger(arg0) && isInteger(arg1);
428 if (!allComplex && !allFloatingPoint && !allInteger)
429 llvm_unreachable(
"unsupported non numeric type");
431 builder.setInsertionPointToEnd(&block);
435 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
436 if (allFloatingPoint)
437 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
439 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
440 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
443 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
444 if (allFloatingPoint)
445 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
447 llvm_unreachable(
"unsupported operation: sub with bools");
448 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
451 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
452 if (allFloatingPoint)
453 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
455 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
456 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
459 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
460 if (allFloatingPoint)
461 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
463 llvm_unreachable(
"unsupported operation: div with bools");
464 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
465 case BinaryFn::div_unsigned:
466 if (!allInteger || allBool)
467 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
468 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
469 case BinaryFn::max_signed:
471 if (allFloatingPoint)
472 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
473 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
474 case BinaryFn::min_signed:
476 if (allFloatingPoint)
477 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
478 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
479 case BinaryFn::max_unsigned:
481 if (allFloatingPoint)
482 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
483 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
484 case BinaryFn::min_unsigned:
486 if (allFloatingPoint)
487 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
488 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
490 assert(allFloatingPoint);
491 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
493 llvm_unreachable(
"unsupported binary function");
501 bool tailFloatingPoint =
502 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
503 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
505 builder.setInsertionPointToEnd(&block);
507 case TernaryFn::select:
508 if (!headBool && !(tailFloatingPoint || tailInteger))
509 llvm_unreachable(
"unsupported non numeric type");
510 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
512 llvm_unreachable(
"unsupported ternary function");
518 case TypeFn::cast_signed:
519 return cast(toType, operand,
false);
520 case TypeFn::cast_unsigned:
521 return cast(toType, operand,
true);
523 llvm_unreachable(
"unsupported type conversion function");
528 builder.setInsertionPointToEnd(&block);
529 Location loc = builder.getUnknownLoc();
530 builder.create<YieldOp>(loc, values);
533 Value constant(
const std::string &value) {
535 builder.setInsertionPointToEnd(&block);
536 Location loc = builder.getUnknownLoc();
538 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
541 Value index(int64_t dim) {
543 builder.setInsertionPointToEnd(&block);
544 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
547 Type getIntegerType(
unsigned width) {
561 builder.setInsertionPointToEnd(&block);
562 auto loc = operand.
getLoc();
566 bool isComplex(
Value value) {
567 return llvm::isa<ComplexType>(value.
getType());
569 bool isFloatingPoint(
Value value) {
570 return llvm::isa<FloatType>(value.
getType());
572 bool isInteger(
Value value) {
573 return llvm::isa<IntegerType>(value.
getType());
590 LogicalResult matchAndRewrite(CopyOp copyOp,
592 if (copyOp.getInputs() != copyOp.getOutputs())
594 if (copyOp.hasPureBufferSemantics())
597 rewriter.
replaceOp(copyOp, copyOp.getInputs());
607 results.
add<EraseSelfCopy>(context);
620 template <
typename TensorReshapeOp>
623 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
625 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
630 TensorReshapeOp newInit;
631 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
633 newInit = rewriter.
create<TensorReshapeOp>(
634 loc, reshapeOp.getResultType(), oldFill.output(),
635 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
636 reshapeOp.getStaticOutputShape());
638 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
640 reshapeOp.getReassociation());
653 LogicalResult matchAndRewrite(tensor::PadOp padOp,
655 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
661 Value padValue = padOp.getConstantPaddingValue();
662 if (!padValue || fillOp.value() != padValue)
668 padOp,
"failed to reify tensor.pad op result shape");
670 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
671 padOp.getLoc(), reifiedShape.front(),
672 padOp.getResultType().getElementType());
678 if (replacement.getType() != padOp.getResultType()) {
679 replacement = rewriter.
create<tensor::CastOp>(
680 fillOp.getLoc(), padOp.getResultType(), replacement);
690 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
693 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
695 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
699 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
704 Value firstDest = insertOp.getDest();
705 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
706 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
711 bool disjoint =
false;
712 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
715 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
716 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
717 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
721 int64_t prevStart = prevOp.getStaticOffset(i);
722 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
723 prevOp.getStaticStride(i);
724 int64_t nextStart = insertOp.getStaticOffset(i);
725 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
726 insertOp.getStaticStride(i);
727 if (prevEnd < nextStart || nextEnd < prevStart) {
735 firstDest = prevOp.getDest();
746 Value padValue = srcPadOp.getConstantPaddingValue();
747 if (!padValue || dstFillOp.value() != padValue)
763 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
765 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
768 RankedTensorType srcPadType = srcPadOp.getSourceType();
770 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
771 if (srcPadType.isDynamicDim(i)) {
773 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
776 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
781 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
782 newSizes, insertOp.getMixedStrides());
788 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
792 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
796 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
801 Value extractedScalar = fillOp.getInputs()[0];
804 rewriter.
replaceOp(extractOp, extractedScalar);
812 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
813 tensor::PackOp packOp) {
814 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
818 if (
auto paddingValue = packOp.getPaddingValue())
822 Value packOpDest = packOp.getDest();
826 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
836 LogicalResult matchAndRewrite(tensor::PackOp packOp,
838 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
841 rewriter.
replaceOp(packOp, fillOp.value().result());
850 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
852 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
855 copyOp.getOutputs());
858 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
860 fillOp.getOutputs());
871 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
873 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
875 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
876 transposeOp.getDpsInitOperand(0)->get());
888 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
890 auto concatOperands = concatOp.getInputs();
891 if (concatOperands.empty()) {
895 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
904 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
906 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
907 auto fillOp = v.getDefiningOp<linalg::FillOp>();
914 if (fillVal != firstFillVal)
917 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
920 if (!llvm::all_of(concatOperands.drop_front(),
921 isDefinedByCompatibleFillOp)) {
923 concatOp,
"not all operands are defined by a compatible fill op");
926 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
927 concatOp.getLoc(), concatOp.getDim(), allOuts);
929 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
938 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
939 FoldFillWithPack, FoldFillWithPad,
940 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
941 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
942 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
955 for (
ValueRange container : {inputs, outputs}) {
956 for (
Value v : container) {
957 Type t = v.getType();
958 blockArgTypes.push_back(
960 blockArgLocs.push_back(v.getLoc());
966 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
970 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
972 for (
Value v : getRegionInputArgs())
974 for (
Value v : getRegionOutputArgs())
978 void GenericOp::build(
981 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
984 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
985 iteratorTypes, doc, libraryCall);
989 inputs, outputs, bodyBuild);
992 void GenericOp::build(
996 StringRef libraryCall,
999 build(builder, result, resultTensorTypes, inputs, outputs,
1004 return IteratorTypeAttr::get(builder.getContext(), iter);
1007 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1008 bodyBuild, attributes);
1011 void GenericOp::build(
1015 StringRef libraryCall,
1018 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1019 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1022 void GenericOp::build(
1028 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1030 "", bodyBuild, attributes);
1033 void GenericOp::build(
1039 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1042 "", bodyBuild, attributes);
1049 auto genericAttrNames = linalgTraitAttrNames();
1052 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1054 for (
auto attr : (*this)->getAttrs()) {
1055 if (attr.getName() == getIteratorTypesAttrName()) {
1056 auto iteratorTypes =
1057 llvm::cast<ArrayAttr>(attr.getValue())
1058 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1064 llvm::to_vector(llvm::map_range(
1065 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1069 genericAttrs.emplace_back(
1070 getIteratorTypesAttrName(),
1072 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1073 genericAttrs.push_back(attr);
1076 if (!genericAttrs.empty()) {
1078 p << genericDictAttr;
1084 genericAttrNames.push_back(
"operandSegmentSizes");
1085 genericAttrNamesSet.insert(genericAttrNames.back());
1087 bool hasExtraAttrs =
false;
1089 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1092 if (hasExtraAttrs) {
1099 if (!getRegion().empty()) {
1109 DictionaryAttr dictAttr;
1118 dictAttr.getValue().end());
1124 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1126 if (!iteratorTypes) {
1127 return parser.
emitError(attributeLocation)
1128 <<
"expected " << getIteratorTypesAttrName(result.
name)
1129 <<
" array attribute";
1134 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1135 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1136 if (!maybeIteratorType.has_value())
1138 <<
"unexpected iterator_type (" << s <<
")";
1140 iteratorTypeAttrs.push_back(
1157 std::unique_ptr<Region> region = std::make_unique<Region>();
1169 result.
addTypes(outputTensorsTypes);
1177 LinalgOp linalgOp) {
1178 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1179 if (!llvm::isa<MemRefType>(operand.
getType()))
1181 effects.emplace_back(
1186 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1187 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1189 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1200 void GenericOp::getEffects(
1210 if (!linalgOp.hasPureTensorSemantics())
1229 template <
typename OpTy>
1233 LogicalResult matchAndRewrite(OpTy linalgOp,
1236 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1241 Block &body = linalgOp->getRegion(0).
front();
1242 if (!llvm::hasSingleElement(body))
1244 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1249 if (linalgOp.hasPureBufferSemantics()) {
1250 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1251 linalgOp.getDpsInputOperand(0)->get() ==
1252 linalgOp.getDpsInitOperand(0)->get()) {
1260 if (!linalgOp.hasPureTensorSemantics())
1267 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1268 if (!yieldArg || yieldArg.getOwner() != &body)
1270 unsigned argumentNumber = yieldArg.getArgNumber();
1271 Value returnedArg = linalgOp->getOperand(argumentNumber);
1272 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1276 if (returnType != resultType) {
1281 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1282 linalgOp.getLoc(), resultType, returnedArg);
1284 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1287 returnedArg = rewriter.
create<tensor::CastOp>(
1288 linalgOp.getLoc(), resultType, returnedArg);
1291 returnedArgs.push_back(returnedArg);
1294 if (returnedArgs.size() != linalgOp->getNumResults())
1296 rewriter.
replaceOp(linalgOp, returnedArgs);
1305 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1327 for (
Type outputType : outputTypes) {
1328 if (llvm::isa<RankedTensorType>(outputType))
1333 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1342 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1344 for (
Value v : getRegionInputArgs())
1349 if (!getResults().empty())
1350 setNameFn(getResults().front(),
"mapped");
1357 build(builder, result,
TypeRange{}, inputs, init);
1362 if (llvm::isa<RankedTensorType>(initType))
1367 inputs, {}, bodyBuild);
1374 bool initFirst =
false) {
1380 for (
auto &operand : operands) {
1382 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1389 payloadOpOperands.push_back(block.
getArguments().back());
1390 for (
const auto &arg : block.
getArguments().drop_back())
1391 payloadOpOperands.push_back(arg);
1400 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1407 std::optional<OperationName> payloadOpName;
1411 if (failed(operationName))
1415 payloadOpName = operationName.value();
1423 if (payloadOpName.has_value()) {
1461 for (
const auto &[operand, bbArg] :
1463 if (bbArg != operand)
1467 for (
const auto &[operand, bbArg] :
1469 if (bbArg != operand)
1478 std::string attrToElide;
1480 for (
const auto &attr : payloadOp->
getAttrs()) {
1482 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1483 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1484 attrToElide = attr.getName().str();
1485 elidedAttrs.push_back(attrToElide);
1494 Block *mapper = getBody();
1509 [&](
auto arg) { p.printRegionArgument(arg); });
1518 auto *bodyBlock = getBody();
1519 auto blockArgs = bodyBlock->getArguments();
1522 if (getInputs().size() != blockArgs.size())
1523 return emitOpError() <<
"expects number of operands to match the arity of "
1525 << getInputs().size() <<
" and " << blockArgs.size();
1528 for (
const auto &[bbArgType, inputArg] :
1529 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1530 auto inputElemType =
1531 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1532 if (bbArgType != inputElemType) {
1533 return emitOpError() <<
"expected element type of input " << inputElemType
1534 <<
" to match bbArg type " << bbArgType;
1539 auto outputShape = getInit().getType().getShape();
1541 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1542 if (inputElemShape != outputShape) {
1543 return emitOpError() <<
"expected shape of input (" << inputElemShape
1544 <<
") to match shape of output (" << outputShape
1553 int64_t rank = getInit().getType().getRank();
1557 ArrayAttr MapOp::getIndexingMaps() {
1559 int64_t rank = getInit().getType().getRank();
1560 int64_t numIndexingMaps = getOperands().size();
1565 void MapOp::getEffects(
1579 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1581 for (
Value v : getRegionInputArgs())
1583 for (
Value v : getRegionOutputArgs())
1584 setNameFn(v,
"init");
1587 void ReduceOp::getAsmResultNames(
1589 if (!getResults().empty())
1590 setNameFn(getResults().front(),
"reduced");
1593 void ReduceOp::build(
1598 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1602 for (
Value init : inits) {
1604 if (llvm::isa<RankedTensorType>(initType))
1610 inputs, inits, bodyBuild);
1615 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1617 utils::IteratorType::parallel);
1618 for (int64_t reductionDim : getDimensions())
1619 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1620 return iteratorTypes;
1623 ArrayAttr ReduceOp::getIndexingMaps() {
1625 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1632 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1633 affineMaps.push_back(resultMap);
1637 void ReduceOp::getEffects(
1649 StringRef attributeName) {
1658 std::optional<OperationName> payloadOpName;
1662 if (failed(operationName))
1666 payloadOpName = operationName.value();
1677 if (payloadOpName.has_value()) {
1697 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1701 Block *mapper = getBody();
1716 [&](
auto arg) { p.printRegionArgument(arg); });
1727 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1728 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1729 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1730 return emitOpError() <<
"expects all inputs to have the same shapes. "
1731 "Shape at input-index "
1733 <<
" is not equal to the shape at input-index 0.";
1736 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1737 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1738 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1739 return emitOpError() <<
"expects all outputs to have the same shapes. "
1740 "Shape at output-index "
1742 <<
" is not equal to the shape at output-index 0.";
1745 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1746 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1749 for (int64_t dimension : dimensionsRef) {
1750 if (dimension < 0 || dimension >= inputType.getRank()) {
1751 return emitOpError()
1752 <<
"dimensions for reduction should be in the range [0, "
1753 << inputType.getRank() - 1 <<
"].";
1755 dimensionsToReduce.insert(dimension);
1758 auto inputDims = inputType.getShape();
1759 auto initDims = initType.getShape();
1764 if (!dimensionsToReduce.count(en.index()))
1765 reducedInputDims.push_back(en.value());
1768 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1769 return emitOpError() <<
"number of dimensions after reduction "
1770 << reducedInputDims.size()
1771 <<
" doesn't match the init rank "
1772 << initType.getRank();
1775 if (reducedInputDims != initDims)
1776 return emitOpError() <<
"init dimensions [" << initDims
1777 <<
"] doesn't match input dimensions after reduction ["
1778 << reducedInputDims <<
"]";
1780 Block *block = getBody();
1782 return emitOpError()
1783 <<
"mismatching number of operands and block arguments";
1786 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1787 Type inputElementType =
1788 llvm::cast<ShapedType>(input.getType()).getElementType();
1789 if (inputElementType != bbArg.getType())
1790 return emitOpError()
1791 <<
"input element type " << inputElementType
1792 <<
" does not match corresponding block argument type "
1797 for (
auto [output, bbArg] : llvm::zip(
1798 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1799 auto outputElementType =
1800 llvm::cast<ShapedType>(output.getType()).getElementType();
1801 if (outputElementType != bbArg.getType())
1802 return emitOpError()
1803 <<
"output element type " << outputElementType
1804 <<
" does not match corresponding block argument type "
1820 b.
create<linalg::YieldOp>(loc, args[0]);
1835 if (llvm::isa<RankedTensorType>(initType))
1864 void TransposeOp::getAsmResultNames(
1866 if (!getResults().empty())
1867 setNameFn(getResults().front(),
"transposed");
1880 return emitOpError(
"permutation is not valid");
1882 auto inputType = getInput().getType();
1883 auto initType = getInit().getType();
1885 int64_t rank = inputType.getRank();
1887 if (rank != initType.getRank())
1888 return emitOpError() <<
"input rank " << rank
1889 <<
" does not match init rank " << initType.getRank();
1891 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1892 return emitOpError() <<
"size of permutation " << permutationRef.size()
1893 <<
" does not match the argument rank " << rank;
1895 auto inputDims = inputType.getShape();
1896 auto initDims = initType.getShape();
1898 for (int64_t i = 0; i < rank; ++i) {
1899 int64_t inputDim = inputDims[permutationRef[i]];
1900 int64_t initDim = initDims[i];
1902 if (inputDim != initDim) {
1903 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1904 <<
" doesn't match dim(input, permutation[" << i
1905 <<
"]) = " << inputDim;
1913 int64_t rank = getInit().getType().getRank();
1917 ArrayAttr TransposeOp::getIndexingMaps() {
1919 int64_t rank = getInit().getType().getRank();
1922 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1926 void TransposeOp::getEffects(
1936 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1939 if (!isa<TensorType>(getInput().
getType()))
1943 if (getPermutation().size() == 0) {
1944 result.push_back(getInput());
1949 result.push_back(getInput());
1962 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1963 if (!defTransposeOp)
1968 foldedPerms.reserve(perms.size());
1969 for (int64_t perm : perms)
1970 foldedPerms.push_back(defPerms[perm]);
1973 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1987 Value input = transposeOp.getInput();
1988 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
1999 unsigned dimensionSize = dimensions.size();
2000 for (
unsigned i = 0; i < dimensionSize; ++i)
2001 resultDimensions.push_back(invertPerm[dimensions[i]]);
2004 Value broadcastInput = broadcastOp.getInput();
2005 Location loc = transposeOp.getLoc();
2008 auto broadcastInputTy =
2009 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2010 unsigned inputRank = broadcastInputTy.getRank();
2011 for (
unsigned i = 0; i < inputRank; ++i) {
2012 if (broadcastInputTy.isDynamicDim(i)) {
2013 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2017 broadcastInputTy.getDimSize(i)));
2022 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2023 transposeOp.getLoc(), transposeResultShapes,
2024 broadcastInputTy.getElementType());
2027 Value transposeResult =
2029 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2033 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2058 if (llvm::isa<RankedTensorType>(initType))
2087 void BroadcastOp::getAsmResultNames(
2089 if (!getResults().empty())
2090 setNameFn(getResults().front(),
"broadcasted");
2102 auto inputType = getInput().getType();
2103 auto initType = getInit().getType();
2105 int64_t inputRank = inputType.getRank();
2106 int64_t initRank = initType.getRank();
2108 auto inputShape = inputType.getShape();
2109 auto initShape = initType.getShape();
2111 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2112 return emitOpError() <<
"input rank plus added dimensions does not "
2113 "match init rank. input rank: "
2115 <<
", dimensions size: " << dimensionsRef.size()
2116 <<
", init rank: " << initRank;
2119 if (dim < 0 || dim >= initRank)
2120 return emitOpError() <<
"dimension " << idx
2121 <<
" is out of range. expected range: [0, "
2122 << initRank - 1 <<
"], got: " << dim;
2127 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2128 if (!llvm::is_contained(dimensionsRef, dim))
2129 dimMap.push_back(dim);
2132 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2135 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2136 return emitOpError() <<
"input dim " << inputDimIdx
2137 <<
" should match init dim " << initDimIdx
2138 <<
". input: " << inputShape[inputDimIdx]
2139 <<
", init: " << initShape[initDimIdx];
2146 int64_t rank = getInit().getType().getRank();
2150 ArrayAttr BroadcastOp::getIndexingMaps() {
2152 int64_t rank = getInit().getType().getRank();
2158 void BroadcastOp::getEffects(
2170 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2178 if (getNumOperands() > 0)
2179 p <<
' ' << getOperands();
2181 if (getNumOperands() > 0)
2182 p <<
" : " << getOperandTypes();
2197 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2199 return op.
emitOpError(
"expected number of yield values (")
2201 <<
") to match the number of inits / outs operands of the enclosing "
2202 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2206 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2208 if (isa<MemRefType, RankedTensorType>(elementType))
2210 if (opOperand.get().getType() != elementType)
2212 << (opOperand.getOperandNumber() + 1) <<
" ("
2213 << opOperand.get().getType() <<
") doesn't match "
2214 <<
"the element type of the enclosing linalg.generic op ("
2215 << elementType <<
")";
2221 auto *parentOp = (*this)->getParentOp();
2222 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2223 return emitOpError(
"expected single non-empty parent region");
2225 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2228 return emitOpError(
"expected parent op with LinalgOp interface");
2236 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2238 return emitOpError(
"expected parent op with LinalgOp interface");
2239 if (linalgOp.getNumLoops() <= getDim())
2240 return emitOpError(
"expected dim (")
2241 << getDim() <<
") to be lower than the number of loops ("
2242 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2248 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2250 #define GET_OP_CLASSES
2251 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2253 #define GET_OP_CLASSES
2254 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2271 for (
unsigned i = 0; i < num; ++i)
2278 auto rangeA = llvm::make_range(a.begin(), a.end());
2279 auto rangeB = llvm::make_range(b.begin(), b.end());
2280 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2281 return llvm::to_vector<4>(concatRanges);
2285 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2287 for (
auto size : memref.getShape())
2294 if (
auto as = memref.getMemorySpace()) {
2295 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2296 ss <<
"as" << attr.getInt();
2302 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2305 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2318 assert(isa<LinalgOp>(op));
2320 std::string fun =
"";
2322 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2323 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2324 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2325 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2329 std::replace(name.begin(), name.end(),
'.',
'_');
2330 llvm::raw_string_ostream ss(name);
2334 return std::string();
2337 std::string res = ss.str();
2350 LogicalResult matchAndRewrite(LinalgOp op,
2356 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2359 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2370 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2373 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2378 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2385 if (castOp->getBlock() != linalgOp->getBlock())
2392 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2395 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2401 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2403 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2406 linalgOp.getDpsInits().end());
2407 outputOperands[resultNumber] = newOperand;
2408 newOperands.append(outputOperands.begin(), outputOperands.end());
2411 linalgOp->result_type_end());
2412 resultTypes[resultNumber] = resultType;
2413 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2420 results[resultNumber] = castBack;
2432 if (linalgOp.isScalar(&opOperand))
2434 Value src = opOperand.get();
2435 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2436 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2444 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2445 Value castSource = castOp.getSource();
2446 auto castSourceType =
2447 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2448 if (castSourceType && castSourceType.hasStaticShape())
2449 sourceShape = castSourceType.getShape();
2455 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2456 if (sourceType.isDynamicDim(i))
2458 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2459 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2469 static void createNewOperandWithStaticSizes(
2473 bool &changeNeeded) {
2475 newOperands.push_back(src);
2476 if (linalgOp.isScalar(opOperand))
2478 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2479 Type resultType = sourceType;
2480 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2481 resultTypes.push_back(resultType);
2485 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2489 bool newOperandNeeded =
false;
2490 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2491 int64_t dimShape = sourceShape[i];
2493 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2494 newShape.push_back(dimShape);
2500 newShape.push_back(affineExprToSize[dimExpr]);
2501 newOperandNeeded =
true;
2504 if (newOperandNeeded) {
2505 changeNeeded =
true;
2508 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2510 newOperands[index] = newOperand;
2512 if (linalgOp.isDpsInit(opOperand))
2513 resultTypes.push_back(resultType);
2522 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2524 if (!linalgOp.hasPureTensorSemantics())
2528 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2529 return !map.isProjectedPermutation();
2539 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2546 bool changeNeeded =
false;
2547 newOperands.reserve(linalgOp->getNumOperands());
2548 resultTypes.reserve(linalgOp.getNumDpsInits());
2551 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2552 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2553 affineExprToSize, linalgOp, newOperands,
2554 resultTypes, changeNeeded);
2563 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2566 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2567 Value newResult = std::get<1>(it);
2568 Value oldResult = std::get<0>(it);
2571 replacements.push_back(
2572 (newType != oldType)
2573 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2576 rewriter.
replaceOp(linalgOp, replacements);
2591 ShapedType inputType = getInputOperandType();
2592 ShapedType outputType = getOutputOperandType();
2597 return emitOpError(
"incompatible output shape");
2599 int64_t inputRank = getInputOperandRank();
2600 int64_t dimension = getDimension();
2601 if ((dimension < 0) || (dimension >= inputRank))
2602 return emitOpError(
"incorrect dimension specified");
2608 int64_t operandRank = getInputOperandRank();
2611 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2612 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2613 Value source = getInput();
2614 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2615 loopBounds[dim].offset = zero;
2616 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2617 loopBounds[dim].stride = one;
2624 utils::IteratorType::parallel);
2625 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2626 return iteratorTypes;
2629 FailureOr<TilingResult>
2630 SoftmaxOp::getTiledImplementation(
OpBuilder &builder,
2633 int64_t rank = getInputOperandRank();
2637 tiledOperands.emplace_back(
2638 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2639 tiledOperands.emplace_back(
2640 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2643 if (hasPureTensorSemantics())
2644 resultTypes.push_back(tiledOperands[1].
getType());
2646 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2651 LogicalResult SoftmaxOp::getResultTilePosition(
2655 if (resultNumber == 0) {
2656 resultOffsets.assign(offsets.begin(), offsets.end());
2657 resultSizes.assign(sizes.begin(), sizes.end());
2672 Location loc = getOperation()->getLoc();
2674 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2675 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2676 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2677 if (!outputShapedType.isDynamicDim(dim)) {
2679 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2686 reifiedReturnShapes.emplace_back(std::move(shapes));
2690 void SoftmaxOp::getEffects(
2694 if (!llvm::isa<MemRefType>(operand.
getType()))
2697 &getOperation()->getOpOperand(index), 0,
2702 for (
OpOperand &operand : getDpsInitsMutable()) {
2703 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2736 int64_t dim,
bool allParallel =
false) {
2738 utils::IteratorType::parallel);
2740 iteratorTypes[dim] = utils::IteratorType::reduction;
2744 for (
int i = 0; i < inputRank; i++) {
2751 return std::make_tuple(iteratorTypes, indexingMaps);
2756 template <
typename T>
2759 auto inputType = cast<ShapedType>(input.
getType());
2761 int64_t inputRank = inputShape.size();
2762 auto [iteratorTypes, indexingMaps] =
2764 assert(indexingMaps.size() == 2 &&
2765 "We should have two maps: 1 for the input, 1 for the output");
2766 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2768 auto genericOp = builder.
create<linalg::GenericOp>(
2769 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2771 Value result = b.create<T>(loc, args[0], args[1]);
2772 b.create<linalg::YieldOp>(loc, result);
2782 auto inputType = cast<ShapedType>(input.
getType());
2784 int64_t inputRank = inputShape.size();
2786 builder, inputRank, dim,
true);
2787 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2788 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2790 indexingMaps.push_back(indexingMaps[0]);
2791 auto genericOp = builder.
create<linalg::GenericOp>(
2794 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2795 Value result = b.create<math::ExpOp>(loc, diff);
2796 b.create<linalg::YieldOp>(loc, result);
2807 Value denominator,
Value output, int64_t dim) {
2808 auto inputType = cast<ShapedType>(numerator.
getType());
2810 int64_t inputRank = inputShape.size();
2812 builder, inputRank, dim,
true);
2813 assert(indexingMaps.size() == 2 &&
2814 "We should have one map for each input (2)");
2815 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2817 indexingMaps.push_back(indexingMaps[0]);
2818 auto genericOp = builder.
create<linalg::GenericOp>(
2820 indexingMaps, iteratorTypes,
2822 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2823 b.create<linalg::YieldOp>(loc, result);
2847 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2851 Value input = getInput();
2852 ShapedType inputType = getInputOperandType();
2853 Type elementType = inputType.getElementType();
2854 int64_t reductionDim = getDimension();
2856 Value output = getOutput();
2857 dims.erase(dims.begin() + reductionDim);
2859 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2861 elementType, b, loc,
2863 Value neutralForMaxFInit =
2864 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2867 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2876 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2878 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2882 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2891 auto filterType = cast<ShapedType>(getFilter().
getType());
2893 int64_t filterH = filterShape[getFilterHDim()];
2894 int64_t filterW = filterShape[getFilterWDim()];
2898 if (filterH != r && filterH != 1)
2899 return emitOpError(
"expect filter height either equals to r or 1");
2900 if (filterW != r && filterW != 1)
2901 return emitOpError(
"expect filter width either equals to r or 1");
2902 if (filterH == 1 && filterW == 1)
2903 return emitOpError(
"expect either filter height or width equals to r");
2906 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2907 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2908 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2909 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2911 auto outputType = cast<ShapedType>(getOutput().
getType());
2914 return emitOpError(
"the output shape is not expected");
2920 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2924 Value filter = getFilter();
2925 int64_t filterRank = getFilterOperandRank();
2927 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2928 loopBounds[dim].offset = zeroAttr;
2929 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
2930 loopBounds[dim].stride = oneAttr;
2936 WinogradFilterTransformOp::getLoopIteratorTypes() {
2937 int64_t filterRank = getFilterOperandRank();
2939 utils::IteratorType::parallel);
2940 return iteratorTypes;
2943 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2948 ShapedType filterType = getFilterOperandType();
2950 int64_t filterH = filterShape[getFilterHDim()];
2951 int64_t filterW = filterShape[getFilterWDim()];
2954 int64_t alpha = m + r - 1;
2955 int64_t alphaH = filterH != 1 ? alpha : 1;
2956 int64_t alphaW = filterW != 1 ? alpha : 1;
2960 resultOffsets.append(
2961 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2963 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2974 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
2979 ShapedType filterType = getFilterOperandType();
2981 int64_t filterH = filterShape[getFilterHDim()];
2982 int64_t filterW = filterShape[getFilterWDim()];
2988 sliceOffsets.append(
2989 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
2990 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
2991 sizes[getFilterCDim()]});
2992 int64_t filterRank = getFilterOperandRank();
2995 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
2996 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
2999 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3003 int64_t outputRank = getOutputOperandRank();
3005 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
3006 loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3009 resultTypes.push_back(tiledOperands[1].
getType());
3011 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3021 auto inputType = cast<ShapedType>(getInput().
getType());
3023 int64_t inputH = inputShape[getInputHDim()];
3024 int64_t inputW = inputShape[getInputWDim()];
3027 int64_t tileSize = m + r - 1;
3028 bool leftTransform = inputH != 1;
3029 bool rightTransform = inputW != 1;
3032 if (ShapedType::isDynamic(inputH)) {
3033 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3034 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3036 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3037 expectedOutputShape[getOutputTileHDim()] =
3038 leftTransform ? (inputH - (r - 1)) / m : 1;
3040 if (ShapedType::isDynamic(inputW)) {
3041 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3042 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3044 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3045 expectedOutputShape[getOutputTileWDim()] =
3046 rightTransform ? (inputW - (r - 1)) / m : 1;
3048 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3049 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3051 auto outputType = cast<ShapedType>(getOutput().
getType());
3054 return emitOpError(
"the output shape is not expected");
3060 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3064 Value output = getOutput();
3065 int64_t outputRank = getOutputOperandRank();
3067 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3068 loopBounds[dim].offset = zeroAttr;
3070 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3071 loopBounds[dim].stride = oneAttr;
3077 WinogradInputTransformOp::getLoopIteratorTypes() {
3078 int64_t outputRank = getOutputOperandRank();
3080 utils::IteratorType::parallel);
3081 return iteratorTypes;
3084 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3089 ShapedType inputType = getInputOperandType();
3091 int64_t inputH = inputShape[getInputHDim()];
3092 int64_t inputW = inputShape[getInputWDim()];
3095 int64_t alpha = m + r - 1;
3096 int64_t alphaH = inputH != 1 ? alpha : 1;
3097 int64_t alphaW = inputW != 1 ? alpha : 1;
3101 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3102 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3103 offsets[getOutputCDim()]});
3104 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3105 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3106 sizes[getOutputCDim()]});
3117 FailureOr<TilingResult>
3118 WinogradInputTransformOp::getTiledImplementation(
OpBuilder &builder,
3123 ShapedType inputType = getInputOperandType();
3125 int64_t inputH = inputShape[getInputHDim()];
3126 int64_t inputW = inputShape[getInputWDim()];
3132 auto offsetAffineMap =
3135 builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3137 builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3141 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3143 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3152 sliceOffsets.append(
3153 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3159 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3160 int64_t inputRank = getInputOperandRank();
3162 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
3163 loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
3166 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3170 int64_t outputRank = getOutputOperandRank();
3172 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
3173 loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3176 resultTypes.push_back(tiledOperands[1].
getType());
3178 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3188 auto valueType = cast<ShapedType>(getValue().
getType());
3190 int64_t valueH = valueShape[getValueAlphaHDim()];
3191 int64_t valueW = valueShape[getValueAlphaWDim()];
3192 int64_t valueTileH = valueShape[getValueTileHDim()];
3193 int64_t valueTileW = valueShape[getValueTileWDim()];
3196 bool leftTransform = valueH != 1;
3197 bool rightTransform = valueW != 1;
3199 int64_t outputRank = getOutputOperandRank();
3201 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3202 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3204 if (valueH != (leftTransform ? m + r - 1 : 1))
3205 return emitOpError(
"expect input height equals to input tile size");
3206 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3208 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3209 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3211 if (valueW != (rightTransform ? m + r - 1 : 1))
3212 return emitOpError(
"expect input width equals to input tile size");
3213 expectedOutputShape[getOutputWDim()] =
3214 (rightTransform ? m : 1) * valueTileW;
3216 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3217 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3219 auto outputType = cast<ShapedType>(getOutput().
getType());
3222 return emitOpError(
"the output shape is not expected");
3228 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3232 Value value = getValue();
3233 int64_t valueRank = getValueOperandRank();
3235 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3236 loopBounds[dim].offset = zeroAttr;
3238 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3239 loopBounds[dim].stride = oneAttr;
3245 WinogradOutputTransformOp::getLoopIteratorTypes() {
3246 int64_t valueRank = getValueOperandRank();
3248 utils::IteratorType::parallel);
3249 return iteratorTypes;
3252 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3264 builder, loc, affineMap, offsets[getValueTileHDim()]);
3266 builder, loc, affineMap, offsets[getValueTileWDim()]);
3268 builder, loc, affineMap, sizes[getValueTileHDim()]);
3270 builder, loc, affineMap, sizes[getValueTileWDim()]);
3272 ShapedType valueType = getValueOperandType();
3274 int64_t valueH = valueShape[0];
3275 int64_t valueW = valueShape[1];
3287 resultOffsets.append(
3288 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3290 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3300 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3309 ShapedType valueType = getValueOperandType();
3311 int64_t alphaH = valueShape[getValueAlphaHDim()];
3312 int64_t alphaW = valueShape[getValueAlphaWDim()];
3316 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3317 offsets[getValueTileWDim()], offsets[getValueNDim()],
3318 offsets[getValueFDim()]});
3319 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3320 sizes[getValueTileWDim()], sizes[getValueNDim()],
3321 sizes[getValueFDim()]});
3322 int64_t valueRank = getValueOperandRank();
3324 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
3325 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
3328 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3332 int64_t outputRank = getOutputOperandRank();
3334 tiledOperands.emplace_back(builder.
create<tensor::ExtractSliceOp>(
3335 loc, getOutput(), resultOffsets, resultSizes, strides));
3338 resultTypes.push_back(tiledOperands[1].
getType());
3340 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3349 void LinalgDialect::getCanonicalizationPatterns(
3351 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3358 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.