40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
61 auto type = cast<ShapedType>(v.
getType());
62 if (!type.isDynamicDim(dim))
67 .Case([&](RankedTensorType t) ->
Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
70 .Case([&](MemRefType t) ->
Value {
71 return memref::DimOp::create(builder, loc, v, dim);
82 .Case([&](RankedTensorType t) ->
Operation * {
83 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
86 .Case([&](MemRefType type) ->
Operation * {
87 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable(
"Expected MemRefType or TensorType");
108 auto shapedType = llvm::cast<ShapedType>(source.
getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
111 return b.getIndexAttr(shapedType.getDimSize(dim));
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
192 std::optional<TypeRange> resultTensorTypes,
199 return attr.
getName() ==
"indexing_maps";
202 indexingMapsAttrVal = llvm::map_to_vector(
205 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
208 attributes, regionBuilder);
212 std::optional<TypeRange> resultTensorTypes,
219 return attr.
getName() ==
"indexing_maps";
222 indexingMapsAttrVal = llvm::map_to_vector(
225 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
228 attributes, regionBuilder);
232 std::optional<TypeRange> resultTensorTypes,
239 indexingMapsAttrVal =
241 return AffineMapAttr::get(map);
243 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
245 attributes, regionBuilder);
254 bool addOperandSegmentSizes =
true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
284 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
286 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
290 if (addOperandSegmentSizes) {
297 if (
result.propertiesAttr) {
299 attrs.
append(
"operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
305 result.addAttribute(
"operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
311 if (!
result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
315 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 <<
"'" << result.name.getStringRef() <<
"' op ";
328 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
329 if (!outputs.empty())
330 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
344 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
352 opBuilder, region, inputTypes, outputTypes, attrs,
371 unsigned numRegionArgs,
388 result.addTypes(outputTensorsTypes);
390 std::unique_ptr<Region> region = std::make_unique<Region>();
392 outputTypes,
result.attributes.getAttrs(),
395 result.addRegion(std::move(region));
402 if (resultTypes.empty())
447class RegionBuilderHelper {
449 RegionBuilderHelper(OpBuilder &builder,
Block &block)
450 : builder(builder), block(block) {}
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
455 if (!isFloatingPoint(arg)) {
457 emitError() <<
"unsupported non numeric type";
460 llvm_unreachable(
"unsupported non numeric type");
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
466 return math::ExpOp::create(builder, arg.
getLoc(), arg);
468 return math::LogOp::create(builder, arg.
getLoc(), arg);
470 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
472 return math::CeilOp::create(builder, arg.
getLoc(), arg);
474 return math::FloorOp::create(builder, arg.
getLoc(), arg);
476 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.
getType());
479 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
484 return math::RoundOp::create(builder, arg.
getLoc(), arg);
486 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
488 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
492 return math::TanhOp::create(builder, arg.
getLoc(), arg);
494 return math::ErfOp::create(builder, arg.
getLoc(), arg);
497 emitError() <<
"unsupported unary function";
500 llvm_unreachable(
"unsupported unary function");
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
514 if (!allComplex && !allFloatingPoint && !allInteger) {
517 <<
"Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
522 llvm_unreachable(
"unsupported non numeric type");
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
529 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
533 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
537 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 emitError() <<
"unsupported operation: sub with bools";
545 llvm_unreachable(
"unsupported operation: sub with bools");
547 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
554 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
558 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
563 emitError() <<
"unsupported operation: div with bools";
566 llvm_unreachable(
"unsupported operation: div with bools");
568 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
572 emitError() <<
"unsupported operation: unsigned div not on uint";
575 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
577 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
590 if (!allInteger || allBool) {
592 emitError() <<
"unsupported operation: unsigned max not on uint";
595 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
597 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
600 if (!allInteger || allBool) {
602 emitError() <<
"unsupported operation: unsigned min not on uint";
605 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
607 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
613 emitError() <<
"unsupported binary function";
616 llvm_unreachable(
"unsupported binary function");
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
629 emitError() <<
"unsupported ternary function";
632 llvm_unreachable(
"unsupported ternary function");
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
639 case TypeFn::cast_signed:
640 return cast(toType, operand,
false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand,
true);
645 emitError() <<
"unsupported type conversion function";
648 llvm_unreachable(
"unsupported type conversion function");
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
658 Value constant(
const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr =
parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
673 Type getIntegerType(
unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
677 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
685 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.
getLoc();
689 if (isa<UnknownLoc>(loc)) {
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.
getType());
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.
getType());
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.
getType());
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter)
const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
727 if (copyOp.hasPureBufferSemantics())
730 rewriter.
replaceOp(copyOp, copyOp.getInputs());
740 results.
add<EraseSelfCopy>(context);
753template <
typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter)
const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
783struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter)
const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
801 padOp,
"failed to reify tensor.pad op result shape");
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
807 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
810 if (
replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
822struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter)
const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
836 Value firstDest = insertOp.getDest();
837 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
843 bool disjoint =
false;
844 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
867 firstDest = prevOp.getDest();
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
885 Location loc = insertOp.getLoc();
888 AffineExpr sym0, sym1;
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
908 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
920struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter)
const override {
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
933 Value extractedScalar = fillOp.getInputs()[0];
936 rewriter.
replaceOp(extractOp, extractedScalar);
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
950 if (
auto paddingValue = packOp.getPaddingValue())
954 Value packOpDest = packOp.getDest();
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
963struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter)
const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
973 rewriter.
replaceOp(packOp, fillOp.value().result());
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter)
const override {
984 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
987 copyOp.getOutputs());
990 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
992 fillOp.getOutputs());
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter)
const override {
1005 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1017struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter)
const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1032 OpFoldResult firstFillVal =
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1038 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1044 OpFoldResult fillVal =
1046 if (fillVal != firstFillVal)
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1055 concatOp,
"not all operands are defined by a compatible fill op");
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1061 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1087 for (
ValueRange container : {inputs, outputs}) {
1088 for (
Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1092 blockArgLocs.push_back(v.getLoc());
1098 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1102void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1104 for (Value v : getRegionInputArgs())
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v,
"out");
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1121 inputs, outputs, bodyBuild);
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder,
result, resultTensorTypes, inputs, outputs,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1139 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1143void GenericOp::build(
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1149 ArrayRef<NamedAttribute> attributes) {
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1154void GenericOp::build(
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1162 "", bodyBuild, attributes);
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1174 "", bodyBuild, attributes);
1177void GenericOp::print(OpAsmPrinter &p) {
1181 auto genericAttrNames = linalgTraitAttrNames();
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (
auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(
getContext(), iteratorTypeNames));
1203 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1209 p << genericDictAttr;
1215 genericAttrNames.push_back(
"operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1218 bool hasExtraAttrs =
false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1223 if (hasExtraAttrs) {
1230 if (!getRegion().empty()) {
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1240 DictionaryAttr dictAttr;
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1257 if (!iteratorTypes) {
1258 return parser.
emitError(attributeLocation)
1259 <<
"expected " << getIteratorTypesAttrName(
result.name)
1260 <<
" array attribute";
1263 SmallVector<Attribute> iteratorTypeAttrs;
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1269 <<
"unexpected iterator_type (" << s <<
")";
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1274 result.attributes.set(getIteratorTypesAttrName(
result.name),
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1291 result.addRegion(std::move(region));
1297 SmallVector<Type, 1> outputTensorsTypes;
1300 result.addTypes(outputTensorsTypes);
1308 LinalgOp linalgOp) {
1309 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.
getType()))
1312 effects.emplace_back(
1317 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1341 if (!linalgOp.hasPureTensorSemantics())
1359template <
typename OpTy>
1360struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter)
const override {
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1384 linalgOp,
"expected single input and output to be the same value");
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1390 "cannot fold fill-like op");
1397 if (!linalgOp.hasPureTensorSemantics()) {
1399 linalgOp,
"mixed semantics is not supported yet");
1404 SmallVector<Value> returnedArgs;
1405 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1414 Type returnType = returnedArg.
getType();
1415 if (returnType != resultType) {
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1430 returnedArgs.push_back(returnedArg);
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1435 rewriter.
replaceOp(linalgOp, returnedArgs);
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1466 for (
Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1472 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1481void MapOp::getAsmBlockArgumentNames(Region ®ion,
1483 for (Value v : getRegionInputArgs())
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v,
"init");
1489void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(),
"mapped");
1497 ArrayRef<NamedAttribute> attributes) {
1499 result.addAttributes(attributes);
1502 Type initType = init.
getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1508 inputs, {init}, bodyBuild);
1515 bool initFirst =
false,
bool mapInit =
true) {
1519 b.setInsertionPointToStart(&block);
1520 for (
auto &operand : operands) {
1522 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1530 payloadOpOperands.push_back(block.
getArguments().back());
1531 for (
const auto &arg : block.
getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1552 if (
failed(operationName))
1556 payloadOpName = operationName.value();
1564 if (payloadOpName.has_value()) {
1565 if (!
result.operands.empty())
1567 payloadOpAttrs, ArrayRef(
result.operands),
false,
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1577 Region *body =
result.addRegion();
1585 bool mapInit =
true) {
1587 if (initFirst && !mapInit)
1611 for (
const auto &[operand, bbArg] :
1613 if (bbArg != operand)
1617 for (
const auto &[operand, bbArg] :
1620 if (bbArg != operand)
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1634 std::string attrToElide;
1636 for (
const auto &attr : payloadOp->
getAttrs()) {
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1660 if (!useShortForm) {
1666 [&](
auto arg) { p.printRegionArgument(arg); });
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() <<
"expects number of operands to match the arity of "
1683 << getInputs().size() + 1 <<
" and "
1684 << blockArgs.size();
1687 for (
const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() <<
"expected element type of input " << inputElemType
1693 <<
" to match bbArg type " << bbArgType;
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType :
TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() <<
"expected shape of input (" << inputElemShape
1703 <<
") to match shape of output (" << outputShape
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1738void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1740 for (Value v : getRegionInputArgs())
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v,
"init");
1746void ReduceOp::getAsmResultNames(
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(),
"reduced");
1752void ReduceOp::build(
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1756 ArrayRef<NamedAttribute> attributes) {
1758 result.addAttributes(attributes);
1761 for (Value init : inits) {
1762 Type initType = init.
getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1769 inputs, inits, bodyBuild);
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1774 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1784 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1788 AffineMap resultMap =
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1808 StringRef attributeName) {
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1821 if (
failed(operationName))
1825 payloadOpName = operationName.value();
1831 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1836 if (payloadOpName.has_value()) {
1838 ArrayRef(
result.operands),
true);
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1846 Region *body =
result.addRegion();
1856 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1869 if (!useShortForm) {
1875 [&](
auto arg) { p.printRegionArgument(arg); });
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1891 if (getInputs().size() !=
static_cast<size_t>(getNumDpsInputs()))
1893 <<
"expected equal number of inputs and outputs (required by "
1894 "SameVariadicOperandSize), got "
1895 << getNumDpsInputs() <<
" input(s) and " << getNumDpsInits()
1898 if (getInputs().empty())
1899 return emitOpError() <<
"expected at least one input";
1901 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1904 return emitOpError() <<
"expects all inputs to have the same shapes. "
1905 "Shape at input-index "
1907 <<
" is not equal to the shape at input-index 0.";
1910 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1913 return emitOpError() <<
"expects all outputs to have the same shapes. "
1914 "Shape at output-index "
1916 <<
" is not equal to the shape at output-index 0.";
1919 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1920 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1923 for (int64_t dimension : dimensionsRef) {
1924 if (dimension < 0 || dimension >= inputType.getRank()) {
1926 <<
"dimensions for reduction should be in the range [0, "
1927 << inputType.getRank() - 1 <<
"].";
1929 dimensionsToReduce.insert(dimension);
1932 auto inputDims = inputType.getShape();
1933 auto initDims = initType.getShape();
1936 SmallVector<int64_t> reducedInputDims;
1937 for (
const auto &en : llvm::enumerate(inputDims)) {
1938 if (!dimensionsToReduce.count(en.index()))
1939 reducedInputDims.push_back(en.value());
1942 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1943 return emitOpError() <<
"number of dimensions after reduction "
1944 << reducedInputDims.size()
1945 <<
" doesn't match the init rank "
1946 << initType.getRank();
1949 if (reducedInputDims != initDims)
1950 return emitOpError() <<
"init dimensions [" << initDims
1951 <<
"] doesn't match input dimensions after reduction ["
1952 << reducedInputDims <<
"]";
1954 Block *block = getBody();
1957 <<
"mismatching number of operands and block arguments";
1960 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1961 Type inputElementType =
1962 llvm::cast<ShapedType>(input.getType()).getElementType();
1963 if (inputElementType != bbArg.getType())
1965 <<
"input element type " << inputElementType
1966 <<
" does not match corresponding block argument type "
1971 for (
auto [output, bbArg] : llvm::zip(
1972 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1973 auto outputElementType =
1974 llvm::cast<ShapedType>(output.getType()).getElementType();
1975 if (outputElementType != bbArg.getType())
1977 <<
"output element type " << outputElementType
1978 <<
" does not match corresponding block argument type "
1994 linalg::YieldOp::create(
b, loc, args[0]);
1998void TransposeOp::build(::mlir::OpBuilder &builder,
1999 ::mlir::OperationState &
result, Value input, Value init,
2001 ArrayRef<NamedAttribute> attributes) {
2002 result.addOperands(input);
2003 result.addOperands(init);
2004 result.addAttribute(getPermutationAttrName(
result.name), permutation);
2005 result.addAttributes(attributes);
2008 Type initType = init.
getType();
2009 if (llvm::isa<RankedTensorType>(initType))
2010 result.addTypes(initType);
2016void TransposeOp::build(::mlir::OpBuilder &builder,
2017 ::mlir::OperationState &
result, Value input, Value init,
2018 ArrayRef<int64_t> permutation,
2019 ArrayRef<NamedAttribute> attributes) {
2024ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2026 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2038void TransposeOp::getAsmResultNames(
2040 if (!getResults().empty())
2041 setNameFn(getResults().front(),
"transposed");
2044void TransposeOp::print(OpAsmPrinter &p) {
2050LogicalResult TransposeOp::verify() {
2051 ArrayRef<int64_t> permutationRef = getPermutation();
2056 auto inputType = getInput().getType();
2057 auto initType = getInit().getType();
2059 int64_t rank = inputType.getRank();
2065 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2066 return emitOpError() <<
"size of permutation " << permutationRef.size()
2067 <<
" does not match the argument rank " << rank;
2069 auto inputDims = inputType.getShape();
2070 auto initDims = initType.getShape();
2072 for (int64_t i = 0; i < rank; ++i) {
2073 int64_t inputDim = inputDims[permutationRef[i]];
2074 int64_t initDim = initDims[i];
2076 if (inputDim != initDim) {
2077 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2078 <<
" doesn't match dim(input, permutation[" << i
2079 <<
"]) = " << inputDim;
2086SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2087 int64_t rank = getInit().getType().getRank();
2088 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2091ArrayAttr TransposeOp::getIndexingMaps() {
2093 int64_t rank = getInit().getType().getRank();
2096 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2100void TransposeOp::getEffects(
2101 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2110LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2111 SmallVectorImpl<OpFoldResult> &
result) {
2113 if (!isa<TensorType>(getInput().
getType()))
2117 if (getPermutation().empty()) {
2118 result.push_back(getInput());
2123 result.push_back(getInput());
2136 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2137 if (!defTransposeOp)
2142 foldedPerms.reserve(perms.size());
2144 foldedPerms.push_back(defPerms[perm]);
2147 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2161 Value input = transposeOp.getInput();
2162 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2173 unsigned dimensionSize = dimensions.size();
2174 for (
unsigned i = 0; i < dimensionSize; ++i)
2175 resultDimensions.push_back(invertPerm[dimensions[i]]);
2178 Value broadcastInput = broadcastOp.getInput();
2179 Location loc = transposeOp.getLoc();
2182 auto broadcastInputTy =
2183 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2184 unsigned inputRank = broadcastInputTy.getRank();
2185 for (
unsigned i = 0; i < inputRank; ++i) {
2186 if (broadcastInputTy.isDynamicDim(i)) {
2187 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2190 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2191 broadcastInputTy.getDimSize(i)));
2196 Value transposeInit = tensor::EmptyOp::create(
2197 rewriter, transposeOp.getLoc(), transposeResultShapes,
2198 broadcastInputTy.getElementType());
2201 Value transposeResult =
2202 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2203 transposeInit, resultPerms)
2206 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2211void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2212 MLIRContext *context) {
2213 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2220void BroadcastOp::build(::mlir::OpBuilder &builder,
2221 ::mlir::OperationState &
result, Value input, Value init,
2223 ArrayRef<NamedAttribute> attributes) {
2224 result.addOperands(input);
2225 result.addOperands(init);
2226 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2227 result.addAttributes(attributes);
2230 Type initType = init.
getType();
2231 if (llvm::isa<RankedTensorType>(initType))
2232 result.addTypes(initType);
2238void BroadcastOp::build(::mlir::OpBuilder &builder,
2239 ::mlir::OperationState &
result, Value input, Value init,
2240 ArrayRef<int64_t> dimensions,
2241 ArrayRef<NamedAttribute> attributes) {
2246ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2248 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2260void BroadcastOp::getAsmResultNames(
2262 if (!getResults().empty())
2263 setNameFn(getResults().front(),
"broadcasted");
2266void BroadcastOp::print(OpAsmPrinter &p) {
2272LogicalResult BroadcastOp::verify() {
2273 ArrayRef<int64_t> dimensionsRef = getDimensions();
2275 auto inputType = getInput().getType();
2276 auto initType = getInit().getType();
2278 int64_t inputRank = inputType.getRank();
2279 int64_t initRank = initType.getRank();
2281 auto inputShape = inputType.getShape();
2282 auto initShape = initType.getShape();
2284 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2285 return emitOpError() <<
"input rank plus added dimensions does not "
2286 "match init rank. input rank: "
2288 <<
", dimensions size: " << dimensionsRef.size()
2289 <<
", init rank: " << initRank;
2291 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2292 if (dim < 0 || dim >= initRank)
2294 <<
" is out of range. expected range: [0, "
2295 << initRank - 1 <<
"], got: " << dim;
2299 SmallVector<int64_t> dimMap;
2300 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2301 if (!llvm::is_contained(dimensionsRef, dim))
2302 dimMap.push_back(dim);
2305 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2308 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2309 return emitOpError() <<
"input dim " << inputDimIdx
2310 <<
" should match init dim " << initDimIdx
2311 <<
". input: " << inputShape[inputDimIdx]
2312 <<
", init: " << initShape[initDimIdx];
2318SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2319 int64_t rank = getInit().getType().getRank();
2320 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2323ArrayAttr BroadcastOp::getIndexingMaps() {
2325 int64_t rank = getInit().getType().getRank();
2331void BroadcastOp::getEffects(
2332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2347 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2348 if (!defBroadcastOp)
2353 Value init = broadcastOp.getInit();
2357 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2358 if (!llvm::is_contained(dimensions, dim))
2359 dimMap.push_back(dim);
2361 for (
auto dim : defDimensions)
2362 foldedDims.push_back(dimMap[dim]);
2364 llvm::sort(foldedDims);
2366 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2371void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2372 MLIRContext *context) {
2373 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2380void linalg::YieldOp::print(OpAsmPrinter &p) {
2381 if (getNumOperands() > 0)
2382 p <<
' ' << getOperands();
2384 if (getNumOperands() > 0)
2385 p <<
" : " << getOperandTypes();
2388ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2389 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2390 SmallVector<Type, 2> types;
2400static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2401 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2402 return op.emitOpError(
"expected number of yield values (")
2403 << op.getNumOperands()
2404 <<
") to match the number of inits / outs operands of the enclosing "
2405 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2407 for (
OpOperand &opOperand : op->getOpOperands()) {
2409 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2411 if (isa<MemRefType, RankedTensorType>(elementType))
2413 if (opOperand.get().getType() != elementType)
2414 return op.emitOpError(
"type of yield operand ")
2415 << (opOperand.getOperandNumber() + 1) <<
" ("
2416 << opOperand.get().getType() <<
") doesn't match "
2417 <<
"the element type of the enclosing linalg.generic op ("
2418 << elementType <<
")";
2423LogicalResult linalg::YieldOp::verify() {
2424 auto *parentOp = (*this)->getParentOp();
2425 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2426 return emitOpError(
"expected single non-empty parent region");
2428 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2431 return emitOpError(
"expected parent op with LinalgOp interface");
2438LogicalResult IndexOp::verify() {
2439 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2441 return emitOpError(
"expected parent op with LinalgOp interface");
2442 if (linalgOp.getNumLoops() <= getDim())
2444 << getDim() <<
") to be lower than the number of loops ("
2445 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2449OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2450 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2455 return OpFoldResult{};
2458 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2459 uint64_t dim = getDim();
2460 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2461 if (loopBounds[dim] == 1)
2462 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2464 return OpFoldResult{};
2469#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2471#define GET_OP_CLASSES
2472#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2474#define GET_OP_CLASSES
2475#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2476#define GET_OP_CLASSES
2477#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2494 for (
unsigned i = 0; i < num; ++i)
2501 auto rangeA = llvm::make_range(a.begin(), a.end());
2502 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2503 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2504 return llvm::to_vector<4>(concatRanges);
2508 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2510 for (
auto size :
memref.getShape())
2517 if (
auto as =
memref.getMemorySpace()) {
2518 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2519 ss <<
"as" << attr.getInt();
2525 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2528 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2541 assert(isa<LinalgOp>(op));
2543 std::string fun =
"";
2545 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2546 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2547 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2548 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2552 llvm::replace(name,
'.',
'_');
2553 llvm::raw_string_ostream ss(name);
2557 return std::string();
2572 LogicalResult matchAndRewrite(LinalgOp op,
2574 for (
OpOperand &opOperand : op->getOpOperands()) {
2578 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2581 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2592struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2593 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2595 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2596 PatternRewriter &rewriter)
const override {
2600 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2607 if (castOp->getBlock() != linalgOp->getBlock())
2610 OpBuilder::InsertionGuard guard(rewriter);
2613 Location loc = linalgOp.getLoc();
2614 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2617 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2623 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2625 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2626 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2627 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2628 linalgOp.getDpsInits().end());
2629 outputOperands[resultNumber] = newOperand;
2630 newOperands.append(outputOperands.begin(), outputOperands.end());
2632 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2633 linalgOp->result_type_end());
2634 resultTypes[resultNumber] = resultType;
2635 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2638 Value castBack = tensor::CastOp::create(
2642 results[resultNumber] = castBack;
2651static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2652 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2653 for (OpOperand &opOperand : operands) {
2654 if (linalgOp.isScalar(&opOperand))
2656 Value src = opOperand.get();
2657 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2658 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2664 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2666 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2667 Value castSource = castOp.getSource();
2668 auto castSourceType =
2669 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2670 if (castSourceType && castSourceType.hasStaticShape())
2671 sourceShape = castSourceType.getShape();
2677 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2678 if (sourceType.isDynamicDim(i))
2680 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2681 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2691static void createNewOperandWithStaticSizes(
2692 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2693 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2694 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2695 bool &changeNeeded) {
2696 Value src = opOperand->
get();
2697 newOperands.push_back(src);
2698 if (linalgOp.isScalar(opOperand))
2700 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2701 Type resultType = sourceType;
2702 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2703 resultTypes.push_back(resultType);
2706 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2707 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2708 SmallVector<int64_t> newShape;
2711 bool newOperandNeeded =
false;
2712 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2713 int64_t dimShape = sourceShape[i];
2714 AffineExpr dimExpr = sourceMap.
getResult(i);
2715 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2716 newShape.push_back(dimShape);
2722 newShape.push_back(affineExprToSize[dimExpr]);
2723 newOperandNeeded =
true;
2725 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2726 sourceType.getEncoding());
2727 if (newOperandNeeded) {
2728 changeNeeded =
true;
2731 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2733 newOperands[index] = newOperand;
2735 if (linalgOp.isDpsInit(opOperand))
2736 resultTypes.push_back(resultType);
2742struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2743 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2745 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2746 PatternRewriter &rewriter)
const override {
2747 if (!linalgOp.hasPureTensorSemantics())
2751 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2752 return !map.isProjectedPermutation();
2757 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2758 Location loc = linalgOp.getLoc();
2762 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2764 SmallVector<Value> newOperands;
2765 SmallVector<Type> resultTypes;
2769 bool changeNeeded =
false;
2770 newOperands.reserve(linalgOp->getNumOperands());
2771 resultTypes.reserve(linalgOp.getNumDpsInits());
2774 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2775 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2776 affineExprToSize, linalgOp, newOperands,
2777 resultTypes, changeNeeded);
2786 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2787 SmallVector<Value> replacements;
2789 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2790 Value newResult = std::get<1>(it);
2791 Value oldResult = std::get<0>(it);
2792 Type newType = newResult.
getType();
2793 Type oldType = oldResult.
getType();
2794 replacements.push_back(
2795 (newType != oldType)
2796 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2799 rewriter.
replaceOp(linalgOp, replacements);
2813LogicalResult SoftmaxOp::verify() {
2814 ShapedType inputType = getInputOperandType();
2815 ShapedType outputType = getOutputOperandType();
2817 ArrayRef<int64_t> inputShape = inputType.getShape();
2818 ArrayRef<int64_t> outputShape = outputType.getShape();
2822 int64_t inputRank = getInputOperandRank();
2823 int64_t dimension = getDimension();
2824 if ((dimension < 0) || (dimension >= inputRank))
2825 return emitOpError(
"incorrect dimension specified");
2830SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2831 int64_t operandRank = getInputOperandRank();
2832 SmallVector<Range> loopBounds(operandRank);
2833 Location loc = getLoc();
2836 Value source = getInput();
2837 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2838 loopBounds[dim].offset = zero;
2839 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2840 loopBounds[dim].stride = one;
2845SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2846 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2847 utils::IteratorType::parallel);
2848 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2849 return iteratorTypes;
2852FailureOr<TilingResult>
2853SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2854 ArrayRef<OpFoldResult> offsets,
2855 ArrayRef<OpFoldResult> sizes) {
2856 int64_t rank = getInputOperandRank();
2858 SmallVector<OpFoldResult> strides(rank, oneAttr);
2859 SmallVector<Value> tiledOperands;
2860 Operation *inputSlice =
2861 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2863 return emitOpError(
"failed to compute input slice");
2865 tiledOperands.emplace_back(inputSlice->
getResult(0));
2866 Operation *outputSlice =
2867 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2869 return emitOpError(
"failed to compute output slice");
2871 tiledOperands.emplace_back(outputSlice->
getResult(0));
2873 SmallVector<Type, 4> resultTypes;
2874 if (hasPureTensorSemantics())
2875 resultTypes.push_back(tiledOperands[1].
getType());
2876 Operation *tiledOp =
2877 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2879 return TilingResult{
2882 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2885LogicalResult SoftmaxOp::getResultTilePosition(
2886 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2887 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2888 SmallVector<OpFoldResult> &resultSizes) {
2889 if (resultNumber == 0) {
2890 resultOffsets.assign(offsets.begin(), offsets.end());
2891 resultSizes.assign(sizes.begin(), sizes.end());
2898LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2903SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2905 SmallVector<OpFoldResult> shapes;
2906 Location loc = getOperation()->getLoc();
2907 IRRewriter rewriter(
b);
2908 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2909 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2910 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2911 if (!outputShapedType.isDynamicDim(dim)) {
2913 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2920 reifiedReturnShapes.emplace_back(std::move(shapes));
2924void SoftmaxOp::getEffects(
2925 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2927 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2928 if (!llvm::isa<MemRefType>(operand.
getType()))
2931 &getOperation()->getOpOperand(index), 0,
2936 for (OpOperand &operand : getDpsInitsMutable()) {
2937 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2968static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2970 int64_t dim,
bool allParallel =
false) {
2972 utils::IteratorType::parallel);
2974 iteratorTypes[dim] = utils::IteratorType::reduction;
2978 for (
int i = 0; i < inputRank; i++) {
2985 return std::make_tuple(iteratorTypes, indexingMaps);
2990template <
typename T>
2993 auto inputType = cast<ShapedType>(input.
getType());
2995 int64_t inputRank = inputShape.size();
2996 auto [iteratorTypes, indexingMaps] =
2998 assert(indexingMaps.size() == 2 &&
2999 "We should have two maps: 1 for the input, 1 for the output");
3000 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3002 auto genericOp = linalg::GenericOp::create(
3003 builder, loc, output.
getType(), input, output, indexingMaps,
3005 Value result = T::create(b, loc, args[0], args[1]);
3006 linalg::YieldOp::create(b, loc, result);
3008 return genericOp.getResult(0);
3016 auto inputType = cast<ShapedType>(input.
getType());
3018 int64_t inputRank = inputShape.size();
3020 builder, inputRank, dim,
true);
3021 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3022 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3024 indexingMaps.push_back(indexingMaps[0]);
3025 auto genericOp = linalg::GenericOp::create(
3027 indexingMaps, iteratorTypes,
3029 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3030 Value result = math::ExpOp::create(b, loc, diff);
3031 linalg::YieldOp::create(b, loc, result);
3033 return genericOp.getResult(0);
3043 auto inputType = cast<ShapedType>(numerator.
getType());
3045 int64_t inputRank = inputShape.size();
3047 builder, inputRank, dim,
true);
3048 assert(indexingMaps.size() == 2 &&
3049 "We should have one map for each input (2)");
3050 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3052 indexingMaps.push_back(indexingMaps[0]);
3053 auto genericOp = linalg::GenericOp::create(
3055 output, indexingMaps, iteratorTypes,
3057 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3058 linalg::YieldOp::create(b, loc, result);
3060 return genericOp.getResult(0);
3082FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3083 OpBuilder::InsertionGuard guard(
b);
3084 b.setInsertionPoint(*
this);
3085 Location loc = getLoc();
3086 Value input = getInput();
3087 ShapedType inputType = getInputOperandType();
3088 Type elementType = inputType.getElementType();
3089 int64_t reductionDim = getDimension();
3091 Value output = getOutput();
3092 dims.erase(dims.begin() + reductionDim);
3094 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3096 elementType,
b, loc,
3098 Value neutralForMaxFInit =
3099 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3111 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3117 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3118 return SmallVector<Value>{
result};
3125LogicalResult WinogradFilterTransformOp::verify() {
3126 auto filterType = cast<ShapedType>(getFilter().
getType());
3127 ArrayRef<int64_t> filterShape = filterType.getShape();
3128 int64_t filterH = filterShape[getFilterHDim()];
3129 int64_t filterW = filterShape[getFilterWDim()];
3130 WinogradConv2DFmr fmr = getFmr();
3134 if (filterH != r && filterH != 1)
3135 return emitOpError(
"expect filter height either equals to r or 1");
3136 if (filterW != r && filterW != 1)
3137 return emitOpError(
"expect filter width either equals to r or 1");
3138 if (filterH == 1 && filterW == 1)
3139 return emitOpError(
"expect either filter height or width equals to r");
3141 SmallVector<int64_t> expectedOutputShape;
3142 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3143 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3144 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3145 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3147 auto outputType = cast<ShapedType>(getOutput().
getType());
3148 ArrayRef<int64_t> outputShape = outputType.getShape();
3150 return emitOpError(
"the output shape is not expected");
3156WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3157 Location loc = getLoc();
3160 Value filter = getFilter();
3161 int64_t filterRank = getFilterOperandRank();
3162 SmallVector<Range> loopBounds(filterRank);
3163 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3164 loopBounds[dim].offset = zeroAttr;
3165 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3166 loopBounds[dim].stride = oneAttr;
3171SmallVector<utils::IteratorType>
3172WinogradFilterTransformOp::getLoopIteratorTypes() {
3173 int64_t filterRank = getFilterOperandRank();
3174 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3175 utils::IteratorType::parallel);
3176 return iteratorTypes;
3179LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3180 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3181 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3182 SmallVector<OpFoldResult> &resultSizes) {
3184 ShapedType filterType = getFilterOperandType();
3185 ArrayRef<int64_t> filterShape = filterType.getShape();
3186 int64_t filterH = filterShape[getFilterHDim()];
3187 int64_t filterW = filterShape[getFilterWDim()];
3188 WinogradConv2DFmr fmr = getFmr();
3191 int64_t alpha = m + r - 1;
3192 int64_t alphaH = filterH != 1 ? alpha : 1;
3193 int64_t alphaW = filterW != 1 ? alpha : 1;
3197 resultOffsets.append(
3198 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3200 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3211FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3212 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3213 ArrayRef<OpFoldResult> sizes) {
3216 ShapedType filterType = getFilterOperandType();
3217 ArrayRef<int64_t> filterShape = filterType.getShape();
3218 int64_t filterH = filterShape[getFilterHDim()];
3219 int64_t filterW = filterShape[getFilterWDim()];
3222 SmallVector<Value> tiledOperands;
3223 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3225 sliceOffsets.append(
3226 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3227 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3228 sizes[getFilterCDim()]});
3229 int64_t filterRank = getFilterOperandRank();
3230 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3231 Location loc = getLoc();
3232 auto filterSlice = tensor::ExtractSliceOp::create(
3233 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3234 tiledOperands.emplace_back(filterSlice);
3236 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3241 int64_t outputRank = getOutputOperandRank();
3242 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3243 auto outputSlice = tensor::ExtractSliceOp::create(
3244 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3245 tiledOperands.emplace_back(outputSlice);
3247 SmallVector<Type> resultTypes;
3248 resultTypes.push_back(tiledOperands[1].
getType());
3249 Operation *tiledOp =
3250 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3252 return TilingResult{
3255 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3262LogicalResult WinogradInputTransformOp::verify() {
3263 auto inputType = cast<ShapedType>(getInput().
getType());
3264 ArrayRef<int64_t> inputShape = inputType.getShape();
3265 int64_t inputH = inputShape[getInputHDim()];
3266 int64_t inputW = inputShape[getInputWDim()];
3267 WinogradConv2DFmr fmr = getFmr();
3270 int64_t tileSize = m + r - 1;
3272 auto outputType = cast<ShapedType>(getOutput().
getType());
3273 ArrayRef<int64_t> outputShape = outputType.getShape();
3274 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3275 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3277 SmallVector<int64_t> expectedOutputShape(6, inputH);
3278 if (ShapedType::isDynamic(inputH)) {
3279 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3280 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3282 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3283 expectedOutputShape[getOutputTileHDim()] =
3284 leftTransform ? (inputH - (r - 1)) / m : inputH;
3286 if (ShapedType::isDynamic(inputW)) {
3287 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3288 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3290 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3291 expectedOutputShape[getOutputTileWDim()] =
3292 rightTransform ? (inputW - (r - 1)) / m : inputW;
3294 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3295 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3298 return emitOpError(
"the output shape is not expected");
3304WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3305 Location loc = getLoc();
3308 Value output = getOutput();
3309 int64_t outputRank = getOutputOperandRank();
3310 SmallVector<Range> loopBounds(outputRank);
3311 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3312 loopBounds[dim].offset = zeroAttr;
3314 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3315 loopBounds[dim].stride = oneAttr;
3320SmallVector<utils::IteratorType>
3321WinogradInputTransformOp::getLoopIteratorTypes() {
3322 int64_t outputRank = getOutputOperandRank();
3323 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3324 utils::IteratorType::parallel);
3325 return iteratorTypes;
3328LogicalResult WinogradInputTransformOp::getResultTilePosition(
3329 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3330 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3331 SmallVector<OpFoldResult> &resultSizes) {
3333 ShapedType outputType = getOutputOperandType();
3334 ArrayRef<int64_t> outputShape = outputType.getShape();
3335 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3336 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3338 WinogradConv2DFmr fmr = getFmr();
3341 int64_t alpha = m + r - 1;
3342 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3343 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3348 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3349 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3350 offsets[getOutputCDim()]});
3351 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3352 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3353 sizes[getOutputCDim()]});
3364FailureOr<TilingResult>
3365WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3366 ArrayRef<OpFoldResult> offsets,
3367 ArrayRef<OpFoldResult> sizes) {
3369 WinogradConv2DFmr fmr = getFmr();
3373 ShapedType outputType = getOutputOperandType();
3374 ArrayRef<int64_t> outputShape = outputType.getShape();
3375 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3376 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3378 Location loc = getLoc();
3380 auto identityAffineMap =
3382 auto offsetAffineMap =
3385 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3386 offsets[getOutputTileHDim()]);
3388 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3389 offsets[getOutputTileWDim()]);
3393 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3395 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3397 SmallVector<Value> tiledOperands;
3398 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3400 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3401 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3402 sliceOffsets.append(
3403 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3404 OpFoldResult sizeH =
3405 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3406 OpFoldResult sizeW =
3407 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3409 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3410 int64_t inputRank = getInputOperandRank();
3411 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3412 auto inputSlice = tensor::ExtractSliceOp::create(
3413 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3414 tiledOperands.emplace_back(inputSlice);
3416 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3421 int64_t outputRank = getOutputOperandRank();
3422 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3423 auto outputSlice = tensor::ExtractSliceOp::create(
3424 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3425 tiledOperands.emplace_back(outputSlice);
3427 SmallVector<Type> resultTypes;
3428 resultTypes.push_back(tiledOperands[1].
getType());
3429 Operation *tiledOp =
3430 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3432 return TilingResult{
3435 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3442LogicalResult WinogradOutputTransformOp::verify() {
3443 auto valueType = cast<ShapedType>(getValue().
getType());
3444 ArrayRef<int64_t> valueShape = valueType.getShape();
3445 int64_t valueH = valueShape[getValueAlphaHDim()];
3446 int64_t valueW = valueShape[getValueAlphaWDim()];
3447 int64_t valueTileH = valueShape[getValueTileHDim()];
3448 int64_t valueTileW = valueShape[getValueTileWDim()];
3449 WinogradConv2DFmr fmr = getFmr();
3452 bool leftTransform = valueH != 1;
3453 bool rightTransform = valueW != 1;
3455 int64_t outputRank = getOutputOperandRank();
3456 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3457 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3458 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3460 if (valueH != (leftTransform ? m + r - 1 : 1))
3461 return emitOpError(
"expect input height equals to input tile size");
3462 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3464 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3465 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3467 if (valueW != (rightTransform ? m + r - 1 : 1))
3468 return emitOpError(
"expect input width equals to input tile size");
3469 expectedOutputShape[getOutputWDim()] =
3470 (rightTransform ? m : 1) * valueTileW;
3472 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3473 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3475 auto outputType = cast<ShapedType>(getOutput().
getType());
3476 ArrayRef<int64_t> outputShape = outputType.getShape();
3478 return emitOpError(
"the output shape is not expected");
3484WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3485 Location loc = getLoc();
3488 Value value = getValue();
3489 int64_t valueRank = getValueOperandRank();
3490 SmallVector<Range> loopBounds(valueRank);
3491 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3492 loopBounds[dim].offset = zeroAttr;
3494 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3495 loopBounds[dim].stride = oneAttr;
3500SmallVector<utils::IteratorType>
3501WinogradOutputTransformOp::getLoopIteratorTypes() {
3502 int64_t valueRank = getValueOperandRank();
3503 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3504 utils::IteratorType::parallel);
3505 return iteratorTypes;
3508LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3509 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3510 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3511 SmallVector<OpFoldResult> &resultSizes) {
3512 WinogradConv2DFmr fmr = getFmr();
3516 Location loc = getLoc();
3518 auto identityAffineMap =
3523 ShapedType valueType = getValueOperandType();
3524 ArrayRef<int64_t> valueShape = valueType.getShape();
3525 int64_t valueH = valueShape[0];
3526 int64_t valueW = valueShape[1];
3528 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3529 offsets[getValueTileHDim()]);
3531 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3532 offsets[getValueTileWDim()]);
3534 builder, loc, affineMap, sizes[getValueTileHDim()]);
3536 builder, loc, affineMap, sizes[getValueTileWDim()]);
3539 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3540 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3541 OpFoldResult sizeH =
3542 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3543 OpFoldResult sizeW =
3544 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3546 resultOffsets.append(
3547 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3549 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3559FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3560 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3561 ArrayRef<OpFoldResult> sizes) {
3564 Location loc = getLoc();
3565 SmallVector<Value> tiledOperands;
3566 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3568 ShapedType valueType = getValueOperandType();
3569 ArrayRef<int64_t> valueShape = valueType.getShape();
3570 int64_t alphaH = valueShape[getValueAlphaHDim()];
3571 int64_t alphaW = valueShape[getValueAlphaWDim()];
3575 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3576 offsets[getValueTileWDim()], offsets[getValueNDim()],
3577 offsets[getValueFDim()]});
3578 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3579 sizes[getValueTileWDim()], sizes[getValueNDim()],
3580 sizes[getValueFDim()]});
3581 int64_t valueRank = getValueOperandRank();
3582 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3583 auto valueSlice = tensor::ExtractSliceOp::create(
3584 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3585 tiledOperands.emplace_back(valueSlice);
3587 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3592 int64_t outputRank = getOutputOperandRank();
3593 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3594 auto outputSlice = tensor::ExtractSliceOp::create(
3595 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3596 tiledOperands.emplace_back(outputSlice);
3598 SmallVector<Type> resultTypes;
3599 resultTypes.push_back(tiledOperands[1].
getType());
3600 Operation *tiledOp =
3601 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3603 return TilingResult{
3606 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3620 llvm::set_union(explicitSet, defaultSet);
3621 return explicitSet == defaultSet;
3641 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3643 auto opIndexingMap = opIndexingMaps[opIndex];
3644 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3647 return matmulOp->emitOpError()
3648 <<
"Unexpected dim expression in map result.";
3651 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3652 return matmulOp->emitOpError()
3653 <<
"Invalid broadcast requested, should be (d2).";
3662template <
typename OpTy>
3665 AffineMap defaultIndexingMap,
bool isLHS) {
3666 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3667 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3668 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3671 return batchVariantMatmulOp->emitOpError()
3672 <<
"Unexpected result dim expression (outside the set of default "
3677 return batchVariantMatmulOp->emitOpError()
3678 <<
"no. of result dim expressions exceeds 3.";
3680 auto hasValidBatchDim = [](
AffineMap map) {
3687 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3688 return batchVariantMatmulOp->emitOpError()
3689 <<
"Invalid broadcast requested.";
3690 }
else if (!hasValidBatchDim(opIndexingMap)) {
3691 return batchVariantMatmulOp->emitOpError()
3692 <<
"Invalid batch dimension expression.";
3700template <
typename OpTy>
3703 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3704 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3705 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3706 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3709 return batchVariantMatmulOp->emitOpError()
3710 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3713 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3715 return batchVariantMatmulOp->emitOpError()
3716 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3720 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3721 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3722 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3723 outputMap.getResult(1).isFunctionOfDim(1) &&
3724 outputMap.getResult(2).isFunctionOfDim(2)
3725 : outputMap.getResult(0).isFunctionOfDim(1) &&
3726 outputMap.getResult(1).isFunctionOfDim(2);
3729 if (!areValidOutputResultDim(opIndexingMap)) {
3730 return batchVariantMatmulOp->emitOpError()
3731 <<
"Invalid output map result dimension.";
3740template <
typename OpTy>
3745 batchVariantMatmulOp.getIndexingMapsArray();
3747 batchVariantMatmulOp.getDefaultIndexingMaps(
3748 batchVariantMatmulOp->getContext());
3750 if (opIndexingMaps.size() != 3)
3751 return batchVariantMatmulOp->emitOpError()
3752 <<
"Indexing_map attribute must have 3 affine maps.";
3754 auto opIndexingMap = opIndexingMaps[opIndex];
3755 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3763 defaultIndexingMap, opIndex == 0)))
3773 if (m == 2 && r == 3)
3774 return WinogradConv2DFmr::F_2_3;
3775 if (m == 4 && r == 3)
3776 return WinogradConv2DFmr::F_4_3;
3777 if (m == 2 && r == 5)
3778 return WinogradConv2DFmr::F_2_5;
3779 return std::nullopt;
3784 case WinogradConv2DFmr::F_2_3:
3786 case WinogradConv2DFmr::F_4_3:
3788 case WinogradConv2DFmr::F_2_5:
3791 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3798static FailureOr<SmallVector<SmallVector<int64_t>>>
3801 for (
auto map : maps) {
3802 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3806 for (
auto result : attr.getAffineMap().getResults()) {
3807 auto dim = dyn_cast<AffineDimExpr>(
result);
3810 pos.push_back(dim.getPosition());
3812 positions.push_back(pos);
3825 return indexingMaps;
3828bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3829 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3832 if (maps.size() != 3)
3837 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3838 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3839 (*positions)[2] == SmallVector<int64_t>{0, 1};
3842SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3843 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3844 utils::IteratorType::parallel,
3845 utils::IteratorType::reduction};
3848unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3850std::string MatmulOp::getLibraryCallName() {
3854bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3858bool MatmulOp::hasUserDefinedMaps() {
3859 SmallVector<AffineMap, 3> defaultMaps =
3861 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3862 return defaultMaps != explicitMaps;
3867void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3868 ArrayRef<NamedAttribute> attrs,
3871 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3876 "MatmulOp regionBuilder expects 3 args");
3877 RegionBuilderHelper helper(
b, block);
3878 SmallVector<Value> yields;
3880 TypeFn castVal = TypeFn::cast_signed;
3881 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3882 return attr.
getName() ==
"cast";
3884 if (castIter != attrs.end()) {
3885 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3893 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3894 if (!value1 || !value2 || !value3)
3896 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3900 yields.push_back(value4);
3901 helper.yieldOutputs(yields);
3911bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3912 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3913 AffineExpr expr = bcastMap.
getResult(0);
3923 ArrayAttr arrayAttr;
3927 if (llvm::any_of(arrayAttr,
3928 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3930 <<
"element of indexing_maps array is not an affine_map";
3937 if (failed(indexingMapsAttr))
3940 if (*indexingMapsAttr ==
nullptr) {
3941 auto indexingMapAttrs = llvm::map_to_vector(
3942 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3947 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3949 MatmulOp::getRegionBuilder());
3952void MatmulOp::print(OpAsmPrinter &p) {
3953 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3954 MatmulOp::getDefaultIndexingMaps(
getContext()),
3955 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3956 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3957 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3959 std::array<StringRef, 3> elidedAttrs = {
3960 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3966LogicalResult MatmulOp::verify() {
3968 if (!hasUserDefinedMaps())
3971 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3978LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3982void MatmulOp::getEffects(
3983 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3985 if (hasPureTensorSemantics())
3994SmallVector<AffineMap>
3995MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3996 AffineExpr d0, d1, d2;
4002 return {mapLHS, mapRHS, mapOut};
4006 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4009 if (maps.size() != 3)
4012 if (failed(positions))
4024 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4032 build(builder, state, inputs, outputs, attributes);
4033 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4034 assert(res &&
"builder didn't return the right type");
4044 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4053 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4054 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4055 assert(res &&
"builder didn't return the right type");
4065 result.addAttribute(
"cast", cast);
4067 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4076 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4077 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4078 assert(res &&
"builder didn't return the right type");
4083 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4085 op->
getAttr(
"indexing_maps"));
4089MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4096 return {mapLHS, mapRHS, mapOut};
4100 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4103 if (maps.size() != 3)
4106 if (failed(positions))
4118 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4126 build(builder, state, inputs, outputs, attributes);
4127 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4128 assert(res &&
"builder didn't return the right type");
4138 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4147 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4148 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4149 assert(res &&
"builder didn't return the right type");
4159 result.addAttribute(
"cast", cast);
4161 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4170 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4171 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4172 assert(res &&
"builder didn't return the right type");
4177 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4179 op->
getAttr(
"indexing_maps"));
4183BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4190 return {mapLHS, mapRHS, mapOut};
4194 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4197 if (maps.size() != 3)
4200 if (failed(positions))
4211 BatchMatmulOp::getRegionBuilder(),
4212 getDefaultIndexingMaps(builder));
4220 build(builder, state, inputs, outputs, attributes);
4221 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4222 assert(res &&
"builder didn't return the right type");
4231 BatchMatmulOp::getRegionBuilder(),
4232 getDefaultIndexingMaps(builder));
4241 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4242 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4243 assert(res &&
"builder didn't return the right type");
4251 result.addAttribute(
"cast", cast);
4253 BatchMatmulOp::getRegionBuilder(),
4254 getDefaultIndexingMaps(builder));
4263 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4264 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4265 assert(res &&
"builder didn't return the right type");
4270 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4272 op->
getAttr(
"indexing_maps"));
4276BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4283 return {mapLHS, mapRHS, mapOut};
4287 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4290 if (maps.size() != 3)
4293 if (failed(positions))
4304 BatchMatmulOp::getRegionBuilder(),
4305 getDefaultIndexingMaps(builder));
4313 build(builder, state, inputs, outputs, attributes);
4314 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4315 assert(res &&
"builder didn't return the right type");
4324 BatchMatmulOp::getRegionBuilder(),
4325 getDefaultIndexingMaps(builder));
4334 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4335 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4336 assert(res &&
"builder didn't return the right type");
4344 result.addAttribute(
"cast", cast);
4346 BatchMatmulOp::getRegionBuilder(),
4347 getDefaultIndexingMaps(builder));
4356 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4357 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4358 assert(res &&
"builder didn't return the right type");
4363 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4365 op->
getAttr(
"indexing_maps"));
4373 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4384 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4385 assert(dimExpr &&
"affine_map is a projected permutation");
4386 dimsInOutput[dimExpr.getPosition()] =
true;
4390 for (
auto dimOccursInOutput : dimsInOutput)
4391 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4392 : utils::IteratorType::reduction);
4394 return iteratorTypes;
4397unsigned ContractOp::getNumRegionArgs() {
return 3; }
4400void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4401 ArrayRef<NamedAttribute> attrs,
4404 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4409 "ContractOp regionBuilder expects 3 args");
4410 RegionBuilderHelper helper(
b, block);
4412 TypeFn castSignedness = TypeFn::cast_signed;
4413 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4414 return attr.
getName() ==
"cast";
4416 if (castIter != attrs.end()) {
4417 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4423 Value lhsAtOutType =
4424 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4425 Value rhsAtOutType =
4426 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4427 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4429 if (!productAtOutType)
4435 helper.yieldOutputs({
result});
4438ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4440 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4442 "expected 'indexing_maps' attribute");
4443 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4449void ContractOp::print(OpAsmPrinter &p) {
4450 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4452 p, getOperation(), getInputs(), getOutputs(),
4453 {
"indexing_maps",
"operandSegmentSizes"});
4456LogicalResult ContractOp::verify() {
4457 int iterationSpaceDims = -1;
4462 SmallVector<size_t> inOccurrences;
4463 SmallVector<size_t> outOccurrences;
4466 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4467 bool isInput) -> LogicalResult {
4470 return emitError(
"provided affine_map is not a projected permutation");
4473 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4475 return emitError(
"ranks of shaped operand and results of corresponding "
4476 "affine_map differ");
4478 return emitError(
"affine_map specifies shaped access while operand has "
4483 if (iterationSpaceDims == -1) {
4485 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4486 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4487 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4488 return emitError(
"iteration spaces of provided affine_maps differ");
4492 for (AffineExpr affineExpr : affineMap.
getResults()) {
4493 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4495 llvm_unreachable(
"affine_map is a projected permutation");
4498 inOccurrences[affineDimExpr.getPosition()] += 1;
4500 outOccurrences[affineDimExpr.getPosition()] += 1;
4506 for (
auto &&[affineMap, operandType, isInput] :
4507 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4508 SmallVector<bool>{
true,
true,
false})) {
4509 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4513 bool hasContractingDim =
false;
4514 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4515 size_t inOccCount = inOccurrences[dimIndex];
4516 size_t outOccCount = outOccurrences[dimIndex];
4519 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4521 if (inOccCount == 0 && outOccCount == 0)
4522 return emitError() <<
"iteration space dim at index " << dimIndex
4523 <<
" not used to access any operand";
4534 if (inOccCount == 1 && outOccCount != 1)
4536 <<
"iteration space dim at index " << dimIndex
4537 <<
" is neither a contracting dim nor of parallel iteration type";
4540 if (!hasContractingDim)
4541 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4546LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4550void ContractOp::getEffects(
4551 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4553 if (hasPureTensorSemantics())
4565SmallVector<AffineMap>
4566BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4567 AffineExpr d0, d1, d2, d3;
4568 SmallVector<AffineMap> indexingMaps;
4570 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4571 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4572 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4573 return indexingMaps;
4576bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4577 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4580 if (maps.size() != 3)
4585 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4586 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4587 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4590SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4591 return SmallVector<utils::IteratorType>{
4592 utils::IteratorType::parallel, utils::IteratorType::parallel,
4593 utils::IteratorType::parallel, utils::IteratorType::reduction};
4596unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4598std::string BatchMatmulOp::getLibraryCallName() {
4604bool BatchMatmulOp::hasUserDefinedMaps() {
4605 SmallVector<AffineMap, 3> defaultMaps =
4607 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4608 return defaultMaps != explicitMaps;
4618bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4620 "Expected less than 3 result dim expr.");
4621 bool isValid =
false;
4622 enum Indices { batchPos, mPos, nPos, kPos };
4624 AffineExpr expr = bcastMap.
getResult(0);
4627 AffineExpr expr0 = bcastMap.
getResult(0);
4628 AffineExpr expr1 = bcastMap.
getResult(1);
4633 : ((expr0.isFunctionOfDim(batchPos) &&
4634 expr1.isFunctionOfDim(kPos)) ||
4635 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4640void BatchMatmulOp::regionBuilder(
4641 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4644 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4649 "BatchMatmulOp regionBuilder expects 3 args");
4650 RegionBuilderHelper helper(
b, block);
4651 SmallVector<Value> yields;
4653 TypeFn castVal = TypeFn::cast_signed;
4654 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4655 return attr.
getName() ==
"cast";
4657 if (castIter != attrs.end()) {
4658 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4663 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4664 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4666 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4667 if (!castValA || !castValB || !mulVal)
4669 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4673 yields.push_back(addVal);
4674 helper.yieldOutputs(yields);
4677ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4678 SmallVector<Attribute, 3> indexingMapsAttr;
4690 if (!isa<AffineMapAttr>(mapAttr)) {
4692 "expected affine map attribute");
4694 indexingMapsAttr.push_back(mapAttr);
4704 if (indexingMapsAttr.empty()) {
4705 indexingMapsAttr = llvm::map_to_vector(
4706 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4707 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4709 result.addAttribute(
"indexing_maps",
4712 return ::parseNamedStructuredOp(parser,
result,
4713 BatchMatmulOp::getNumRegionArgs(),
4714 BatchMatmulOp::getRegionBuilder());
4717void BatchMatmulOp::print(OpAsmPrinter &p) {
4718 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4719 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4720 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4721 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4722 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4724 std::array<StringRef, 3> elidedAttrs = {
4725 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4731LogicalResult BatchMatmulOp::verify() {
4734 if (!hasUserDefinedMaps())
4737 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4744LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4745 SmallVectorImpl<OpFoldResult> &) {
4749void BatchMatmulOp::getEffects(
4750 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4752 if (hasPureTensorSemantics())
4766struct ArityGroupAndKind {
4768 ElementwiseArityGroup arityGroup;
4774 TernaryFn ternaryFn;
4778unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4779 return static_cast<unsigned>(arityGroup);
4784 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4785 constexpr int lastBinary =
4786 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4787 constexpr int lastTernary =
4788 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4790 int val =
static_cast<int>(kind);
4791 ArityGroupAndKind
result;
4793 if (val < lastUnary) {
4794 result.arityGroup = ElementwiseArityGroup::Unary;
4795 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4798 if (val < lastBinary) {
4799 result.arityGroup = ElementwiseArityGroup::Binary;
4800 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4803 if (val >= lastTernary) {
4804 llvm_unreachable(
"unhandled ElementwiseFn");
4806 result.arityGroup = ElementwiseArityGroup::Ternary;
4807 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4812 auto rank = getResultRank();
4817ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4823ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4826 mlir::linalg::ElementwiseKind elemwiseKindVal;
4831 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4832 if (!elemwiseKindAttr)
4834 "expected ElementwiseKind attribute");
4835 elemwiseKindVal = elemwiseKindAttr.getValue();
4838 "expected operation 'kind' attribute");
4841 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4844 SmallVector<Attribute, 3> indexingMapsAttr;
4854 if (!isa<AffineMapAttr>(mapAttr))
4856 "expected affine map attribute");
4857 indexingMapsAttr.push_back(mapAttr);
4868 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4870 ElementwiseOp::getRegionBuilder())) {
4872 "unable to parse elemwise op");
4876 if (indexingMapsAttr.empty()) {
4879 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4880 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4883 "return type needs to be shaped type");
4884 auto numDims = shapedType.getRank();
4885 indexingMapsAttr = llvm::map_to_vector(
4886 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4888 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4891 result.addAttribute(
"indexing_maps",
4896void ElementwiseOp::print(OpAsmPrinter &p) {
4899 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4903 unsigned numDims = getResultRank();
4905 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4906 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4908 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4910 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4911 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4919void ElementwiseOp::regionBuilder(
4920 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4922 ElementwiseKind elemwiseKind;
4923 for (
auto attr : attrs) {
4924 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4925 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4926 assert(kindAttr &&
"op kind attribute incorrectly set");
4927 elemwiseKind = kindAttr.getValue();
4933 auto arityGroup = groupAndKind.arityGroup;
4934 auto kind = groupAndKind.kind;
4936 getArityGroupAsUInt(arityGroup) + 1 ) {
4937 emitError() <<
"Elementwise regionBuilder expects "
4938 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4943 getArityGroupAsUInt(arityGroup) + 1
4944 &&
"Elementwise regionBuilder number of block args mismatch");
4946 RegionBuilderHelper helper(
b, block);
4947 SmallVector<Value> yields;
4950 if (arityGroup == ElementwiseArityGroup::Unary) {
4953 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4957 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4962 assert(
false &&
"found unhandled category in elemwise");
4965 yields.push_back(
result);
4966 helper.yieldOutputs(yields);
4969LogicalResult ElementwiseOp::fold(FoldAdaptor,
4970 SmallVectorImpl<OpFoldResult> &) {
4974void ElementwiseOp::getEffects(
4975 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4977 if (hasPureTensorSemantics())
4990template <
typename OpTy,
typename>
4993 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4994 ? packOrUnPack.getDestType()
4995 : packOrUnPack.getSourceType();
4996 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4997 ? packOrUnPack.getSourceType()
4998 : packOrUnPack.getDestType();
5000 packedType.getShape().take_front(unpackedType.getRank()));
5001 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5023 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5025 .take_back(mixedTiles.size()),
5027 int64_t dimSize = std::get<0>(it);
5028 if (dimSize == ShapedType::kDynamic) {
5029 newMixedTileSizes.push_back(std::get<1>(it));
5036 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5038 newMixedTileSizes.push_back(
tile);
5041 "tile size and dim size don't match!");
5042 newMixedTileSizes.push_back(
5047 return newMixedTileSizes;
5050template <
typename OpTy>
5054 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5055 "applies to only pack or unpack operations");
5056 int64_t destRank = op.getDestRank();
5058 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5059 reifiedReturnShapes[0][dim] =
5064template <
typename OpTy>
5066 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5067 "applies to only pack or unpack operations");
5071 assert(tiles.size() == dimsToTile.size() &&
5072 "tiles must match indices of dimension to block");
5074 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5075 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5076 return dimAndTileMapping;
5079template <
typename OpTy>
5081 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5082 "applies to only pack or unpack operations");
5085 unsigned dynamicValIndex = 0;
5086 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5087 if (ShapedType::isStatic(staticTile))
5090 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5092 return mixedInnerTiles;
5095template <
typename OpTy>
5097 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5098 "applies to only pack or unpack operations");
5111 size_t dimsPosSize = dimsPos.size();
5112 if (dimsPosSize > rank)
5115 if (dimsPosSize != uniqued.size())
5117 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5118 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5122template <
typename OpTy>
5124 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5125 "applies to only pack or unpack operations");
5126 Operation *op = packOrUnPack.getOperation();
5136 if (!packOrUnPack.getSourceType().hasRank() ||
5137 !packOrUnPack.getDestType().hasRank())
5138 return op->
emitError(
"expected both source and destination to have rank");
5141 if (!packOrUnPack.hasPureBufferSemantics() &&
5142 !packOrUnPack.hasPureTensorSemantics())
5143 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5144 const unsigned numResults = packOrUnPack.getNumResults();
5145 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5146 return op->
emitError(
"expected 1 result, got ") << numResults;
5147 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5148 return op->
emitError(
"expected 0 results, got ") << numResults;
5152 if (hasZeros(mixedTiles))
5153 return op->
emitError(
"invalid zero tile factor");
5156 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5157 ? packOrUnPack.getSourceType()
5158 : packOrUnPack.getDestType();
5159 size_t unpackedRank = unpackedType.getRank();
5163 return op->
emitError(
"invalid inner_dims_pos vector");
5165 return op->
emitError(
"invalid outer_dims_perm vector");
5166 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5167 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5171 if (mixedTiles.size() > unpackedRank) {
5172 return op->
emitError(
"tiling factors must be less than or equal to the "
5173 "input rank for pack or output rank for unpack");
5175 if (mixedTiles.size() != innerDimsPos.size()) {
5177 "tiling factors must equal the number of dimensions to tile");
5180 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5181 ? packOrUnPack.getDestType()
5182 : packOrUnPack.getSourceType();
5183 size_t packedRank = packedType.getRank();
5185 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5186 if (expectedPackedRank != packedRank) {
5188 "packed rank != (unpacked rank + num tiling factors), got ")
5189 << packedRank <<
" != " << expectedPackedRank;
5196 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5197 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5198 for (
auto it : llvm::enumerate(llvm::zip(
5199 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5200 int64_t dimSize = std::get<0>(it.value());
5202 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5203 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5204 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5205 if (dimSize != staticTileSize)
5207 "mismatch in inner tile sizes specified and shaped of "
5208 "tiled dimension in the packed type at index ")
5209 << it.index() <<
": got " << dimSize <<
" != " << staticTileSize;
5210 }
else if (!ShapedType::isDynamic(dimSize)) {
5211 return op->
emitError(
"mismatch in inner tile sizes specified at index ")
5212 << it.index() <<
": got static shape " << dimSize
5213 <<
" but dynamic tile size";
5218 auto elementType = unpackedType.getElementType();
5219 Type expectedType, actualType;
5220 if (packOrUnPack.hasPureTensorSemantics()) {
5221 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5222 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5224 expectedType = MemRefType::get(expectedPackedShape, elementType);
5225 actualType = MemRefType::get(packedType.getShape(), elementType);
5228 << expectedType <<
" for the packed domain value, got "
5241struct PackOrUnPackTransposeResult {
5248template <
typename OpTy>
5249static PackOrUnPackTransposeResult
5253 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5254 "applies to only pack or unpack operations");
5255 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5256 "some permutation must be non-empty");
5257 PackOrUnPackTransposeResult metadata;
5258 metadata.innerDimsPos =
5260 metadata.innerTiles =
5262 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5263 ? packOrUnPackOp.getSourceRank()
5264 : packOrUnPackOp.getDestRank();
5265 metadata.outerDimsPerm =
5266 packOrUnPackOp.getOuterDimsPerm().empty()
5267 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5269 if (!innerPermutation.empty()) {
5270 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5272 "invalid inner permutation");
5276 if (!outerPermutation.empty()) {
5277 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5279 "invalid outer permutation");
5290 if (!getResults().empty())
5291 setNameFn(getResult(),
"pack");
5301 Type sourceType, destType, resultType;
5318 SmallVector<int64_t> outerDimsPermVec;
5321 if (parser.parseInteger(value))
5323 outerDimsPermVec.push_back(value);
5333 SmallVector<int64_t> innerDimsPosVec;
5336 if (parser.parseInteger(value))
5338 innerDimsPosVec.push_back(value);
5350 for (
auto val : staticTilesAttr.
asArrayRef())
5351 staticTiles.push_back(val);
5368 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5371 "pack/unpack requires '->' and destination type");
5375 resultType = destType;
5381 if (!paddingValue.empty() &&
5386 if (!dynamicTiles.empty() &&
5391 result.addAttribute(
"static_inner_tiles",
5393 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5395 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5397 SmallVector<int32_t> segmentSizes = {
5398 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5399 static_cast<int32_t
>(dynamicTiles.size())};
5400 result.addAttribute(
"operandSegmentSizes",
5404 result.addTypes(resultType);
5409void PackOp::print(OpAsmPrinter &p) {
5410 p <<
" " << getSource();
5412 if (getPaddingValue()) {
5413 p <<
" padding_value(" << getPaddingValue() <<
" : "
5414 << getPaddingValue().getType() <<
")";
5417 if (!getOuterDimsPerm().empty()) {
5418 p <<
" outer_dims_perm = [";
5419 llvm::interleaveComma(getOuterDimsPerm(), p);
5423 p <<
" inner_dims_pos = [";
5424 llvm::interleaveComma(getInnerDimsPos(), p);
5427 p <<
" inner_tiles = ";
5430 p <<
" into " << getDest();
5433 {
"static_inner_tiles",
"inner_dims_pos",
5434 "outer_dims_perm",
"operandSegmentSizes"});
5436 p <<
" : " << getSource().getType();
5437 p <<
" -> " << getDest().getType();
5440void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5441 Value dest, ArrayRef<int64_t> innerDimsPos,
5442 ArrayRef<OpFoldResult> innerTiles,
5443 std::optional<Value> paddingValue,
5444 ArrayRef<int64_t> outerDimsPerm) {
5445 assert(innerDimsPos.size() == innerTiles.size() &&
5446 "number of tile sizes specified must match the specified number of "
5447 "original dimensions to be tiled");
5448 SmallVector<int64_t> staticTileSizes;
5449 SmallVector<Value> dynamicTileSizes;
5451 build(builder, state, dest.
getType(), source, dest,
5452 paddingValue ? *paddingValue :
nullptr,
5453 outerDimsPerm.empty() ?
nullptr
5460PackOp::reifyResultShapes(OpBuilder &builder,
5469SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5473SmallVector<int64_t> PackOp::getStaticTiles() {
5477ArrayRef<int64_t> PackOp::getAllOuterDims() {
5478 ShapedType inputType = getSourceType();
5479 int64_t inputRank = inputType.getRank();
5480 return getDestType().getShape().take_front(inputRank);
5483SmallVector<int64_t> PackOp::getTiledOuterDims() {
5484 auto innerDimsPos = getInnerDimsPos();
5485 SmallVector<int64_t> outerDims(getAllOuterDims());
5486 SmallVector<int64_t> res;
5489 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5491 if (!outerDimPermInv.empty())
5495 for (
auto index : innerDimsPos)
5496 res.push_back(outerDims[index]);
5501bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5502 ArrayRef<int64_t> innerDimsPos,
5503 ArrayRef<int64_t> outputShape,
5504 ArrayRef<int64_t> outerDimsPerm,
5505 ArrayRef<OpFoldResult> innerTiles) {
5506 SmallVector<int64_t> outputTileSizes(
5507 outputShape.take_front(inputShape.size()));
5508 if (!outerDimsPerm.empty()) {
5509 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5510 "expected output and outer_dims_perm to have same size");
5514 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5515 if (ShapedType::isDynamic(inputShape[pos]))
5518 if (!constantTile) {
5519 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5520 (inputShape[pos] % outputTileSizes[pos] != 0))
5523 assert(*constantTile != 0 &&
"static tile size can't be zero");
5524 if (inputShape[pos] % (*constantTile) != 0) {
5532bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5533 ArrayRef<int64_t> innerDimsPos,
5534 ArrayRef<int64_t> outputShape,
5535 ArrayRef<int64_t> outerDimsPerm,
5536 ArrayRef<OpFoldResult> innerTiles) {
5537 SmallVector<int64_t> outputTileSizes(
5538 outputShape.take_front(inputShape.size()));
5539 if (!outerDimsPerm.empty()) {
5540 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5541 "expected output and outer_dims_perm to have same size");
5545 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5546 if (ShapedType::isDynamic(inputShape[pos]) ||
5547 ShapedType::isDynamic(outputTileSizes[pos]))
5552 assert(*constantTile != 0 &&
"static tile size can't be zero");
5553 if (inputShape[pos] % (*constantTile) != 0)
5559LogicalResult PackOp::verify() {
5566 auto paddingValue = getPaddingValue();
5570 << getSourceType().getElementType()
5571 <<
" but got: " << paddingValue.getType();
5574 if (!paddingValue &&
5575 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5576 getDestType().
getShape(), getOuterDimsPerm(),
5579 "invalid tile factor or output size provided. Only full tiles are "
5580 "supported when padding_value is not set");
5587static SmallVector<int64_t>
5590 for (
auto o : ofrs) {
5592 if (llvm::dyn_cast_if_present<Value>(o))
5593 result.push_back(ShapedType::kDynamic);
5605 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5606 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5608 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5609 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5612 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5613 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5617 if (!outerDimsPerm.empty())
5621 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5625SmallVector<OpFoldResult> PackOp::getResultShape(
5626 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5627 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5628 ArrayRef<int64_t> outerDimsPerm) {
5629 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5633 AffineExpr ceilDivExpr = s0.
ceilDiv(s1);
5634 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5636 builder, loc, ceilDivExpr,
5637 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5639 if (!outerDimsPerm.empty())
5641 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5643 SmallVector<int64_t> resultTypeShape =
5646 innerDimsPos, outerDimsPerm);
5652 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5653 if (ShapedType::isStatic(resultTypeShape[i]))
5662RankedTensorType PackOp::inferPackedTensorType(
5663 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5664 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5665 SmallVector<int64_t> resultShape = inferPackedShape(
5666 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5667 return RankedTensorType::get(resultShape, sourceType.getElementType());
5670MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5671 ArrayRef<int64_t> innerTileSizes,
5672 ArrayRef<int64_t> innerDimsPos,
5673 ArrayRef<int64_t> outerDimsPerm) {
5674 SmallVector<int64_t> resultShape = inferPackedShape(
5675 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5676 return MemRefType::get(resultShape, sourceType.getElementType());
5679Value PackOp::createDestinationTensor(OpBuilder &
b, Location loc, Value source,
5680 ArrayRef<OpFoldResult> innerTileSizes,
5681 ArrayRef<int64_t> innerDimsPos,
5682 ArrayRef<int64_t> outerDimsPerm) {
5683 AffineExpr dim0, dim1;
5685 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5690 SmallVector<OpFoldResult> mixedSizes;
5691 for (
auto [index, value] : llvm::enumerate(
5692 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5693 if (ShapedType::isDynamic(value))
5694 mixedSizes.push_back(
5695 tensor::DimOp::create(
b, loc, source, index).getResult());
5697 mixedSizes.push_back(
b.getIndexAttr(value));
5699 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5700 int64_t dimPos = std::get<0>(it);
5701 OpFoldResult tileSize = std::get<1>(it);
5702 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5704 if (!outerDimsPerm.empty())
5707 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5708 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5709 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5712PackOp PackOp::createTransposedClone(OpBuilder &
b, Location loc,
5713 ArrayRef<int64_t> innerPermutation,
5714 ArrayRef<int64_t> outerPermutation) {
5716 *
this, innerPermutation, outerPermutation);
5717 Value transposedDest =
5718 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5719 metadata.innerDimsPos, metadata.outerDimsPerm);
5720 return PackOp::create(
b, loc, getSource(), transposedDest,
5721 metadata.innerDimsPos, metadata.innerTiles,
5722 getPaddingValue(), metadata.outerDimsPerm);
5725template <
typename OpTy>
5730 if (op.hasPureTensorSemantics())
5733 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5734 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5737 if (&opOperand == &op.getSourceMutable()) {
5741 }
else if (&opOperand == &op.getDestMutable()) {
5752void PackOp::getEffects(
5758void UnPackOp::getEffects(
5765template <
typename OpTy>
5767 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5768 "applies to only pack or unpack operations");
5769 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5771 : op.getSourceType();
5773 for (
auto [dimDest,
tile] : llvm::zip(
5774 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5776 if (!constTileSize || ShapedType::isDynamic(dimDest))
5783 if (!hasPureTensorSemantics())
5785 if (getPaddingValue())
5800 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5802 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5814 auto packTiles = packOp.getMixedTiles();
5815 auto unPackTiles = unPackOp.getMixedTiles();
5816 if (packTiles.size() != unPackTiles.size())
5818 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5827 auto srcType = op.getSourceType();
5828 if (llvm::any_of(op.getInnerDimsPos(),
5829 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5831 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5833 return !PackOp::requirePaddingValue(
5834 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5835 op.getOuterDimsPerm(), op.getMixedTiles());
5842 bool changeNeeded =
false;
5843 srcShape.assign(packOp.getSourceType().getShape().begin(),
5844 packOp.getSourceType().getShape().end());
5845 destShape.assign(packOp.getDestType().getShape().begin(),
5846 packOp.getDestType().getShape().end());
5847 llvm::SmallSetVector<int64_t, 4> innerDims;
5848 innerDims.insert_range(packOp.getInnerDimsPos());
5850 if (!packOp.getOuterDimsPerm().empty())
5852 int srcRank = packOp.getSourceRank();
5853 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5854 if (innerDims.contains(i))
5858 if (!inverseOuterDimsPerm.empty())
5859 destPos = inverseOuterDimsPerm[srcPos];
5860 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5861 ShapedType::isDynamic(destShape[destPos])) {
5864 int64_t size = srcShape[srcPos];
5865 if (ShapedType::isDynamic(size))
5866 size = destShape[destPos];
5867 srcShape[srcPos] = size;
5868 destShape[destPos] = size;
5869 changeNeeded =
true;
5871 return changeNeeded;
5874LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5876 if (!packOp.hasPureTensorSemantics())
5880 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5881 if (unPackOp.getSourceType() == packOp.getDestType() &&
5882 !packOp.getPaddingValue() &&
5885 rewriter.
replaceOp(packOp, unPackOp.getSource());
5893 packOp.getPaddingValueMutable().clear();
5899 SmallVector<int64_t> srcShape, destShape;
5901 Location loc = packOp.getLoc();
5902 Value source = packOp.getSource();
5903 if (srcShape != packOp.getSourceType().getShape()) {
5904 auto newSrcType = packOp.getSourceType().clone(srcShape);
5906 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5908 Value dest = packOp.getDest();
5909 ShapedType originalResultType = packOp.getDestType();
5910 bool needUpdateDestType = (destShape != originalResultType.getShape());
5911 if (needUpdateDestType) {
5912 auto newDestType = packOp.getDestType().clone(destShape);
5914 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5917 packOp.getSourceMutable().assign(source);
5918 packOp.getDestMutable().assign(dest);
5919 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5922 if (needUpdateDestType) {
5924 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5925 packOp.getResult());
5934template <
typename PackOrUnpackOp>
5936 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5937 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5938 "Function meant for pack/unpack");
5943 int64_t numPackedDims = innerDimsPos.size();
5944 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5945 if (orderedDims != innerDimsPos) {
5951 int64_t packedRank = packedTensorType.getRank();
5961 return llvm::all_of(
5962 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5963 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5966bool PackOp::isLikePad() {
5967 auto packedTensorType =
5968 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5972::mlir::LogicalResult
5973PackOp::fold(FoldAdaptor adaptor,
5975 if (!hasPureTensorSemantics())
5977 std::optional<Attribute> paddingValue;
5978 if (
auto pad = adaptor.getPaddingValue())
5980 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5981 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5982 cast<TensorType>(getDestType()), paddingValue)) {
5983 results.push_back(reshapedSource);
6009 if (!op.hasPureTensorSemantics())
6030 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6031 op.getInnerDimsPos(), newMixedTileSizes,
6032 op.getPaddingValue(), op.getOuterDimsPerm());
6033 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6036 Value oldResult = op.getResult();
6037 Value newResult = newOp.getResult();
6040 ? tensor::CastOp::create(rewriter, op->getLoc(),
6041 oldResult.
getType(), newResult)
6054void UnPackOp::getAsmResultNames(
6056 if (!getResults().empty())
6057 setNameFn(getResult(),
"unpack");
6066 Type sourceType, destType, resultType;
6078 if (parser.parseInteger(value))
6080 outerDimsPermVec.push_back(value);
6090 SmallVector<int64_t> innerDimsPosVec;
6093 if (parser.parseInteger(value))
6095 innerDimsPosVec.push_back(value);
6107 for (
auto val : staticTilesAttr.
asArrayRef())
6108 staticTiles.push_back(val);
6125 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6128 "pack/unpack requires '->' and destination type");
6132 resultType = destType;
6138 if (!dynamicTiles.empty() &&
6143 result.addAttribute(
"static_inner_tiles",
6145 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6147 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6149 SmallVector<int32_t> segmentSizes = {
6150 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6151 result.addAttribute(
"operandSegmentSizes",
6155 result.addTypes(resultType);
6160void UnPackOp::print(OpAsmPrinter &p) {
6161 p <<
" " << getSource();
6163 if (!getOuterDimsPerm().empty()) {
6164 p <<
" outer_dims_perm = [";
6165 llvm::interleaveComma(getOuterDimsPerm(), p);
6169 p <<
" inner_dims_pos = [";
6170 llvm::interleaveComma(getInnerDimsPos(), p);
6173 p <<
" inner_tiles = ";
6176 p <<
" into " << getDest();
6179 {
"static_inner_tiles",
"inner_dims_pos",
6180 "outer_dims_perm",
"operandSegmentSizes"});
6182 p <<
" : " << getSource().getType();
6183 p <<
" -> " << getDest().getType();
6187UnPackOp::reifyResultShapes(OpBuilder &builder,
6196SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6200SmallVector<int64_t> UnPackOp::getStaticTiles() {
6204ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6205 ShapedType destType = getDestType();
6206 int64_t destRank = destType.getRank();
6207 return getSourceType().getShape().take_front(destRank);
6210SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6211 auto innerDimsPos = getInnerDimsPos();
6212 SmallVector<int64_t> outerDims(getAllOuterDims());
6213 SmallVector<int64_t> res;
6216 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6218 if (!outerDimPermInv.empty())
6222 for (
auto index : innerDimsPos)
6223 res.push_back(outerDims[index]);
6228LogicalResult UnPackOp::verify() {
6233 if (!hasPureTensorSemantics())
6242void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6243 Value dest, ArrayRef<int64_t> innerDimsPos,
6244 ArrayRef<OpFoldResult> innerTiles,
6245 ArrayRef<int64_t> outerDimsPerm) {
6246 assert(innerDimsPos.size() == innerTiles.size() &&
6247 "number of tile sizes specified must match the specified number of "
6248 "original dimensions to be tiled");
6249 SmallVector<int64_t> staticTileSizes;
6250 SmallVector<Value> dynamicTileSizes;
6252 build(builder, state, dest.
getType(), source, dest,
6253 outerDimsPerm.empty() ?
nullptr
6259Value UnPackOp::createDestinationTensor(OpBuilder &
b, Location loc,
6261 ArrayRef<OpFoldResult> innerTileSizes,
6262 ArrayRef<int64_t> innerDimsPos,
6263 ArrayRef<int64_t> outerDimsPerm) {
6264 AffineExpr sym0, sym1;
6266 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6270 SmallVector<OpFoldResult> mixedSizes;
6271 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6273 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6274 if (srcType.isDynamicDim(i))
6275 mixedSizes.push_back(
6276 tensor::DimOp::create(
b, loc, source, i).getResult());
6278 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6280 if (!outerDimsPerm.empty()) {
6285 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6286 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6288 auto elemType = srcType.getElementType();
6289 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6292UnPackOp UnPackOp::createTransposedClone(OpBuilder &
b, Location loc,
6293 Value transposedSource,
6294 ArrayRef<int64_t> innerPermutation,
6295 ArrayRef<int64_t> outerPermutation) {
6297 *
this, innerPermutation, outerPermutation);
6298 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6299 metadata.innerDimsPos, metadata.innerTiles,
6300 metadata.outerDimsPerm);
6307 bool changeNeeded =
false;
6308 srcShape.assign(op.getSourceType().getShape().begin(),
6309 op.getSourceType().getShape().end());
6310 destShape.assign(op.getDestType().getShape().begin(),
6311 op.getDestType().getShape().end());
6312 llvm::SmallSetVector<int64_t, 4> innerDims;
6313 innerDims.insert_range(op.getInnerDimsPos());
6315 if (!op.getOuterDimsPerm().empty())
6317 int destRank = op.getDestRank();
6318 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6319 if (innerDims.contains(i))
6323 if (!inverseOuterDimsPerm.empty())
6324 srcPos = inverseOuterDimsPerm[destPos];
6325 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6326 ShapedType::isDynamic(destShape[destPos])) {
6329 int64_t size = srcShape[srcPos];
6330 if (ShapedType::isDynamic(size))
6331 size = destShape[destPos];
6332 srcShape[srcPos] = size;
6333 destShape[destPos] = size;
6334 changeNeeded =
true;
6336 return changeNeeded;
6339LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6342 if (!unPackOp.hasPureTensorSemantics())
6346 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6347 if (packOp.getSourceType() != unPackOp.getDestType())
6349 if (packOp.getPaddingValue() ||
6353 rewriter.
replaceOp(unPackOp, packOp.getSource());
6357 if (
auto dstStyleOp =
6358 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6359 auto destValue = cast<OpResult>(unPackOp.getDest());
6360 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6362 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6366 if (unPackOp->hasOneUse()) {
6367 auto extractSliceUser =
6368 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6369 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6370 OpBuilder::InsertionGuard g(rewriter);
6372 auto newDest = tensor::ExtractSliceOp::create(
6373 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6374 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6375 extractSliceUser.getMixedStrides());
6377 unPackOp.setDpsInitOperand(0, newDest);
6378 unPackOp.getResult().setType(newDest.
getType());
6380 rewriter.
replaceOp(extractSliceUser, unPackOp);
6386 SmallVector<int64_t> srcShape, destShape;
6388 Location loc = unPackOp.getLoc();
6389 Value source = unPackOp.getSource();
6390 if (srcShape != unPackOp.getSourceType().getShape()) {
6391 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6392 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6393 unPackOp.getSource());
6395 Value dest = unPackOp.getDest();
6396 if (destShape != unPackOp.getDestType().getShape()) {
6397 auto newDestType = unPackOp.getDestType().clone(destShape);
6398 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6399 unPackOp.getDest());
6401 UnPackOp newOp = UnPackOp::create(
6402 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6403 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6405 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6412bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6414 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6419 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6420 SmallVector<int64_t> outerShapeWithoutTranspose =
6422 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(),
false);
6423 for (
auto [pos, tileSize] :
6424 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6425 areOuterDimsTiled[pos] =
true;
6426 if (unpackedTypeAfterFold.isDynamicDim(pos))
6428 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6430 if (ShapedType::isDynamic(tileSize))
6432 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6433 unpackedTypeAfterFold.getDimSize(pos);
6434 if (paddingSize >= tileSize)
6438 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6439 if (areOuterDimsTiled[pos])
6441 int64_t dim = outerShapeWithoutTranspose[pos];
6442 if (ShapedType::isDynamic(dim))
6444 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6450bool UnPackOp::isLikeUnPad() {
6451 ShapedType packedTensorType = getSourceType();
6455::mlir::LogicalResult
6456UnPackOp::fold(FoldAdaptor adaptor,
6457 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6459 if (!hasPureTensorSemantics())
6462 if (OpFoldResult reshapedSource = reshapeConstantSource(
6463 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6464 cast<TensorType>(getResult().
getType()))) {
6465 results.push_back(reshapedSource);
6491 if (!op.hasPureTensorSemantics())
6500 Value sourceTensor = newOperands[0];
6504 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6510 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6511 newOperands[1], op.getInnerDimsPos(),
6512 newMixedTileSizes, op.getOuterDimsPerm());
6513 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6516 Value oldResult = op.getResult();
6517 Value newResult = newOp.getResult();
6520 ? tensor::CastOp::create(rewriter, op->getLoc(),
6521 oldResult.
getType(), newResult)
6535 utils::IteratorType::reduction, utils::IteratorType::parallel,
6536 utils::IteratorType::parallel, utils::IteratorType::reduction};
6539SmallVector<AffineMap>
6540BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6541 AffineExpr d0, d1, d2, d3;
6542 SmallVector<AffineMap> indexingMaps;
6544 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6545 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6547 return indexingMaps;
6550bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6551 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6554 if (maps.size() != 3)
6559 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6560 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6561 (*positions)[2] == SmallVector<int64_t>{1, 2};
6563unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6565std::string BatchReduceMatmulOp::getLibraryCallName() {
6571bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6572 SmallVector<AffineMap, 3> defaultMaps =
6574 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6575 return defaultMaps != explicitMaps;
6585bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6588 "Expected less than 3 result dim expr.");
6589 bool isValid =
false;
6590 enum Indices { batchPos, mPos, nPos, kPos };
6592 AffineExpr expr = bcastMap.
getResult(0);
6595 AffineExpr expr0 = bcastMap.
getResult(0);
6596 AffineExpr expr1 = bcastMap.
getResult(1);
6601 : ((expr0.isFunctionOfDim(batchPos) &&
6602 expr1.isFunctionOfDim(kPos)) ||
6603 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6608void BatchReduceMatmulOp::regionBuilder(
6609 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
6612 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6617 "BatchReduceMatmulOp regionBuilder expects 3 args");
6618 RegionBuilderHelper helper(
b, block);
6619 SmallVector<Value> yields;
6623 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6625 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6627 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6628 if (!castValA || !castValB || !mulVal)
6631 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6634 yields.push_back(addVal);
6635 helper.yieldOutputs(yields);
6638ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6639 OperationState &
result) {
6640 SmallVector<Attribute, 3> indexingMapsAttr;
6651 if (!isa<AffineMapAttr>(mapAttr)) {
6653 "expected affine map attribute");
6655 indexingMapsAttr.push_back(mapAttr);
6665 if (indexingMapsAttr.empty()) {
6666 indexingMapsAttr = llvm::map_to_vector(
6667 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6668 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6670 result.addAttribute(
"indexing_maps",
6672 return ::parseNamedStructuredOp(parser,
result,
6673 BatchReduceMatmulOp::getNumRegionArgs(),
6674 BatchReduceMatmulOp::getRegionBuilder());
6677void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6678 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6679 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6680 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
6682 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6683 p <<
" indexing_maps = [";
6684 llvm::interleaveComma(getIndexingMaps(), p,
6689 SmallVector<StringRef, 3> elidedAttrs = {
6690 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6696LogicalResult BatchReduceMatmulOp::verify() {
6699 if (!hasUserDefinedMaps())
6702 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6708LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6709 SmallVectorImpl<OpFoldResult> &) {
6712void BatchReduceMatmulOp::getEffects(
6713 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6715 if (hasPureTensorSemantics())
6731void LinalgDialect::getCanonicalizationPatterns(
6740 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override