41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
62 auto type = cast<ShapedType>(v.
getType());
63 if (!type.isDynamicDim(dim))
68 .Case([&](RankedTensorType t) ->
Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
71 .Case([&](MemRefType t) ->
Value {
72 return memref::DimOp::create(builder, loc, v, dim);
83 .Case([&](RankedTensorType t) ->
Operation * {
84 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
87 .Case([&](MemRefType type) ->
Operation * {
88 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
94static std::optional<TypedAttr>
98 if (!splatAttr || !splatAttr.
isSplat())
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable(
"Expected MemRefType or TensorType");
119 auto shapedType = llvm::cast<ShapedType>(source.
getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
122 return b.getIndexAttr(shapedType.getDimSize(dim));
145 for (
auto containers : {inputTypes, outputTypes}) {
146 for (
auto t : containers) {
158 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
174 std::optional<TypeRange> resultTensorTypes,
181 if (!resultTensorTypes)
182 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
203 std::optional<TypeRange> resultTensorTypes,
210 return attr.
getName() ==
"indexing_maps";
213 indexingMapsAttrVal = llvm::map_to_vector(
216 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
219 attributes, regionBuilder);
223 std::optional<TypeRange> resultTensorTypes,
230 return attr.
getName() ==
"indexing_maps";
233 indexingMapsAttrVal = llvm::map_to_vector(
236 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
239 attributes, regionBuilder);
243 std::optional<TypeRange> resultTensorTypes,
250 indexingMapsAttrVal =
252 return AffineMapAttr::get(map);
254 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
256 attributes, regionBuilder);
265 bool addOperandSegmentSizes =
true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
295 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
297 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
301 if (addOperandSegmentSizes) {
308 if (
result.propertiesAttr) {
310 attrs.
append(
"operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
316 result.addAttribute(
"operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
322 if (!
result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
326 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 <<
"'" << result.name.getStringRef() <<
"' op ";
339 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
340 if (!outputs.empty())
341 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
355 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
363 opBuilder, region, inputTypes, outputTypes, attrs,
382 unsigned numRegionArgs,
399 result.addTypes(outputTensorsTypes);
401 std::unique_ptr<Region> region = std::make_unique<Region>();
403 outputTypes,
result.attributes.getAttrs(),
406 result.addRegion(std::move(region));
413 if (resultTypes.empty())
458class RegionBuilderHelper {
460 RegionBuilderHelper(OpBuilder &builder,
Block &block)
461 : builder(builder), block(block) {}
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
466 if (!isFloatingPoint(arg)) {
468 emitError() <<
"unsupported non numeric type";
471 llvm_unreachable(
"unsupported non numeric type");
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
477 return math::ExpOp::create(builder, arg.
getLoc(), arg);
479 return math::LogOp::create(builder, arg.
getLoc(), arg);
481 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
483 return math::CeilOp::create(builder, arg.
getLoc(), arg);
485 return math::FloorOp::create(builder, arg.
getLoc(), arg);
487 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.
getType());
490 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
495 return math::RoundOp::create(builder, arg.
getLoc(), arg);
497 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
499 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
503 return math::TanhOp::create(builder, arg.
getLoc(), arg);
505 return math::ErfOp::create(builder, arg.
getLoc(), arg);
508 emitError() <<
"unsupported unary function";
511 llvm_unreachable(
"unsupported unary function");
518 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
520 bool allComplex = isComplex(arg0) && isComplex(arg1);
521 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
522 bool allInteger = isInteger(arg0) && isInteger(arg1);
525 if (!allComplex && !allFloatingPoint && !allInteger) {
528 <<
"Cannot build binary Linalg operation: expects allComplex, "
529 "allFloatingPoint, or allInteger, got "
533 llvm_unreachable(
"unsupported non numeric type");
535 OpBuilder::InsertionGuard g(builder);
536 builder.setInsertionPointToEnd(&block);
540 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
541 if (allFloatingPoint)
542 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
544 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
545 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
548 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
549 if (allFloatingPoint)
550 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
553 emitError() <<
"unsupported operation: sub with bools";
556 llvm_unreachable(
"unsupported operation: sub with bools");
558 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
561 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
562 if (allFloatingPoint)
563 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
565 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
566 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
569 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
570 if (allFloatingPoint)
571 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
574 emitError() <<
"unsupported operation: div with bools";
577 llvm_unreachable(
"unsupported operation: div with bools");
579 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 case BinaryFn::div_unsigned:
581 if (!allInteger || allBool) {
583 emitError() <<
"unsupported operation: unsigned div not on uint";
586 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
588 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
589 case BinaryFn::max_signed:
591 if (allFloatingPoint)
592 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
593 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
594 case BinaryFn::min_signed:
596 if (allFloatingPoint)
597 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
598 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
599 case BinaryFn::max_unsigned:
601 if (!allInteger || allBool) {
603 emitError() <<
"unsupported operation: unsigned max not on uint";
606 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
608 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
609 case BinaryFn::min_unsigned:
611 if (!allInteger || allBool) {
613 emitError() <<
"unsupported operation: unsigned min not on uint";
616 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
618 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
620 assert(allFloatingPoint);
621 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
624 emitError() <<
"unsupported binary function";
627 llvm_unreachable(
"unsupported binary function");
631 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
633 OpBuilder::InsertionGuard g(builder);
634 builder.setInsertionPointToEnd(&block);
636 case TernaryFn::select:
637 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
640 emitError() <<
"unsupported ternary function";
643 llvm_unreachable(
"unsupported ternary function");
647 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
650 case TypeFn::cast_signed:
651 return cast(toType, operand,
false);
652 case TypeFn::cast_unsigned:
653 return cast(toType, operand,
true);
656 emitError() <<
"unsupported type conversion function";
659 llvm_unreachable(
"unsupported type conversion function");
663 OpBuilder::InsertionGuard g(builder);
664 builder.setInsertionPointToEnd(&block);
665 Location loc = builder.getUnknownLoc();
666 YieldOp::create(builder, loc, values);
669 Value constant(
const std::string &value) {
670 OpBuilder::InsertionGuard g(builder);
671 builder.setInsertionPointToEnd(&block);
672 Location loc = builder.getUnknownLoc();
673 Attribute valueAttr =
parseAttribute(value, builder.getContext());
674 return arith::ConstantOp::create(builder, loc,
675 ::cast<TypedAttr>(valueAttr));
678 Value index(int64_t dim) {
679 OpBuilder::InsertionGuard g(builder);
680 builder.setInsertionPointToEnd(&block);
681 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
684 Type getIntegerType(
unsigned width) {
685 return IntegerType::get(builder.getContext(), width);
688 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
689 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
696 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
697 OpBuilder::InsertionGuard g(builder);
698 builder.setInsertionPointToEnd(&block);
699 auto loc = operand.
getLoc();
700 if (isa<UnknownLoc>(loc)) {
710 bool isComplex(Value value) {
711 return llvm::isa<ComplexType>(value.
getType());
713 bool isFloatingPoint(Value value) {
714 return llvm::isa<FloatType>(value.
getType());
716 bool isInteger(Value value) {
717 return llvm::isa<IntegerType>(value.
getType());
733 using OpRewritePattern<CopyOp>::OpRewritePattern;
734 LogicalResult matchAndRewrite(CopyOp copyOp,
735 PatternRewriter &rewriter)
const override {
736 if (copyOp.getInputs() != copyOp.getOutputs())
738 if (copyOp.hasPureBufferSemantics())
741 rewriter.
replaceOp(copyOp, copyOp.getInputs());
751 results.
add<EraseSelfCopy>(context);
764template <
typename TensorReshapeOp>
765struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
766 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
767 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
768 PatternRewriter &rewriter)
const override {
769 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
773 Location loc = oldFill.getLoc();
774 TensorReshapeOp newInit;
775 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
777 newInit = TensorReshapeOp::create(
778 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
779 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
780 reshapeOp.getStaticOutputShape());
782 newInit = TensorReshapeOp::create(
783 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
784 reshapeOp.getReassociation());
794struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
797 LogicalResult matchAndRewrite(tensor::PadOp padOp,
798 PatternRewriter &rewriter)
const override {
799 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
805 Value padValue = padOp.getConstantPaddingValue();
806 if (!padValue || fillOp.value() != padValue)
812 padOp,
"failed to reify tensor.pad op result shape");
815 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
816 padOp.getResultType().getElementType());
818 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
821 if (
replacement.getType() != padOp.getResultType()) {
822 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
833struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
836 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
837 PatternRewriter &rewriter)
const override {
838 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
842 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
847 Value firstDest = insertOp.getDest();
848 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
849 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
854 bool disjoint =
false;
855 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
858 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
859 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
860 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
864 int64_t prevStart = prevOp.getStaticOffset(i);
865 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
866 prevOp.getStaticStride(i);
867 int64_t nextStart = insertOp.getStaticOffset(i);
868 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
869 insertOp.getStaticStride(i);
870 if (prevEnd < nextStart || nextEnd < prevStart) {
878 firstDest = prevOp.getDest();
889 Value padValue = srcPadOp.getConstantPaddingValue();
890 if (!padValue || dstFillOp.value() != padValue)
893 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
894 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
896 Location loc = insertOp.getLoc();
899 AffineExpr sym0, sym1;
905 SmallVector<OpFoldResult, 4> newOffsets;
906 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
908 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
911 RankedTensorType srcPadType = srcPadOp.getSourceType();
912 SmallVector<OpFoldResult, 4> newSizes;
913 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
914 if (srcPadType.isDynamicDim(i)) {
916 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
919 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
924 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
925 newSizes, insertOp.getMixedStrides());
931struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
933 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
935 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
936 PatternRewriter &rewriter)
const override {
939 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
944 Value extractedScalar = fillOp.getInputs()[0];
947 rewriter.
replaceOp(extractOp, extractedScalar);
955static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
956 linalg::PackOp packOp) {
957 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
961 if (
auto paddingValue = packOp.getPaddingValue())
965 Value packOpDest = packOp.getDest();
969 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
974struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
976 FoldFillWithPack(MLIRContext *context)
977 : OpRewritePattern<linalg::PackOp>(context) {}
979 LogicalResult matchAndRewrite(linalg::PackOp packOp,
980 PatternRewriter &rewriter)
const override {
981 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
984 rewriter.
replaceOp(packOp, fillOp.value().result());
990struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
991 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
993 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
994 PatternRewriter &rewriter)
const override {
995 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
998 copyOp.getOutputs());
1001 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1003 fillOp.getOutputs());
1011struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1012 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1014 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1015 PatternRewriter &rewriter)
const override {
1016 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1018 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1019 transposeOp.getDpsInitOperand(0)->get());
1028struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1031 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1032 PatternRewriter &rewriter)
const override {
1033 auto concatOperands = concatOp.getInputs();
1034 if (concatOperands.empty()) {
1038 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1043 OpFoldResult firstFillVal =
1046 SmallVector<Value> allOuts;
1047 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1049 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1050 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1055 OpFoldResult fillVal =
1057 if (fillVal != firstFillVal)
1060 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1063 if (!llvm::all_of(concatOperands.drop_front(),
1064 isDefinedByCompatibleFillOp)) {
1066 concatOp,
"not all operands are defined by a compatible fill op");
1069 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1070 concatOp.getDim(), allOuts);
1072 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1079void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080 MLIRContext *context) {
1081 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1082 FoldFillWithPack, FoldFillWithPad,
1083 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1084 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1085 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1098 for (
ValueRange container : {inputs, outputs}) {
1099 for (
Value v : container) {
1100 Type t = v.getType();
1101 blockArgTypes.push_back(
1103 blockArgLocs.push_back(v.getLoc());
1109 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1113void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1115 for (Value v : getRegionInputArgs())
1117 for (Value v : getRegionOutputArgs())
1118 setNameFn(v,
"out");
1121void GenericOp::build(
1122 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1124 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1126 ArrayRef<NamedAttribute> attributes) {
1127 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1128 iteratorTypes, doc, libraryCall);
1129 result.addAttributes(attributes);
1132 inputs, outputs, bodyBuild);
1135void GenericOp::build(
1136 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1138 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1139 StringRef libraryCall,
1141 ArrayRef<NamedAttribute> attributes) {
1142 build(builder,
result, resultTensorTypes, inputs, outputs,
1146 [&](utils::IteratorType iter) -> mlir::Attribute {
1147 return IteratorTypeAttr::get(builder.getContext(), iter);
1150 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1151 bodyBuild, attributes);
1154void GenericOp::build(
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1158 StringRef libraryCall,
1160 ArrayRef<NamedAttribute> attributes) {
1162 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1165void GenericOp::build(
1167 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1173 "", bodyBuild, attributes);
1176void GenericOp::build(
1177 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1179 ArrayRef<utils::IteratorType> iteratorTypes,
1181 ArrayRef<NamedAttribute> attributes) {
1182 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1185 "", bodyBuild, attributes);
1188void GenericOp::print(OpAsmPrinter &p) {
1192 auto genericAttrNames = linalgTraitAttrNames();
1194 llvm::StringSet<> genericAttrNamesSet;
1195 genericAttrNamesSet.insert_range(genericAttrNames);
1196 SmallVector<NamedAttribute, 8> genericAttrs;
1197 for (
auto attr : (*this)->getAttrs()) {
1198 if (attr.getName() == getIteratorTypesAttrName()) {
1199 auto iteratorTypes =
1200 llvm::cast<ArrayAttr>(attr.getValue())
1201 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1206 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1207 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1208 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1211 genericAttrs.emplace_back(
1212 getIteratorTypesAttrName(),
1213 ArrayAttr::get(
getContext(), iteratorTypeNames));
1214 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1215 genericAttrs.push_back(attr);
1218 if (!genericAttrs.empty()) {
1219 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1220 p << genericDictAttr;
1226 genericAttrNames.push_back(
"operandSegmentSizes");
1227 genericAttrNamesSet.insert(genericAttrNames.back());
1229 bool hasExtraAttrs =
false;
1230 for (NamedAttribute n : (*this)->getAttrs()) {
1231 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1234 if (hasExtraAttrs) {
1241 if (!getRegion().empty()) {
1250ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1251 DictionaryAttr dictAttr;
1259 result.attributes.assign(dictAttr.getValue().begin(),
1260 dictAttr.getValue().end());
1266 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1267 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1268 if (!iteratorTypes) {
1269 return parser.
emitError(attributeLocation)
1270 <<
"expected " << getIteratorTypesAttrName(
result.name)
1271 <<
" array attribute";
1274 SmallVector<Attribute> iteratorTypeAttrs;
1276 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1277 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1278 if (!maybeIteratorType.has_value())
1280 <<
"unexpected iterator_type (" << s <<
")";
1282 iteratorTypeAttrs.push_back(
1283 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1285 result.attributes.set(getIteratorTypesAttrName(
result.name),
1289 SmallVector<Type, 1> inputTypes, outputTypes;
1299 std::unique_ptr<Region> region = std::make_unique<Region>();
1302 result.addRegion(std::move(region));
1308 SmallVector<Type, 1> outputTensorsTypes;
1311 result.addTypes(outputTensorsTypes);
1319 LinalgOp linalgOp) {
1320 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1321 if (!llvm::isa<MemRefType>(operand.
getType()))
1323 effects.emplace_back(
1328 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1329 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1331 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1342void GenericOp::getEffects(
1343 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1352 if (!linalgOp.hasPureTensorSemantics())
1370template <
typename OpTy>
1371struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1372 using OpRewritePattern<OpTy>::OpRewritePattern;
1374 LogicalResult matchAndRewrite(OpTy linalgOp,
1375 PatternRewriter &rewriter)
const override {
1377 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1382 Block &body = linalgOp->getRegion(0).front();
1383 if (!llvm::hasSingleElement(body))
1385 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1390 if (linalgOp.hasPureBufferSemantics()) {
1391 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1392 linalgOp.getDpsInputOperand(0)->get() !=
1393 linalgOp.getDpsInitOperand(0)->get()) {
1395 linalgOp,
"expected single input and output to be the same value");
1398 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1399 if (!yieldArg || yieldArg.getOwner() != &body) {
1401 "cannot fold fill-like op");
1408 if (!linalgOp.hasPureTensorSemantics()) {
1410 linalgOp,
"mixed semantics is not supported yet");
1415 SmallVector<Value> returnedArgs;
1416 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1417 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1418 if (!yieldArg || yieldArg.getOwner() != &body)
1420 unsigned argumentNumber = yieldArg.getArgNumber();
1421 Value returnedArg = linalgOp->getOperand(argumentNumber);
1422 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1425 Type returnType = returnedArg.
getType();
1426 if (returnType != resultType) {
1431 returnedArg = sparse_tensor::ConvertOp::create(
1432 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1434 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1437 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1438 resultType, returnedArg);
1441 returnedArgs.push_back(returnedArg);
1444 if (returnedArgs.size() != linalgOp->getNumResults())
1446 rewriter.
replaceOp(linalgOp, returnedArgs);
1453void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1454 MLIRContext *context) {
1455 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1458LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1477 for (
Type outputType : outputTypes) {
1478 if (llvm::isa<RankedTensorType>(outputType))
1479 result.addTypes(outputType);
1483 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1492void MapOp::getAsmBlockArgumentNames(Region ®ion,
1494 for (Value v : getRegionInputArgs())
1496 for (Value v : getRegionOutputArgs())
1497 setNameFn(v,
"init");
1500void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1501 if (!getResults().empty())
1502 setNameFn(getResults().front(),
"mapped");
1508 ArrayRef<NamedAttribute> attributes) {
1510 result.addAttributes(attributes);
1513 Type initType = init.
getType();
1514 if (llvm::isa<RankedTensorType>(initType))
1515 result.addTypes(initType);
1519 inputs, {init}, bodyBuild);
1526 bool initFirst =
false,
bool mapInit =
true) {
1530 b.setInsertionPointToStart(&block);
1531 for (
auto &operand : operands) {
1533 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1541 payloadOpOperands.push_back(block.
getArguments().back());
1542 for (
const auto &arg : block.
getArguments().drop_back())
1543 payloadOpOperands.push_back(arg);
1552 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1558ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1559 std::optional<OperationName> payloadOpName;
1560 NamedAttrList payloadOpAttrs;
1563 if (
failed(operationName))
1567 payloadOpName = operationName.value();
1575 if (payloadOpName.has_value()) {
1576 if (!
result.operands.empty())
1578 payloadOpAttrs, ArrayRef(
result.operands),
false,
1583 SmallVector<OpAsmParser::Argument> regionArgs;
1588 Region *body =
result.addRegion();
1596 bool mapInit =
true) {
1598 if (initFirst && !mapInit)
1622 for (
const auto &[operand, bbArg] :
1624 if (bbArg != operand)
1628 for (
const auto &[operand, bbArg] :
1631 if (bbArg != operand)
1638 return yieldOp.getNumOperands() == 1 &&
1639 yieldOp.getOperand(0).getDefiningOp() &&
1640 yieldOp.getOperand(0).getDefiningOp() == &payload;
1645 std::string attrToElide;
1647 for (
const auto &attr : payloadOp->
getAttrs()) {
1649 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1650 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1651 attrToElide = attr.getName().str();
1652 elidedAttrs.push_back(attrToElide);
1660void MapOp::print(OpAsmPrinter &p) {
1661 Block *mapper = getBody();
1671 if (!useShortForm) {
1677 [&](
auto arg) { p.printRegionArgument(arg); });
1685LogicalResult MapOp::verify() {
1686 auto *bodyBlock = getBody();
1687 auto blockArgs = bodyBlock->getArguments();
1691 if (getInputs().size() + 1 != blockArgs.size())
1692 return emitOpError() <<
"expects number of operands to match the arity of "
1694 << getInputs().size() + 1 <<
" and "
1695 << blockArgs.size();
1698 for (
const auto &[bbArgType, inputArg] :
1699 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1700 auto inputElemType =
1701 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1702 if (bbArgType != inputElemType) {
1703 return emitOpError() <<
"expected element type of input " << inputElemType
1704 <<
" to match bbArg type " << bbArgType;
1709 auto outputShape = getInit().getType().getShape();
1710 for (Type inputArgType :
TypeRange{getInputs()}) {
1711 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1712 if (inputElemShape != outputShape) {
1713 return emitOpError() <<
"expected shape of input (" << inputElemShape
1714 <<
") to match shape of output (" << outputShape
1722SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1723 int64_t rank = getInit().getType().getRank();
1724 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1729 int64_t rank = getInit().getType().getRank();
1730 int64_t numIndexingMaps = getOperands().size();
1735void MapOp::getEffects(
1736 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1749void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1751 for (Value v : getRegionInputArgs())
1753 for (Value v : getRegionOutputArgs())
1754 setNameFn(v,
"init");
1757void ReduceOp::getAsmResultNames(
1759 if (!getResults().empty())
1760 setNameFn(getResults().front(),
"reduced");
1763void ReduceOp::build(
1765 ValueRange inits, ArrayRef<int64_t> dimensions,
1767 ArrayRef<NamedAttribute> attributes) {
1769 result.addAttributes(attributes);
1772 for (Value init : inits) {
1773 Type initType = init.
getType();
1774 if (llvm::isa<RankedTensorType>(initType))
1775 result.addTypes(initType);
1780 inputs, inits, bodyBuild);
1783SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1785 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1786 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1787 utils::IteratorType::parallel);
1788 for (int64_t reductionDim : getDimensions())
1789 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1790 return iteratorTypes;
1795 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1796 SmallVector<AffineMap> affineMaps(
1799 AffineMap resultMap =
1802 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1803 affineMaps.push_back(resultMap);
1804 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1807void ReduceOp::getEffects(
1808 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1819 StringRef attributeName) {
1827ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1828 std::optional<OperationName> payloadOpName;
1829 NamedAttrList payloadOpAttrs;
1832 if (
failed(operationName))
1836 payloadOpName = operationName.value();
1842 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1847 if (payloadOpName.has_value()) {
1849 ArrayRef(
result.operands),
true);
1851 SmallVector<OpAsmParser::Argument> regionArgs;
1857 Region *body =
result.addRegion();
1867 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1870void ReduceOp::print(OpAsmPrinter &p) {
1871 Block *mapper = getBody();
1880 if (!useShortForm) {
1886 [&](
auto arg) { p.printRegionArgument(arg); });
1894LogicalResult ReduceOp::verify() {
1895 ArrayRef<int64_t> dimensionsRef = getDimensions();
1902 if (getInputs().size() !=
static_cast<size_t>(getNumDpsInputs()))
1904 <<
"expected equal number of inputs and outputs (required by "
1905 "SameVariadicOperandSize), got "
1906 << getNumDpsInputs() <<
" input(s) and " << getNumDpsInits()
1909 if (getInputs().empty())
1910 return emitOpError() <<
"expected at least one input";
1912 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1915 return emitOpError() <<
"expects all inputs to have the same shapes. "
1916 "Shape at input-index "
1918 <<
" is not equal to the shape at input-index 0.";
1921 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1924 return emitOpError() <<
"expects all outputs to have the same shapes. "
1925 "Shape at output-index "
1927 <<
" is not equal to the shape at output-index 0.";
1930 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1931 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1934 for (int64_t dimension : dimensionsRef) {
1935 if (dimension < 0 || dimension >= inputType.getRank()) {
1937 <<
"dimensions for reduction should be in the range [0, "
1938 << inputType.getRank() - 1 <<
"].";
1940 dimensionsToReduce.insert(dimension);
1943 auto inputDims = inputType.getShape();
1944 auto initDims = initType.getShape();
1947 SmallVector<int64_t> reducedInputDims;
1948 for (
const auto &en : llvm::enumerate(inputDims)) {
1949 if (!dimensionsToReduce.count(en.index()))
1950 reducedInputDims.push_back(en.value());
1953 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1954 return emitOpError() <<
"number of dimensions after reduction "
1955 << reducedInputDims.size()
1956 <<
" doesn't match the init rank "
1957 << initType.getRank();
1960 if (reducedInputDims != initDims)
1961 return emitOpError() <<
"init dimensions [" << initDims
1962 <<
"] doesn't match input dimensions after reduction ["
1963 << reducedInputDims <<
"]";
1965 Block *block = getBody();
1968 <<
"mismatching number of operands and block arguments";
1971 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1972 Type inputElementType =
1973 llvm::cast<ShapedType>(input.getType()).getElementType();
1974 if (inputElementType != bbArg.getType())
1976 <<
"input element type " << inputElementType
1977 <<
" does not match corresponding block argument type "
1982 for (
auto [output, bbArg] : llvm::zip(
1983 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1984 auto outputElementType =
1985 llvm::cast<ShapedType>(output.getType()).getElementType();
1986 if (outputElementType != bbArg.getType())
1988 <<
"output element type " << outputElementType
1989 <<
" does not match corresponding block argument type "
2005 linalg::YieldOp::create(
b, loc, args[0]);
2009void TransposeOp::build(::mlir::OpBuilder &builder,
2010 ::mlir::OperationState &
result, Value input, Value init,
2012 ArrayRef<NamedAttribute> attributes) {
2013 result.addOperands(input);
2014 result.addOperands(init);
2015 result.addAttribute(getPermutationAttrName(
result.name), permutation);
2016 result.addAttributes(attributes);
2019 Type initType = init.
getType();
2020 if (llvm::isa<RankedTensorType>(initType))
2021 result.addTypes(initType);
2027void TransposeOp::build(::mlir::OpBuilder &builder,
2028 ::mlir::OperationState &
result, Value input, Value init,
2029 ArrayRef<int64_t> permutation,
2030 ArrayRef<NamedAttribute> attributes) {
2035ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2037 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2049void TransposeOp::getAsmResultNames(
2051 if (!getResults().empty())
2052 setNameFn(getResults().front(),
"transposed");
2055void TransposeOp::print(OpAsmPrinter &p) {
2061LogicalResult TransposeOp::verify() {
2062 ArrayRef<int64_t> permutationRef = getPermutation();
2067 auto inputType = getInput().getType();
2068 auto initType = getInit().getType();
2070 int64_t rank = inputType.getRank();
2076 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2077 return emitOpError() <<
"size of permutation " << permutationRef.size()
2078 <<
" does not match the argument rank " << rank;
2080 auto inputDims = inputType.getShape();
2081 auto initDims = initType.getShape();
2083 for (int64_t i = 0; i < rank; ++i) {
2084 int64_t inputDim = inputDims[permutationRef[i]];
2085 int64_t initDim = initDims[i];
2087 if (inputDim != initDim) {
2088 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2089 <<
" doesn't match dim(input, permutation[" << i
2090 <<
"]) = " << inputDim;
2097SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2098 int64_t rank = getInit().getType().getRank();
2099 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2102ArrayAttr TransposeOp::getIndexingMaps() {
2104 int64_t rank = getInit().getType().getRank();
2107 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2111void TransposeOp::getEffects(
2112 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2121LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2122 SmallVectorImpl<OpFoldResult> &
result) {
2124 if (!isa<TensorType>(getInput().
getType()))
2128 if (getPermutation().empty()) {
2129 result.push_back(getInput());
2134 result.push_back(getInput());
2147 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2148 if (!defTransposeOp)
2153 foldedPerms.reserve(perms.size());
2155 foldedPerms.push_back(defPerms[perm]);
2158 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2171 if (!transposeOp.hasPureTensorSemantics())
2176 if (!splatValue.has_value())
2180 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2197 Value input = transposeOp.getInput();
2198 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2209 unsigned dimensionSize = dimensions.size();
2210 for (
unsigned i = 0; i < dimensionSize; ++i)
2211 resultDimensions.push_back(invertPerm[dimensions[i]]);
2214 Value broadcastInput = broadcastOp.getInput();
2215 Location loc = transposeOp.getLoc();
2218 auto broadcastInputTy =
2219 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2220 unsigned inputRank = broadcastInputTy.getRank();
2221 for (
unsigned i = 0; i < inputRank; ++i) {
2222 if (broadcastInputTy.isDynamicDim(i)) {
2223 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2226 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2227 broadcastInputTy.getDimSize(i)));
2232 Value transposeInit = tensor::EmptyOp::create(
2233 rewriter, transposeOp.getLoc(), transposeResultShapes,
2234 broadcastInputTy.getElementType());
2237 Value transposeResult =
2238 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2239 transposeInit, resultPerms)
2242 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2247void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2248 MLIRContext *context) {
2249 results.
add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2250 SwapTransposeWithBroadcast>(context);
2257void BroadcastOp::build(::mlir::OpBuilder &builder,
2258 ::mlir::OperationState &
result, Value input, Value init,
2260 ArrayRef<NamedAttribute> attributes) {
2261 result.addOperands(input);
2262 result.addOperands(init);
2263 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2264 result.addAttributes(attributes);
2267 Type initType = init.
getType();
2268 if (llvm::isa<RankedTensorType>(initType))
2269 result.addTypes(initType);
2275void BroadcastOp::build(::mlir::OpBuilder &builder,
2276 ::mlir::OperationState &
result, Value input, Value init,
2277 ArrayRef<int64_t> dimensions,
2278 ArrayRef<NamedAttribute> attributes) {
2283ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2285 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2297void BroadcastOp::getAsmResultNames(
2299 if (!getResults().empty())
2300 setNameFn(getResults().front(),
"broadcasted");
2303void BroadcastOp::print(OpAsmPrinter &p) {
2309LogicalResult BroadcastOp::verify() {
2310 ArrayRef<int64_t> dimensionsRef = getDimensions();
2312 auto inputType = getInput().getType();
2313 auto initType = getInit().getType();
2315 int64_t inputRank = inputType.getRank();
2316 int64_t initRank = initType.getRank();
2318 auto inputShape = inputType.getShape();
2319 auto initShape = initType.getShape();
2321 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2322 return emitOpError() <<
"input rank plus added dimensions does not "
2323 "match init rank. input rank: "
2325 <<
", dimensions size: " << dimensionsRef.size()
2326 <<
", init rank: " << initRank;
2328 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2329 if (dim < 0 || dim >= initRank)
2331 <<
" is out of range. expected range: [0, "
2332 << initRank - 1 <<
"], got: " << dim;
2336 SmallVector<int64_t> dimMap;
2337 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2338 if (!llvm::is_contained(dimensionsRef, dim))
2339 dimMap.push_back(dim);
2342 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2345 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2346 return emitOpError() <<
"input dim " << inputDimIdx
2347 <<
" should match init dim " << initDimIdx
2348 <<
". input: " << inputShape[inputDimIdx]
2349 <<
", init: " << initShape[initDimIdx];
2355SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2356 int64_t rank = getInit().getType().getRank();
2357 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2360ArrayAttr BroadcastOp::getIndexingMaps() {
2362 int64_t rank = getInit().getType().getRank();
2368void BroadcastOp::getEffects(
2369 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2384 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2385 if (!defBroadcastOp)
2390 Value init = broadcastOp.getInit();
2394 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2395 if (!llvm::is_contained(dimensions, dim))
2396 dimMap.push_back(dim);
2398 for (
auto dim : defDimensions)
2399 foldedDims.push_back(dimMap[dim]);
2401 llvm::sort(foldedDims);
2403 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2415 if (!broadcastOp.hasPureTensorSemantics())
2421 if (!splatValue.has_value())
2425 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2426 if (!resultType.hasStaticShape())
2428 "result type has dynamic shape");
2437void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2438 MLIRContext *context) {
2439 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2440 FoldBroadcastSplatConstant>(context);
2447void linalg::YieldOp::print(OpAsmPrinter &p) {
2448 if (getNumOperands() > 0)
2449 p <<
' ' << getOperands();
2451 if (getNumOperands() > 0)
2452 p <<
" : " << getOperandTypes();
2455ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2456 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2457 SmallVector<Type, 2> types;
2467static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2468 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2469 return op.emitOpError(
"expected number of yield values (")
2470 << op.getNumOperands()
2471 <<
") to match the number of inits / outs operands of the enclosing "
2472 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2474 for (
OpOperand &opOperand : op->getOpOperands()) {
2476 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2478 if (isa<MemRefType, RankedTensorType>(elementType))
2480 if (opOperand.get().getType() != elementType)
2481 return op.emitOpError(
"type of yield operand ")
2482 << (opOperand.getOperandNumber() + 1) <<
" ("
2483 << opOperand.get().getType() <<
") doesn't match "
2484 <<
"the element type of the enclosing linalg.generic op ("
2485 << elementType <<
")";
2490LogicalResult linalg::YieldOp::verify() {
2491 auto *parentOp = (*this)->getParentOp();
2492 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2493 return emitOpError(
"expected single non-empty parent region");
2495 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2498 return emitOpError(
"expected parent op with LinalgOp interface");
2505LogicalResult IndexOp::verify() {
2506 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2508 return emitOpError(
"expected parent op with LinalgOp interface");
2509 if (linalgOp.getNumLoops() <= getDim())
2511 << getDim() <<
") to be lower than the number of loops ("
2512 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2516OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2517 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2522 return OpFoldResult{};
2525 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2526 uint64_t dim = getDim();
2527 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2528 if (loopBounds[dim] == 1)
2529 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2531 return OpFoldResult{};
2536#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2538#define GET_OP_CLASSES
2539#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2541#define GET_OP_CLASSES
2542#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2543#define GET_OP_CLASSES
2544#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2561 for (
unsigned i = 0; i < num; ++i)
2568 auto rangeA = llvm::make_range(a.begin(), a.end());
2569 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2570 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2571 return llvm::to_vector<4>(concatRanges);
2575 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2577 for (
auto size :
memref.getShape())
2584 if (
auto as =
memref.getMemorySpace()) {
2585 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2586 ss <<
"as" << attr.getInt();
2592 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2595 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2608 assert(isa<LinalgOp>(op));
2610 std::string fun =
"";
2612 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2613 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2614 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2615 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2619 llvm::replace(name,
'.',
'_');
2620 llvm::raw_string_ostream ss(name);
2624 return std::string();
2639 LogicalResult matchAndRewrite(LinalgOp op,
2641 for (
OpOperand &opOperand : op->getOpOperands()) {
2645 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2648 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2659struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2660 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2662 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2663 PatternRewriter &rewriter)
const override {
2667 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2674 if (castOp->getBlock() != linalgOp->getBlock())
2677 OpBuilder::InsertionGuard guard(rewriter);
2680 Location loc = linalgOp.getLoc();
2681 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2684 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2690 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2692 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2693 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2694 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2695 linalgOp.getDpsInits().end());
2696 outputOperands[resultNumber] = newOperand;
2697 newOperands.append(outputOperands.begin(), outputOperands.end());
2699 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2700 linalgOp->result_type_end());
2701 resultTypes[resultNumber] = resultType;
2702 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2705 Value castBack = tensor::CastOp::create(
2709 results[resultNumber] = castBack;
2718static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2719 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2720 for (OpOperand &opOperand : operands) {
2721 if (linalgOp.isScalar(&opOperand))
2723 Value src = opOperand.get();
2724 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2725 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2731 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2733 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2734 Value castSource = castOp.getSource();
2735 auto castSourceType =
2736 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2737 if (castSourceType && castSourceType.hasStaticShape())
2738 sourceShape = castSourceType.getShape();
2744 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2745 if (sourceType.isDynamicDim(i))
2747 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2748 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2758static void createNewOperandWithStaticSizes(
2759 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2760 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2761 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2762 bool &changeNeeded) {
2763 Value src = opOperand->
get();
2764 newOperands.push_back(src);
2765 if (linalgOp.isScalar(opOperand))
2767 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2768 Type resultType = sourceType;
2769 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2770 resultTypes.push_back(resultType);
2773 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2774 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2775 SmallVector<int64_t> newShape;
2778 bool newOperandNeeded =
false;
2779 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2780 int64_t dimShape = sourceShape[i];
2781 AffineExpr dimExpr = sourceMap.
getResult(i);
2782 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2783 newShape.push_back(dimShape);
2789 newShape.push_back(affineExprToSize[dimExpr]);
2790 newOperandNeeded =
true;
2792 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2793 sourceType.getEncoding());
2794 if (newOperandNeeded) {
2795 changeNeeded =
true;
2798 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2800 newOperands[index] = newOperand;
2802 if (linalgOp.isDpsInit(opOperand))
2803 resultTypes.push_back(resultType);
2809struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2810 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2812 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2813 PatternRewriter &rewriter)
const override {
2814 if (!linalgOp.hasPureTensorSemantics())
2818 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2819 return !map.isProjectedPermutation();
2824 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2825 Location loc = linalgOp.getLoc();
2829 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2831 SmallVector<Value> newOperands;
2832 SmallVector<Type> resultTypes;
2836 bool changeNeeded =
false;
2837 newOperands.reserve(linalgOp->getNumOperands());
2838 resultTypes.reserve(linalgOp.getNumDpsInits());
2841 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2842 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2843 affineExprToSize, linalgOp, newOperands,
2844 resultTypes, changeNeeded);
2853 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2854 SmallVector<Value> replacements;
2856 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2857 Value newResult = std::get<1>(it);
2858 Value oldResult = std::get<0>(it);
2859 Type newType = newResult.
getType();
2860 Type oldType = oldResult.
getType();
2861 replacements.push_back(
2862 (newType != oldType)
2863 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2866 rewriter.
replaceOp(linalgOp, replacements);
2880LogicalResult SoftmaxOp::verify() {
2881 ShapedType inputType = getInputOperandType();
2882 ShapedType outputType = getOutputOperandType();
2884 ArrayRef<int64_t> inputShape = inputType.getShape();
2885 ArrayRef<int64_t> outputShape = outputType.getShape();
2889 int64_t inputRank = getInputOperandRank();
2890 int64_t dimension = getDimension();
2891 if ((dimension < 0) || (dimension >= inputRank))
2892 return emitOpError(
"incorrect dimension specified");
2897SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2898 int64_t operandRank = getInputOperandRank();
2899 SmallVector<Range> loopBounds(operandRank);
2900 Location loc = getLoc();
2903 Value source = getInput();
2904 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2905 loopBounds[dim].offset = zero;
2906 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2907 loopBounds[dim].stride = one;
2912SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2913 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2914 utils::IteratorType::parallel);
2915 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2916 return iteratorTypes;
2919FailureOr<TilingResult>
2920SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2921 ArrayRef<OpFoldResult> offsets,
2922 ArrayRef<OpFoldResult> sizes) {
2923 int64_t rank = getInputOperandRank();
2925 SmallVector<OpFoldResult> strides(rank, oneAttr);
2926 SmallVector<Value> tiledOperands;
2927 Operation *inputSlice =
2928 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2930 return emitOpError(
"failed to compute input slice");
2932 tiledOperands.emplace_back(inputSlice->
getResult(0));
2933 Operation *outputSlice =
2934 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2936 return emitOpError(
"failed to compute output slice");
2938 tiledOperands.emplace_back(outputSlice->
getResult(0));
2940 SmallVector<Type, 4> resultTypes;
2941 if (hasPureTensorSemantics())
2942 resultTypes.push_back(tiledOperands[1].
getType());
2943 Operation *tiledOp =
2944 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2946 return TilingResult{
2949 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2952LogicalResult SoftmaxOp::getResultTilePosition(
2953 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2954 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2955 SmallVector<OpFoldResult> &resultSizes) {
2956 if (resultNumber == 0) {
2957 resultOffsets.assign(offsets.begin(), offsets.end());
2958 resultSizes.assign(sizes.begin(), sizes.end());
2965LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2970SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2972 SmallVector<OpFoldResult> shapes;
2973 Location loc = getOperation()->getLoc();
2974 IRRewriter rewriter(
b);
2975 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2976 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2977 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2978 if (!outputShapedType.isDynamicDim(dim)) {
2980 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2987 reifiedReturnShapes.emplace_back(std::move(shapes));
2991void SoftmaxOp::getEffects(
2992 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2994 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2995 if (!llvm::isa<MemRefType>(operand.
getType()))
2998 &getOperation()->getOpOperand(index), 0,
3003 for (OpOperand &operand : getDpsInitsMutable()) {
3004 if (!llvm::isa<MemRefType>(operand.get().
getType()))
3035static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3037 int64_t dim,
bool allParallel =
false) {
3039 utils::IteratorType::parallel);
3041 iteratorTypes[dim] = utils::IteratorType::reduction;
3045 for (
int i = 0; i < inputRank; i++) {
3052 return std::make_tuple(iteratorTypes, indexingMaps);
3057template <
typename T>
3060 auto inputType = cast<ShapedType>(input.
getType());
3062 int64_t inputRank = inputShape.size();
3063 auto [iteratorTypes, indexingMaps] =
3065 assert(indexingMaps.size() == 2 &&
3066 "We should have two maps: 1 for the input, 1 for the output");
3067 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3069 auto genericOp = linalg::GenericOp::create(
3070 builder, loc, output.
getType(), input, output, indexingMaps,
3072 Value result = T::create(b, loc, args[0], args[1]);
3073 linalg::YieldOp::create(b, loc, result);
3075 return genericOp.getResult(0);
3083 auto inputType = cast<ShapedType>(input.
getType());
3085 int64_t inputRank = inputShape.size();
3087 builder, inputRank, dim,
true);
3088 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3089 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3091 indexingMaps.push_back(indexingMaps[0]);
3092 auto genericOp = linalg::GenericOp::create(
3094 indexingMaps, iteratorTypes,
3096 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3097 Value result = math::ExpOp::create(b, loc, diff);
3098 linalg::YieldOp::create(b, loc, result);
3100 return genericOp.getResult(0);
3110 auto inputType = cast<ShapedType>(numerator.
getType());
3112 int64_t inputRank = inputShape.size();
3114 builder, inputRank, dim,
true);
3115 assert(indexingMaps.size() == 2 &&
3116 "We should have one map for each input (2)");
3117 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3119 indexingMaps.push_back(indexingMaps[0]);
3120 auto genericOp = linalg::GenericOp::create(
3122 output, indexingMaps, iteratorTypes,
3124 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3125 linalg::YieldOp::create(b, loc, result);
3127 return genericOp.getResult(0);
3149FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3150 OpBuilder::InsertionGuard guard(
b);
3151 b.setInsertionPoint(*
this);
3152 Location loc = getLoc();
3153 Value input = getInput();
3154 ShapedType inputType = getInputOperandType();
3155 Type elementType = inputType.getElementType();
3156 int64_t reductionDim = getDimension();
3158 Value output = getOutput();
3159 dims.erase(dims.begin() + reductionDim);
3161 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3163 elementType,
b, loc,
3165 Value neutralForMaxFInit =
3166 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3178 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3184 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3185 return SmallVector<Value>{
result};
3192LogicalResult WinogradFilterTransformOp::verify() {
3193 auto filterType = cast<ShapedType>(getFilter().
getType());
3194 ArrayRef<int64_t> filterShape = filterType.getShape();
3195 int64_t filterH = filterShape[getFilterHDim()];
3196 int64_t filterW = filterShape[getFilterWDim()];
3197 WinogradConv2DFmr fmr = getFmr();
3201 if (filterH != r && filterH != 1)
3202 return emitOpError(
"expect filter height either equals to r or 1");
3203 if (filterW != r && filterW != 1)
3204 return emitOpError(
"expect filter width either equals to r or 1");
3205 if (filterH == 1 && filterW == 1)
3206 return emitOpError(
"expect either filter height or width equals to r");
3208 SmallVector<int64_t> expectedOutputShape;
3209 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3210 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3211 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3212 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3214 auto outputType = cast<ShapedType>(getOutput().
getType());
3215 ArrayRef<int64_t> outputShape = outputType.getShape();
3217 return emitOpError(
"the output shape is not expected");
3223WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3224 Location loc = getLoc();
3227 Value filter = getFilter();
3228 int64_t filterRank = getFilterOperandRank();
3229 SmallVector<Range> loopBounds(filterRank);
3230 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3231 loopBounds[dim].offset = zeroAttr;
3232 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3233 loopBounds[dim].stride = oneAttr;
3238SmallVector<utils::IteratorType>
3239WinogradFilterTransformOp::getLoopIteratorTypes() {
3240 int64_t filterRank = getFilterOperandRank();
3241 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3242 utils::IteratorType::parallel);
3243 return iteratorTypes;
3246LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3247 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3248 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3249 SmallVector<OpFoldResult> &resultSizes) {
3251 ShapedType filterType = getFilterOperandType();
3252 ArrayRef<int64_t> filterShape = filterType.getShape();
3253 int64_t filterH = filterShape[getFilterHDim()];
3254 int64_t filterW = filterShape[getFilterWDim()];
3255 WinogradConv2DFmr fmr = getFmr();
3258 int64_t alpha = m + r - 1;
3259 int64_t alphaH = filterH != 1 ? alpha : 1;
3260 int64_t alphaW = filterW != 1 ? alpha : 1;
3264 resultOffsets.append(
3265 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3267 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3278FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3279 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3280 ArrayRef<OpFoldResult> sizes) {
3283 ShapedType filterType = getFilterOperandType();
3284 ArrayRef<int64_t> filterShape = filterType.getShape();
3285 int64_t filterH = filterShape[getFilterHDim()];
3286 int64_t filterW = filterShape[getFilterWDim()];
3289 SmallVector<Value> tiledOperands;
3290 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3292 sliceOffsets.append(
3293 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3294 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3295 sizes[getFilterCDim()]});
3296 int64_t filterRank = getFilterOperandRank();
3297 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3298 Location loc = getLoc();
3299 auto filterSlice = tensor::ExtractSliceOp::create(
3300 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3301 tiledOperands.emplace_back(filterSlice);
3303 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3308 int64_t outputRank = getOutputOperandRank();
3309 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3310 auto outputSlice = tensor::ExtractSliceOp::create(
3311 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3312 tiledOperands.emplace_back(outputSlice);
3314 SmallVector<Type> resultTypes;
3315 resultTypes.push_back(tiledOperands[1].
getType());
3316 Operation *tiledOp =
3317 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3319 return TilingResult{
3322 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3329LogicalResult WinogradInputTransformOp::verify() {
3330 auto inputType = cast<ShapedType>(getInput().
getType());
3331 ArrayRef<int64_t> inputShape = inputType.getShape();
3332 int64_t inputH = inputShape[getInputHDim()];
3333 int64_t inputW = inputShape[getInputWDim()];
3334 WinogradConv2DFmr fmr = getFmr();
3337 int64_t tileSize = m + r - 1;
3339 auto outputType = cast<ShapedType>(getOutput().
getType());
3340 ArrayRef<int64_t> outputShape = outputType.getShape();
3341 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3342 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3344 SmallVector<int64_t> expectedOutputShape(6, inputH);
3345 if (ShapedType::isDynamic(inputH)) {
3346 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3347 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3349 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3350 expectedOutputShape[getOutputTileHDim()] =
3351 leftTransform ? (inputH - (r - 1)) / m : inputH;
3353 if (ShapedType::isDynamic(inputW)) {
3354 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3355 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3357 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3358 expectedOutputShape[getOutputTileWDim()] =
3359 rightTransform ? (inputW - (r - 1)) / m : inputW;
3361 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3362 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3365 return emitOpError(
"the output shape is not expected");
3371WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3372 Location loc = getLoc();
3375 Value output = getOutput();
3376 int64_t outputRank = getOutputOperandRank();
3377 SmallVector<Range> loopBounds(outputRank);
3378 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3379 loopBounds[dim].offset = zeroAttr;
3381 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3382 loopBounds[dim].stride = oneAttr;
3387SmallVector<utils::IteratorType>
3388WinogradInputTransformOp::getLoopIteratorTypes() {
3389 int64_t outputRank = getOutputOperandRank();
3390 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3391 utils::IteratorType::parallel);
3392 return iteratorTypes;
3395LogicalResult WinogradInputTransformOp::getResultTilePosition(
3396 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3397 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3398 SmallVector<OpFoldResult> &resultSizes) {
3400 ShapedType outputType = getOutputOperandType();
3401 ArrayRef<int64_t> outputShape = outputType.getShape();
3402 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3403 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3405 WinogradConv2DFmr fmr = getFmr();
3408 int64_t alpha = m + r - 1;
3409 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3410 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3415 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3416 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3417 offsets[getOutputCDim()]});
3418 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3419 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3420 sizes[getOutputCDim()]});
3431FailureOr<TilingResult>
3432WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3433 ArrayRef<OpFoldResult> offsets,
3434 ArrayRef<OpFoldResult> sizes) {
3436 WinogradConv2DFmr fmr = getFmr();
3440 ShapedType outputType = getOutputOperandType();
3441 ArrayRef<int64_t> outputShape = outputType.getShape();
3442 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3443 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3445 Location loc = getLoc();
3447 auto identityAffineMap =
3449 auto offsetAffineMap =
3452 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3453 offsets[getOutputTileHDim()]);
3455 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3456 offsets[getOutputTileWDim()]);
3460 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3462 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3464 SmallVector<Value> tiledOperands;
3465 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3467 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3468 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3469 sliceOffsets.append(
3470 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3471 OpFoldResult sizeH =
3472 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3473 OpFoldResult sizeW =
3474 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3476 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3477 int64_t inputRank = getInputOperandRank();
3478 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3479 auto inputSlice = tensor::ExtractSliceOp::create(
3480 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3481 tiledOperands.emplace_back(inputSlice);
3483 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3488 int64_t outputRank = getOutputOperandRank();
3489 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3490 auto outputSlice = tensor::ExtractSliceOp::create(
3491 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3492 tiledOperands.emplace_back(outputSlice);
3494 SmallVector<Type> resultTypes;
3495 resultTypes.push_back(tiledOperands[1].
getType());
3496 Operation *tiledOp =
3497 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3499 return TilingResult{
3502 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3509LogicalResult WinogradOutputTransformOp::verify() {
3510 auto valueType = cast<ShapedType>(getValue().
getType());
3511 ArrayRef<int64_t> valueShape = valueType.getShape();
3512 int64_t valueH = valueShape[getValueAlphaHDim()];
3513 int64_t valueW = valueShape[getValueAlphaWDim()];
3514 int64_t valueTileH = valueShape[getValueTileHDim()];
3515 int64_t valueTileW = valueShape[getValueTileWDim()];
3516 WinogradConv2DFmr fmr = getFmr();
3519 bool leftTransform = valueH != 1;
3520 bool rightTransform = valueW != 1;
3522 int64_t outputRank = getOutputOperandRank();
3523 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3524 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3525 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3527 if (valueH != (leftTransform ? m + r - 1 : 1))
3528 return emitOpError(
"expect input height equals to input tile size");
3529 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3531 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3532 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3534 if (valueW != (rightTransform ? m + r - 1 : 1))
3535 return emitOpError(
"expect input width equals to input tile size");
3536 expectedOutputShape[getOutputWDim()] =
3537 (rightTransform ? m : 1) * valueTileW;
3539 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3540 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3542 auto outputType = cast<ShapedType>(getOutput().
getType());
3543 ArrayRef<int64_t> outputShape = outputType.getShape();
3545 return emitOpError(
"the output shape is not expected");
3551WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3552 Location loc = getLoc();
3555 Value value = getValue();
3556 int64_t valueRank = getValueOperandRank();
3557 SmallVector<Range> loopBounds(valueRank);
3558 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3559 loopBounds[dim].offset = zeroAttr;
3561 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3562 loopBounds[dim].stride = oneAttr;
3567SmallVector<utils::IteratorType>
3568WinogradOutputTransformOp::getLoopIteratorTypes() {
3569 int64_t valueRank = getValueOperandRank();
3570 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3571 utils::IteratorType::parallel);
3572 return iteratorTypes;
3575LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3576 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3577 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3578 SmallVector<OpFoldResult> &resultSizes) {
3579 WinogradConv2DFmr fmr = getFmr();
3583 Location loc = getLoc();
3585 auto identityAffineMap =
3590 ShapedType valueType = getValueOperandType();
3591 ArrayRef<int64_t> valueShape = valueType.getShape();
3592 int64_t valueH = valueShape[0];
3593 int64_t valueW = valueShape[1];
3595 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3596 offsets[getValueTileHDim()]);
3598 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3599 offsets[getValueTileWDim()]);
3601 builder, loc, affineMap, sizes[getValueTileHDim()]);
3603 builder, loc, affineMap, sizes[getValueTileWDim()]);
3606 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3607 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3608 OpFoldResult sizeH =
3609 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3610 OpFoldResult sizeW =
3611 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3613 resultOffsets.append(
3614 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3616 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3626FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3627 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3628 ArrayRef<OpFoldResult> sizes) {
3631 Location loc = getLoc();
3632 SmallVector<Value> tiledOperands;
3633 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3635 ShapedType valueType = getValueOperandType();
3636 ArrayRef<int64_t> valueShape = valueType.getShape();
3637 int64_t alphaH = valueShape[getValueAlphaHDim()];
3638 int64_t alphaW = valueShape[getValueAlphaWDim()];
3642 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3643 offsets[getValueTileWDim()], offsets[getValueNDim()],
3644 offsets[getValueFDim()]});
3645 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3646 sizes[getValueTileWDim()], sizes[getValueNDim()],
3647 sizes[getValueFDim()]});
3648 int64_t valueRank = getValueOperandRank();
3649 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3650 auto valueSlice = tensor::ExtractSliceOp::create(
3651 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3652 tiledOperands.emplace_back(valueSlice);
3654 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3659 int64_t outputRank = getOutputOperandRank();
3660 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3661 auto outputSlice = tensor::ExtractSliceOp::create(
3662 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3663 tiledOperands.emplace_back(outputSlice);
3665 SmallVector<Type> resultTypes;
3666 resultTypes.push_back(tiledOperands[1].
getType());
3667 Operation *tiledOp =
3668 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3670 return TilingResult{
3673 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3687 llvm::set_union(explicitSet, defaultSet);
3688 return explicitSet == defaultSet;
3708 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3710 auto opIndexingMap = opIndexingMaps[opIndex];
3711 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3714 return matmulOp->emitOpError()
3715 <<
"Unexpected dim expression in map result.";
3718 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3719 return matmulOp->emitOpError()
3720 <<
"Invalid broadcast requested, should be (d2).";
3729template <
typename OpTy>
3732 AffineMap defaultIndexingMap,
bool isLHS) {
3733 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3734 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3735 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3738 return batchVariantMatmulOp->emitOpError()
3739 <<
"Unexpected result dim expression (outside the set of default "
3744 return batchVariantMatmulOp->emitOpError()
3745 <<
"no. of result dim expressions exceeds 3.";
3747 auto hasValidBatchDim = [](
AffineMap map) {
3754 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3755 return batchVariantMatmulOp->emitOpError()
3756 <<
"Invalid broadcast requested.";
3757 }
else if (!hasValidBatchDim(opIndexingMap)) {
3758 return batchVariantMatmulOp->emitOpError()
3759 <<
"Invalid batch dimension expression.";
3767template <
typename OpTy>
3770 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3771 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3772 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3773 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3776 return batchVariantMatmulOp->emitOpError()
3777 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3780 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3782 return batchVariantMatmulOp->emitOpError()
3783 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3787 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3788 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3789 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3790 outputMap.getResult(1).isFunctionOfDim(1) &&
3791 outputMap.getResult(2).isFunctionOfDim(2)
3792 : outputMap.getResult(0).isFunctionOfDim(1) &&
3793 outputMap.getResult(1).isFunctionOfDim(2);
3796 if (!areValidOutputResultDim(opIndexingMap)) {
3797 return batchVariantMatmulOp->emitOpError()
3798 <<
"Invalid output map result dimension.";
3807template <
typename OpTy>
3812 batchVariantMatmulOp.getIndexingMapsArray();
3814 batchVariantMatmulOp.getDefaultIndexingMaps(
3815 batchVariantMatmulOp->getContext());
3817 if (opIndexingMaps.size() != 3)
3818 return batchVariantMatmulOp->emitOpError()
3819 <<
"Indexing_map attribute must have 3 affine maps.";
3821 auto opIndexingMap = opIndexingMaps[opIndex];
3822 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3830 defaultIndexingMap, opIndex == 0)))
3840 if (m == 2 && r == 3)
3841 return WinogradConv2DFmr::F_2_3;
3842 if (m == 4 && r == 3)
3843 return WinogradConv2DFmr::F_4_3;
3844 if (m == 2 && r == 5)
3845 return WinogradConv2DFmr::F_2_5;
3846 return std::nullopt;
3851 case WinogradConv2DFmr::F_2_3:
3853 case WinogradConv2DFmr::F_4_3:
3855 case WinogradConv2DFmr::F_2_5:
3858 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3865static FailureOr<SmallVector<SmallVector<int64_t>>>
3868 for (
auto map : maps) {
3869 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3873 for (
auto result : attr.getAffineMap().getResults()) {
3874 auto dim = dyn_cast<AffineDimExpr>(
result);
3877 pos.push_back(dim.getPosition());
3879 positions.push_back(pos);
3892 return indexingMaps;
3895bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3896 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3899 if (maps.size() != 3)
3904 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3905 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3906 (*positions)[2] == SmallVector<int64_t>{0, 1};
3909SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3910 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3911 utils::IteratorType::parallel,
3912 utils::IteratorType::reduction};
3915unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3917std::string MatmulOp::getLibraryCallName() {
3921bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3925bool MatmulOp::hasUserDefinedMaps() {
3926 SmallVector<AffineMap, 3> defaultMaps =
3928 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3929 return defaultMaps != explicitMaps;
3934void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3935 ArrayRef<NamedAttribute> attrs,
3938 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3943 "MatmulOp regionBuilder expects 3 args");
3944 RegionBuilderHelper helper(
b, block);
3945 SmallVector<Value> yields;
3947 TypeFn castVal = TypeFn::cast_signed;
3948 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3949 return attr.
getName() ==
"cast";
3951 if (castIter != attrs.end()) {
3952 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3960 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3961 if (!value1 || !value2 || !value3)
3963 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3967 yields.push_back(value4);
3968 helper.yieldOutputs(yields);
3978bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3979 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3980 AffineExpr expr = bcastMap.
getResult(0);
3990 ArrayAttr arrayAttr;
3994 if (llvm::any_of(arrayAttr,
3995 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3997 <<
"element of indexing_maps array is not an affine_map";
4004 if (failed(indexingMapsAttr))
4007 if (*indexingMapsAttr ==
nullptr) {
4008 auto indexingMapAttrs = llvm::map_to_vector(
4009 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4014 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4016 MatmulOp::getRegionBuilder());
4019void MatmulOp::print(OpAsmPrinter &p) {
4020 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4021 MatmulOp::getDefaultIndexingMaps(
getContext()),
4022 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4023 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4024 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4026 std::array<StringRef, 3> elidedAttrs = {
4027 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4033LogicalResult MatmulOp::verify() {
4035 if (!hasUserDefinedMaps())
4038 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
4045LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4049void MatmulOp::getEffects(
4050 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4052 if (hasPureTensorSemantics())
4061SmallVector<AffineMap>
4062MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4063 AffineExpr d0, d1, d2;
4069 return {mapLHS, mapRHS, mapOut};
4073 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4076 if (maps.size() != 3)
4079 if (failed(positions))
4091 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4099 build(builder, state, inputs, outputs, attributes);
4100 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4101 assert(res &&
"builder didn't return the right type");
4111 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4120 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4121 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4122 assert(res &&
"builder didn't return the right type");
4132 result.addAttribute(
"cast", cast);
4134 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4143 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4144 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4145 assert(res &&
"builder didn't return the right type");
4150 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4152 op->
getAttr(
"indexing_maps"));
4156MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4163 return {mapLHS, mapRHS, mapOut};
4167 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4170 if (maps.size() != 3)
4173 if (failed(positions))
4185 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4193 build(builder, state, inputs, outputs, attributes);
4194 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4195 assert(res &&
"builder didn't return the right type");
4205 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4214 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4215 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4216 assert(res &&
"builder didn't return the right type");
4226 result.addAttribute(
"cast", cast);
4228 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4237 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4238 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4239 assert(res &&
"builder didn't return the right type");
4244 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4246 op->
getAttr(
"indexing_maps"));
4250BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4257 return {mapLHS, mapRHS, mapOut};
4261 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4264 if (maps.size() != 3)
4267 if (failed(positions))
4278 BatchMatmulOp::getRegionBuilder(),
4279 getDefaultIndexingMaps(builder));
4287 build(builder, state, inputs, outputs, attributes);
4288 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4289 assert(res &&
"builder didn't return the right type");
4298 BatchMatmulOp::getRegionBuilder(),
4299 getDefaultIndexingMaps(builder));
4308 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4309 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4310 assert(res &&
"builder didn't return the right type");
4318 result.addAttribute(
"cast", cast);
4320 BatchMatmulOp::getRegionBuilder(),
4321 getDefaultIndexingMaps(builder));
4330 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4331 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4332 assert(res &&
"builder didn't return the right type");
4337 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4339 op->
getAttr(
"indexing_maps"));
4343BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4350 return {mapLHS, mapRHS, mapOut};
4354 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4357 if (maps.size() != 3)
4360 if (failed(positions))
4371 BatchMatmulOp::getRegionBuilder(),
4372 getDefaultIndexingMaps(builder));
4380 build(builder, state, inputs, outputs, attributes);
4381 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4382 assert(res &&
"builder didn't return the right type");
4391 BatchMatmulOp::getRegionBuilder(),
4392 getDefaultIndexingMaps(builder));
4401 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4402 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4403 assert(res &&
"builder didn't return the right type");
4411 result.addAttribute(
"cast", cast);
4413 BatchMatmulOp::getRegionBuilder(),
4414 getDefaultIndexingMaps(builder));
4423 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4424 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4425 assert(res &&
"builder didn't return the right type");
4430 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4432 op->
getAttr(
"indexing_maps"));
4440 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4451 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4452 assert(dimExpr &&
"affine_map is a projected permutation");
4453 dimsInOutput[dimExpr.getPosition()] =
true;
4457 for (
auto dimOccursInOutput : dimsInOutput)
4458 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4459 : utils::IteratorType::reduction);
4461 return iteratorTypes;
4464unsigned ContractOp::getNumRegionArgs() {
return 3; }
4467void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4468 ArrayRef<NamedAttribute> attrs,
4471 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4476 "ContractOp regionBuilder expects 3 args");
4477 RegionBuilderHelper helper(
b, block);
4479 TypeFn castSignedness = TypeFn::cast_signed;
4480 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4481 return attr.
getName() ==
"cast";
4483 if (castIter != attrs.end()) {
4484 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4490 Value lhsAtOutType =
4491 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4492 Value rhsAtOutType =
4493 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4494 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4496 if (!productAtOutType)
4502 helper.yieldOutputs({
result});
4505ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4507 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4509 "expected 'indexing_maps' attribute");
4510 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4516void ContractOp::print(OpAsmPrinter &p) {
4517 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4519 p, getOperation(), getInputs(), getOutputs(),
4520 {
"indexing_maps",
"operandSegmentSizes"});
4523LogicalResult ContractOp::verify() {
4524 int iterationSpaceDims = -1;
4529 SmallVector<size_t> inOccurrences;
4530 SmallVector<size_t> outOccurrences;
4533 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4534 bool isInput) -> LogicalResult {
4537 return emitError(
"provided affine_map is not a projected permutation");
4540 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4542 return emitError(
"ranks of shaped operand and results of corresponding "
4543 "affine_map differ");
4545 return emitError(
"affine_map specifies shaped access while operand has "
4550 if (iterationSpaceDims == -1) {
4552 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4553 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4554 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4555 return emitError(
"iteration spaces of provided affine_maps differ");
4559 for (AffineExpr affineExpr : affineMap.
getResults()) {
4560 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4562 llvm_unreachable(
"affine_map is a projected permutation");
4565 inOccurrences[affineDimExpr.getPosition()] += 1;
4567 outOccurrences[affineDimExpr.getPosition()] += 1;
4573 for (
auto &&[affineMap, operandType, isInput] :
4574 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4575 SmallVector<bool>{
true,
true,
false})) {
4576 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4580 bool hasContractingDim =
false;
4581 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4582 size_t inOccCount = inOccurrences[dimIndex];
4583 size_t outOccCount = outOccurrences[dimIndex];
4586 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4588 if (inOccCount == 0 && outOccCount == 0)
4589 return emitError() <<
"iteration space dim at index " << dimIndex
4590 <<
" not used to access any operand";
4601 if (inOccCount == 1 && outOccCount != 1)
4603 <<
"iteration space dim at index " << dimIndex
4604 <<
" is neither a contracting dim nor of parallel iteration type";
4607 if (!hasContractingDim)
4608 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4613LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4617void ContractOp::getEffects(
4618 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4620 if (hasPureTensorSemantics())
4632SmallVector<AffineMap>
4633BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4634 AffineExpr d0, d1, d2, d3;
4635 SmallVector<AffineMap> indexingMaps;
4637 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4638 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4639 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4640 return indexingMaps;
4643bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4644 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4647 if (maps.size() != 3)
4652 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4653 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4654 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4657SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4658 return SmallVector<utils::IteratorType>{
4659 utils::IteratorType::parallel, utils::IteratorType::parallel,
4660 utils::IteratorType::parallel, utils::IteratorType::reduction};
4663unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4665std::string BatchMatmulOp::getLibraryCallName() {
4671bool BatchMatmulOp::hasUserDefinedMaps() {
4672 SmallVector<AffineMap, 3> defaultMaps =
4674 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4675 return defaultMaps != explicitMaps;
4685bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4687 "Expected less than 3 result dim expr.");
4688 bool isValid =
false;
4689 enum Indices { batchPos, mPos, nPos, kPos };
4691 AffineExpr expr = bcastMap.
getResult(0);
4694 AffineExpr expr0 = bcastMap.
getResult(0);
4695 AffineExpr expr1 = bcastMap.
getResult(1);
4700 : ((expr0.isFunctionOfDim(batchPos) &&
4701 expr1.isFunctionOfDim(kPos)) ||
4702 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4707void BatchMatmulOp::regionBuilder(
4708 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4711 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4716 "BatchMatmulOp regionBuilder expects 3 args");
4717 RegionBuilderHelper helper(
b, block);
4718 SmallVector<Value> yields;
4720 TypeFn castVal = TypeFn::cast_signed;
4721 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4722 return attr.
getName() ==
"cast";
4724 if (castIter != attrs.end()) {
4725 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4730 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4731 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4733 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4734 if (!castValA || !castValB || !mulVal)
4736 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4740 yields.push_back(addVal);
4741 helper.yieldOutputs(yields);
4744ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4745 SmallVector<Attribute, 3> indexingMapsAttr;
4757 if (!isa<AffineMapAttr>(mapAttr)) {
4759 "expected affine map attribute");
4761 indexingMapsAttr.push_back(mapAttr);
4771 if (indexingMapsAttr.empty()) {
4772 indexingMapsAttr = llvm::map_to_vector(
4773 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4774 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4776 result.addAttribute(
"indexing_maps",
4779 return ::parseNamedStructuredOp(parser,
result,
4780 BatchMatmulOp::getNumRegionArgs(),
4781 BatchMatmulOp::getRegionBuilder());
4784void BatchMatmulOp::print(OpAsmPrinter &p) {
4785 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4786 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4787 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4788 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4789 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4791 std::array<StringRef, 3> elidedAttrs = {
4792 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4798LogicalResult BatchMatmulOp::verify() {
4801 if (!hasUserDefinedMaps())
4804 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4811LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4812 SmallVectorImpl<OpFoldResult> &) {
4816void BatchMatmulOp::getEffects(
4817 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4819 if (hasPureTensorSemantics())
4833struct ArityGroupAndKind {
4835 ElementwiseArityGroup arityGroup;
4841 TernaryFn ternaryFn;
4845unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4846 return static_cast<unsigned>(arityGroup);
4851 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4852 constexpr int lastBinary =
4853 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4854 constexpr int lastTernary =
4855 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4857 int val =
static_cast<int>(kind);
4858 ArityGroupAndKind
result;
4860 if (val < lastUnary) {
4861 result.arityGroup = ElementwiseArityGroup::Unary;
4862 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4865 if (val < lastBinary) {
4866 result.arityGroup = ElementwiseArityGroup::Binary;
4867 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4870 if (val >= lastTernary) {
4871 llvm_unreachable(
"unhandled ElementwiseFn");
4873 result.arityGroup = ElementwiseArityGroup::Ternary;
4874 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4879 auto rank = getResultRank();
4884ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4890ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4893 mlir::linalg::ElementwiseKind elemwiseKindVal;
4898 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4899 if (!elemwiseKindAttr)
4901 "expected ElementwiseKind attribute");
4902 elemwiseKindVal = elemwiseKindAttr.getValue();
4905 "expected operation 'kind' attribute");
4908 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4911 SmallVector<Attribute, 3> indexingMapsAttr;
4921 if (!isa<AffineMapAttr>(mapAttr))
4923 "expected affine map attribute");
4924 indexingMapsAttr.push_back(mapAttr);
4935 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4937 ElementwiseOp::getRegionBuilder())) {
4939 "unable to parse elemwise op");
4943 if (indexingMapsAttr.empty()) {
4946 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4947 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4950 "return type needs to be shaped type");
4951 auto numDims = shapedType.getRank();
4952 indexingMapsAttr = llvm::map_to_vector(
4953 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4955 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4958 result.addAttribute(
"indexing_maps",
4963void ElementwiseOp::print(OpAsmPrinter &p) {
4966 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4970 unsigned numDims = getResultRank();
4972 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4973 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4975 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4977 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4978 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4986void ElementwiseOp::regionBuilder(
4987 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4989 ElementwiseKind elemwiseKind;
4990 for (
auto attr : attrs) {
4991 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4992 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4993 assert(kindAttr &&
"op kind attribute incorrectly set");
4994 elemwiseKind = kindAttr.getValue();
5000 auto arityGroup = groupAndKind.arityGroup;
5001 auto kind = groupAndKind.kind;
5003 getArityGroupAsUInt(arityGroup) + 1 ) {
5004 emitError() <<
"Elementwise regionBuilder expects "
5005 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
5010 getArityGroupAsUInt(arityGroup) + 1
5011 &&
"Elementwise regionBuilder number of block args mismatch");
5013 RegionBuilderHelper helper(
b, block);
5014 SmallVector<Value> yields;
5017 if (arityGroup == ElementwiseArityGroup::Unary) {
5020 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
5024 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
5029 assert(
false &&
"found unhandled category in elemwise");
5032 yields.push_back(
result);
5033 helper.yieldOutputs(yields);
5036LogicalResult ElementwiseOp::fold(FoldAdaptor,
5037 SmallVectorImpl<OpFoldResult> &) {
5041void ElementwiseOp::getEffects(
5042 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5044 if (hasPureTensorSemantics())
5057template <
typename OpTy,
typename>
5060 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5061 ? packOrUnPack.getDestType()
5062 : packOrUnPack.getSourceType();
5063 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5064 ? packOrUnPack.getSourceType()
5065 : packOrUnPack.getDestType();
5067 packedType.getShape().take_front(unpackedType.getRank()));
5068 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5089 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5091 .take_back(mixedTiles.size()),
5093 int64_t dimSize = std::get<0>(it);
5094 if (dimSize == ShapedType::kDynamic) {
5095 newMixedTileSizes.push_back(std::get<1>(it));
5098 newMixedTileSizes.push_back(rewriter.
getIndexAttr(dimSize));
5101 return newMixedTileSizes;
5104template <
typename OpTy>
5108 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5109 "applies to only pack or unpack operations");
5110 int64_t destRank = op.getDestRank();
5112 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5113 reifiedReturnShapes[0][dim] =
5118template <
typename OpTy>
5120 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5121 "applies to only pack or unpack operations");
5125 assert(tiles.size() == dimsToTile.size() &&
5126 "tiles must match indices of dimension to block");
5128 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5129 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5130 return dimAndTileMapping;
5133template <
typename OpTy>
5135 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5136 "applies to only pack or unpack operations");
5139 unsigned dynamicValIndex = 0;
5140 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5141 if (ShapedType::isStatic(staticTile))
5144 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5146 return mixedInnerTiles;
5149template <
typename OpTy>
5151 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5152 "applies to only pack or unpack operations");
5165 size_t dimsPosSize = dimsPos.size();
5166 if (dimsPosSize > rank)
5169 if (dimsPosSize != uniqued.size())
5171 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5172 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5176template <
typename OpTy>
5178 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5179 "applies to only pack or unpack operations");
5180 Operation *op = packOrUnPack.getOperation();
5190 if (!packOrUnPack.getSourceType().hasRank() ||
5191 !packOrUnPack.getDestType().hasRank())
5192 return op->
emitError(
"expected both source and destination to have rank");
5195 if (!packOrUnPack.hasPureBufferSemantics() &&
5196 !packOrUnPack.hasPureTensorSemantics())
5197 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5198 const unsigned numResults = packOrUnPack.getNumResults();
5199 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5200 return op->
emitError(
"expected 1 result, got ") << numResults;
5201 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5202 return op->
emitError(
"expected 0 results, got ") << numResults;
5206 if (hasZeros(mixedTiles))
5207 return op->
emitError(
"invalid zero tile factor");
5210 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5211 ? packOrUnPack.getSourceType()
5212 : packOrUnPack.getDestType();
5213 size_t unpackedRank = unpackedType.getRank();
5217 return op->
emitError(
"invalid inner_dims_pos vector");
5219 return op->
emitError(
"invalid outer_dims_perm vector");
5220 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5221 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5225 if (mixedTiles.size() > unpackedRank) {
5226 return op->
emitError(
"tiling factors must be less than or equal to the "
5227 "input rank for pack or output rank for unpack");
5229 if (mixedTiles.size() != innerDimsPos.size()) {
5231 "tiling factors must equal the number of dimensions to tile");
5234 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5235 ? packOrUnPack.getDestType()
5236 : packOrUnPack.getSourceType();
5237 size_t packedRank = packedType.getRank();
5239 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5240 if (expectedPackedRank != packedRank) {
5242 "packed rank != (unpacked rank + num tiling factors), got ")
5243 << packedRank <<
" != " << expectedPackedRank;
5250 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5251 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5252 for (
auto it : llvm::enumerate(llvm::zip(
5253 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5254 int64_t dimSize = std::get<0>(it.value());
5256 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5257 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5258 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5259 if (dimSize != staticTileSize)
5261 "mismatch in inner tile sizes specified and shaped of "
5262 "tiled dimension in the packed type at index ")
5263 << it.index() <<
": got " << dimSize <<
" != " << staticTileSize;
5264 }
else if (!ShapedType::isDynamic(dimSize)) {
5265 return op->
emitError(
"mismatch in inner tile sizes specified at index ")
5266 << it.index() <<
": got static shape " << dimSize
5267 <<
" but dynamic tile size";
5272 auto elementType = unpackedType.getElementType();
5273 Type expectedType, actualType;
5274 if (packOrUnPack.hasPureTensorSemantics()) {
5275 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5276 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5278 expectedType = MemRefType::get(expectedPackedShape, elementType);
5279 actualType = MemRefType::get(packedType.getShape(), elementType);
5282 << expectedType <<
" for the packed domain value, got "
5295struct PackOrUnPackTransposeResult {
5302template <
typename OpTy>
5303static PackOrUnPackTransposeResult
5307 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5308 "applies to only pack or unpack operations");
5309 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5310 "some permutation must be non-empty");
5311 PackOrUnPackTransposeResult metadata;
5312 metadata.innerDimsPos =
5314 metadata.innerTiles =
5316 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5317 ? packOrUnPackOp.getSourceRank()
5318 : packOrUnPackOp.getDestRank();
5319 metadata.outerDimsPerm =
5320 packOrUnPackOp.getOuterDimsPerm().empty()
5321 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5323 if (!innerPermutation.empty()) {
5324 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5326 "invalid inner permutation");
5330 if (!outerPermutation.empty()) {
5331 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5333 "invalid outer permutation");
5344 if (!getResults().empty())
5345 setNameFn(getResult(),
"pack");
5355 Type sourceType, destType, resultType;
5372 SmallVector<int64_t> outerDimsPermVec;
5375 if (parser.parseInteger(value))
5377 outerDimsPermVec.push_back(value);
5387 SmallVector<int64_t> innerDimsPosVec;
5390 if (parser.parseInteger(value))
5392 innerDimsPosVec.push_back(value);
5404 for (
auto val : staticTilesAttr.
asArrayRef())
5405 staticTiles.push_back(val);
5422 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5425 "pack/unpack requires '->' and destination type");
5429 resultType = destType;
5435 if (!paddingValue.empty() &&
5440 if (!dynamicTiles.empty() &&
5445 result.addAttribute(
"static_inner_tiles",
5447 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5449 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5451 SmallVector<int32_t> segmentSizes = {
5452 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5453 static_cast<int32_t
>(dynamicTiles.size())};
5454 result.addAttribute(
"operandSegmentSizes",
5458 result.addTypes(resultType);
5463void PackOp::print(OpAsmPrinter &p) {
5464 p <<
" " << getSource();
5466 if (getPaddingValue()) {
5467 p <<
" padding_value(" << getPaddingValue() <<
" : "
5468 << getPaddingValue().getType() <<
")";
5471 if (!getOuterDimsPerm().empty()) {
5472 p <<
" outer_dims_perm = [";
5473 llvm::interleaveComma(getOuterDimsPerm(), p);
5477 p <<
" inner_dims_pos = [";
5478 llvm::interleaveComma(getInnerDimsPos(), p);
5481 p <<
" inner_tiles = ";
5484 p <<
" into " << getDest();
5487 {
"static_inner_tiles",
"inner_dims_pos",
5488 "outer_dims_perm",
"operandSegmentSizes"});
5490 p <<
" : " << getSource().getType();
5491 p <<
" -> " << getDest().getType();
5494void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5495 Value dest, ArrayRef<int64_t> innerDimsPos,
5496 ArrayRef<OpFoldResult> innerTiles,
5497 std::optional<Value> paddingValue,
5498 ArrayRef<int64_t> outerDimsPerm) {
5499 assert(innerDimsPos.size() == innerTiles.size() &&
5500 "number of tile sizes specified must match the specified number of "
5501 "original dimensions to be tiled");
5502 SmallVector<int64_t> staticTileSizes;
5503 SmallVector<Value> dynamicTileSizes;
5505 build(builder, state, dest.
getType(), source, dest,
5506 paddingValue ? *paddingValue :
nullptr,
5507 outerDimsPerm.empty() ?
nullptr
5514PackOp::reifyResultShapes(OpBuilder &builder,
5523SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5527SmallVector<int64_t> PackOp::getStaticTiles() {
5531ArrayRef<int64_t> PackOp::getAllOuterDims() {
5532 ShapedType inputType = getSourceType();
5533 int64_t inputRank = inputType.getRank();
5534 return getDestType().getShape().take_front(inputRank);
5537SmallVector<int64_t> PackOp::getTiledOuterDims() {
5538 auto innerDimsPos = getInnerDimsPos();
5539 SmallVector<int64_t> outerDims(getAllOuterDims());
5540 SmallVector<int64_t> res;
5543 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5545 if (!outerDimPermInv.empty())
5549 for (
auto index : innerDimsPos)
5550 res.push_back(outerDims[index]);
5555bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5556 ArrayRef<int64_t> innerDimsPos,
5557 ArrayRef<int64_t> outputShape,
5558 ArrayRef<int64_t> outerDimsPerm,
5559 ArrayRef<OpFoldResult> innerTiles) {
5560 SmallVector<int64_t> outputTileSizes(
5561 outputShape.take_front(inputShape.size()));
5562 if (!outerDimsPerm.empty()) {
5563 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5564 "expected output and outer_dims_perm to have same size");
5568 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5569 if (ShapedType::isDynamic(inputShape[pos]))
5572 if (!constantTile) {
5573 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5574 (inputShape[pos] % outputTileSizes[pos] != 0))
5577 assert(*constantTile != 0 &&
"static tile size can't be zero");
5578 if (inputShape[pos] % (*constantTile) != 0) {
5586bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5587 ArrayRef<int64_t> innerDimsPos,
5588 ArrayRef<int64_t> outputShape,
5589 ArrayRef<int64_t> outerDimsPerm,
5590 ArrayRef<OpFoldResult> innerTiles) {
5591 SmallVector<int64_t> outputTileSizes(
5592 outputShape.take_front(inputShape.size()));
5593 if (!outerDimsPerm.empty()) {
5594 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5595 "expected output and outer_dims_perm to have same size");
5599 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5600 if (ShapedType::isDynamic(inputShape[pos]) ||
5601 ShapedType::isDynamic(outputTileSizes[pos]))
5606 assert(*constantTile != 0 &&
"static tile size can't be zero");
5607 if (inputShape[pos] % (*constantTile) != 0)
5613LogicalResult PackOp::verify() {
5620 auto paddingValue = getPaddingValue();
5624 << getSourceType().getElementType()
5625 <<
" but got: " << paddingValue.getType();
5628 if (!paddingValue &&
5629 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5630 getDestType().
getShape(), getOuterDimsPerm(),
5633 "invalid tile factor or output size provided. Only full tiles are "
5634 "supported when padding_value is not set");
5641static SmallVector<int64_t>
5644 for (
auto o : ofrs) {
5646 if (llvm::dyn_cast_if_present<Value>(o))
5647 result.push_back(ShapedType::kDynamic);
5659 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5660 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5662 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5663 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5666 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5667 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5671 if (!outerDimsPerm.empty())
5675 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5679SmallVector<OpFoldResult> PackOp::getResultShape(
5680 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5681 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5682 ArrayRef<int64_t> outerDimsPerm) {
5683 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5687 AffineExpr ceilDivExpr = s0.
ceilDiv(s1);
5688 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5690 builder, loc, ceilDivExpr,
5691 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5693 if (!outerDimsPerm.empty())
5695 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5697 SmallVector<int64_t> resultTypeShape =
5700 innerDimsPos, outerDimsPerm);
5706 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5707 if (ShapedType::isStatic(resultTypeShape[i]))
5716RankedTensorType PackOp::inferPackedTensorType(
5717 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5718 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5719 SmallVector<int64_t> resultShape = inferPackedShape(
5720 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5721 return RankedTensorType::get(resultShape, sourceType.getElementType());
5724MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5725 ArrayRef<int64_t> innerTileSizes,
5726 ArrayRef<int64_t> innerDimsPos,
5727 ArrayRef<int64_t> outerDimsPerm) {
5728 SmallVector<int64_t> resultShape = inferPackedShape(
5729 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5730 return MemRefType::get(resultShape, sourceType.getElementType());
5733Value PackOp::createDestinationTensor(OpBuilder &
b, Location loc, Value source,
5734 ArrayRef<OpFoldResult> innerTileSizes,
5735 ArrayRef<int64_t> innerDimsPos,
5736 ArrayRef<int64_t> outerDimsPerm) {
5737 AffineExpr dim0, dim1;
5739 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5744 SmallVector<OpFoldResult> mixedSizes;
5745 for (
auto [index, value] : llvm::enumerate(
5746 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5747 if (ShapedType::isDynamic(value))
5748 mixedSizes.push_back(
5749 tensor::DimOp::create(
b, loc, source, index).getResult());
5751 mixedSizes.push_back(
b.getIndexAttr(value));
5753 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5754 int64_t dimPos = std::get<0>(it);
5755 OpFoldResult tileSize = std::get<1>(it);
5756 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5758 if (!outerDimsPerm.empty())
5761 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5762 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5763 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5766PackOp PackOp::createTransposedClone(OpBuilder &
b, Location loc,
5767 ArrayRef<int64_t> innerPermutation,
5768 ArrayRef<int64_t> outerPermutation) {
5770 *
this, innerPermutation, outerPermutation);
5771 Value transposedDest =
5772 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5773 metadata.innerDimsPos, metadata.outerDimsPerm);
5774 return PackOp::create(
b, loc, getSource(), transposedDest,
5775 metadata.innerDimsPos, metadata.innerTiles,
5776 getPaddingValue(), metadata.outerDimsPerm);
5779template <
typename OpTy>
5784 if (op.hasPureTensorSemantics())
5787 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5788 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5791 if (&opOperand == &op.getSourceMutable()) {
5795 }
else if (&opOperand == &op.getDestMutable()) {
5806void PackOp::getEffects(
5812void UnPackOp::getEffects(
5819template <
typename OpTy>
5821 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5822 "applies to only pack or unpack operations");
5823 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5825 : op.getSourceType();
5827 for (
auto [dimDest,
tile] : llvm::zip(
5828 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5830 if (!constTileSize || ShapedType::isDynamic(dimDest))
5837 if (!hasPureTensorSemantics())
5839 if (getPaddingValue())
5854 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5856 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5868 auto packTiles = packOp.getMixedTiles();
5869 auto unPackTiles = unPackOp.getMixedTiles();
5870 if (packTiles.size() != unPackTiles.size())
5872 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5881 auto srcType = op.getSourceType();
5882 if (llvm::any_of(op.getInnerDimsPos(),
5883 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5885 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5887 return !PackOp::requirePaddingValue(
5888 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5889 op.getOuterDimsPerm(), op.getMixedTiles());
5896 bool changeNeeded =
false;
5897 srcShape.assign(packOp.getSourceType().getShape().begin(),
5898 packOp.getSourceType().getShape().end());
5899 destShape.assign(packOp.getDestType().getShape().begin(),
5900 packOp.getDestType().getShape().end());
5901 llvm::SmallSetVector<int64_t, 4> innerDims;
5902 innerDims.insert_range(packOp.getInnerDimsPos());
5904 if (!packOp.getOuterDimsPerm().empty())
5906 int srcRank = packOp.getSourceRank();
5907 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5908 if (innerDims.contains(i))
5912 if (!inverseOuterDimsPerm.empty())
5913 destPos = inverseOuterDimsPerm[srcPos];
5914 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5915 ShapedType::isDynamic(destShape[destPos])) {
5918 int64_t size = srcShape[srcPos];
5919 if (ShapedType::isDynamic(size))
5920 size = destShape[destPos];
5921 srcShape[srcPos] = size;
5922 destShape[destPos] = size;
5923 changeNeeded =
true;
5925 return changeNeeded;
5928LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5930 if (!packOp.hasPureTensorSemantics())
5934 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5935 if (unPackOp.getSourceType() == packOp.getDestType() &&
5936 !packOp.getPaddingValue() &&
5939 rewriter.
replaceOp(packOp, unPackOp.getSource());
5947 packOp.getPaddingValueMutable().clear();
5953 SmallVector<int64_t> srcShape, destShape;
5955 Location loc = packOp.getLoc();
5956 Value source = packOp.getSource();
5957 if (srcShape != packOp.getSourceType().getShape()) {
5958 auto newSrcType = packOp.getSourceType().clone(srcShape);
5960 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5962 Value dest = packOp.getDest();
5963 ShapedType originalResultType = packOp.getDestType();
5964 bool needUpdateDestType = (destShape != originalResultType.getShape());
5965 if (needUpdateDestType) {
5966 auto newDestType = packOp.getDestType().clone(destShape);
5968 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5971 packOp.getSourceMutable().assign(source);
5972 packOp.getDestMutable().assign(dest);
5973 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5976 if (needUpdateDestType) {
5978 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5979 packOp.getResult());
5988template <
typename PackOrUnpackOp>
5990 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5991 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5992 "Function meant for pack/unpack");
5997 int64_t numPackedDims = innerDimsPos.size();
5998 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5999 if (orderedDims != innerDimsPos) {
6005 int64_t packedRank = packedTensorType.getRank();
6015 return llvm::all_of(
6016 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6017 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
6020bool PackOp::isLikePad() {
6021 auto packedTensorType =
6022 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6026::mlir::LogicalResult
6027PackOp::fold(FoldAdaptor adaptor,
6029 if (!hasPureTensorSemantics())
6031 std::optional<Attribute> paddingValue;
6032 if (
auto pad = adaptor.getPaddingValue())
6034 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6035 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6036 cast<TensorType>(getDestType()), paddingValue)) {
6037 results.push_back(reshapedSource);
6063 if (!op.hasPureTensorSemantics())
6084 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6085 op.getInnerDimsPos(), newMixedTileSizes,
6086 op.getPaddingValue(), op.getOuterDimsPerm());
6087 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6090 Value oldResult = op.getResult();
6091 Value newResult = newOp.getResult();
6094 ? tensor::CastOp::create(rewriter, op->getLoc(),
6095 oldResult.
getType(), newResult)
6108void UnPackOp::getAsmResultNames(
6110 if (!getResults().empty())
6111 setNameFn(getResult(),
"unpack");
6120 Type sourceType, destType, resultType;
6132 if (parser.parseInteger(value))
6134 outerDimsPermVec.push_back(value);
6144 SmallVector<int64_t> innerDimsPosVec;
6147 if (parser.parseInteger(value))
6149 innerDimsPosVec.push_back(value);
6161 for (
auto val : staticTilesAttr.
asArrayRef())
6162 staticTiles.push_back(val);
6179 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6182 "pack/unpack requires '->' and destination type");
6186 resultType = destType;
6192 if (!dynamicTiles.empty() &&
6197 result.addAttribute(
"static_inner_tiles",
6199 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6201 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6203 SmallVector<int32_t> segmentSizes = {
6204 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6205 result.addAttribute(
"operandSegmentSizes",
6209 result.addTypes(resultType);
6214void UnPackOp::print(OpAsmPrinter &p) {
6215 p <<
" " << getSource();
6217 if (!getOuterDimsPerm().empty()) {
6218 p <<
" outer_dims_perm = [";
6219 llvm::interleaveComma(getOuterDimsPerm(), p);
6223 p <<
" inner_dims_pos = [";
6224 llvm::interleaveComma(getInnerDimsPos(), p);
6227 p <<
" inner_tiles = ";
6230 p <<
" into " << getDest();
6233 {
"static_inner_tiles",
"inner_dims_pos",
6234 "outer_dims_perm",
"operandSegmentSizes"});
6236 p <<
" : " << getSource().getType();
6237 p <<
" -> " << getDest().getType();
6241UnPackOp::reifyResultShapes(OpBuilder &builder,
6250SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6254SmallVector<int64_t> UnPackOp::getStaticTiles() {
6258ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6259 ShapedType destType = getDestType();
6260 int64_t destRank = destType.getRank();
6261 return getSourceType().getShape().take_front(destRank);
6264SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6265 auto innerDimsPos = getInnerDimsPos();
6266 SmallVector<int64_t> outerDims(getAllOuterDims());
6267 SmallVector<int64_t> res;
6270 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6272 if (!outerDimPermInv.empty())
6276 for (
auto index : innerDimsPos)
6277 res.push_back(outerDims[index]);
6282LogicalResult UnPackOp::verify() {
6287 if (!hasPureTensorSemantics())
6296void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6297 Value dest, ArrayRef<int64_t> innerDimsPos,
6298 ArrayRef<OpFoldResult> innerTiles,
6299 ArrayRef<int64_t> outerDimsPerm) {
6300 assert(innerDimsPos.size() == innerTiles.size() &&
6301 "number of tile sizes specified must match the specified number of "
6302 "original dimensions to be tiled");
6303 SmallVector<int64_t> staticTileSizes;
6304 SmallVector<Value> dynamicTileSizes;
6306 build(builder, state, dest.
getType(), source, dest,
6307 outerDimsPerm.empty() ?
nullptr
6313Value UnPackOp::createDestinationTensor(OpBuilder &
b, Location loc,
6315 ArrayRef<OpFoldResult> innerTileSizes,
6316 ArrayRef<int64_t> innerDimsPos,
6317 ArrayRef<int64_t> outerDimsPerm) {
6318 AffineExpr sym0, sym1;
6320 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6324 SmallVector<OpFoldResult> mixedSizes;
6325 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6327 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6328 if (srcType.isDynamicDim(i))
6329 mixedSizes.push_back(
6330 tensor::DimOp::create(
b, loc, source, i).getResult());
6332 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6334 if (!outerDimsPerm.empty()) {
6339 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6340 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6342 auto elemType = srcType.getElementType();
6343 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6346UnPackOp UnPackOp::createTransposedClone(OpBuilder &
b, Location loc,
6347 Value transposedSource,
6348 ArrayRef<int64_t> innerPermutation,
6349 ArrayRef<int64_t> outerPermutation) {
6351 *
this, innerPermutation, outerPermutation);
6352 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6353 metadata.innerDimsPos, metadata.innerTiles,
6354 metadata.outerDimsPerm);
6361 bool changeNeeded =
false;
6362 srcShape.assign(op.getSourceType().getShape().begin(),
6363 op.getSourceType().getShape().end());
6364 destShape.assign(op.getDestType().getShape().begin(),
6365 op.getDestType().getShape().end());
6366 llvm::SmallSetVector<int64_t, 4> innerDims;
6367 innerDims.insert_range(op.getInnerDimsPos());
6369 if (!op.getOuterDimsPerm().empty())
6371 int destRank = op.getDestRank();
6372 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6373 if (innerDims.contains(i))
6377 if (!inverseOuterDimsPerm.empty())
6378 srcPos = inverseOuterDimsPerm[destPos];
6379 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6380 ShapedType::isDynamic(destShape[destPos])) {
6383 int64_t size = srcShape[srcPos];
6384 if (ShapedType::isDynamic(size))
6385 size = destShape[destPos];
6386 srcShape[srcPos] = size;
6387 destShape[destPos] = size;
6388 changeNeeded =
true;
6390 return changeNeeded;
6393LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6396 if (!unPackOp.hasPureTensorSemantics())
6400 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6401 if (packOp.getSourceType() != unPackOp.getDestType())
6403 if (packOp.getPaddingValue() ||
6407 rewriter.
replaceOp(unPackOp, packOp.getSource());
6411 if (
auto dstStyleOp =
6412 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6413 auto destValue = cast<OpResult>(unPackOp.getDest());
6414 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6416 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6420 if (unPackOp->hasOneUse()) {
6421 auto extractSliceUser =
6422 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6423 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6424 OpBuilder::InsertionGuard g(rewriter);
6426 auto newDest = tensor::ExtractSliceOp::create(
6427 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6428 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6429 extractSliceUser.getMixedStrides());
6431 unPackOp.setDpsInitOperand(0, newDest);
6432 unPackOp.getResult().setType(newDest.
getType());
6434 rewriter.
replaceOp(extractSliceUser, unPackOp);
6440 SmallVector<int64_t> srcShape, destShape;
6442 Location loc = unPackOp.getLoc();
6443 Value source = unPackOp.getSource();
6444 if (srcShape != unPackOp.getSourceType().getShape()) {
6445 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6446 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6447 unPackOp.getSource());
6449 Value dest = unPackOp.getDest();
6450 if (destShape != unPackOp.getDestType().getShape()) {
6451 auto newDestType = unPackOp.getDestType().clone(destShape);
6452 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6453 unPackOp.getDest());
6455 UnPackOp newOp = UnPackOp::create(
6456 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6457 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6459 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6466bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6468 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6473 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6474 SmallVector<int64_t> outerShapeWithoutTranspose =
6476 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(),
false);
6477 for (
auto [pos, tileSize] :
6478 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6479 areOuterDimsTiled[pos] =
true;
6480 if (unpackedTypeAfterFold.isDynamicDim(pos))
6482 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6484 if (ShapedType::isDynamic(tileSize))
6486 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6487 unpackedTypeAfterFold.getDimSize(pos);
6488 if (paddingSize >= tileSize)
6492 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6493 if (areOuterDimsTiled[pos])
6495 int64_t dim = outerShapeWithoutTranspose[pos];
6496 if (ShapedType::isDynamic(dim))
6498 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6504bool UnPackOp::isLikeUnPad() {
6505 ShapedType packedTensorType = getSourceType();
6509::mlir::LogicalResult
6510UnPackOp::fold(FoldAdaptor adaptor,
6511 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6513 if (!hasPureTensorSemantics())
6516 if (OpFoldResult reshapedSource = reshapeConstantSource(
6517 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6518 cast<TensorType>(getResult().
getType()))) {
6519 results.push_back(reshapedSource);
6545 if (!op.hasPureTensorSemantics())
6554 Value sourceTensor = newOperands[0];
6558 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6564 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6565 newOperands[1], op.getInnerDimsPos(),
6566 newMixedTileSizes, op.getOuterDimsPerm());
6567 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6570 Value oldResult = op.getResult();
6571 Value newResult = newOp.getResult();
6574 ? tensor::CastOp::create(rewriter, op->getLoc(),
6575 oldResult.
getType(), newResult)
6589 utils::IteratorType::reduction, utils::IteratorType::parallel,
6590 utils::IteratorType::parallel, utils::IteratorType::reduction};
6593SmallVector<AffineMap>
6594BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6595 AffineExpr d0, d1, d2, d3;
6596 SmallVector<AffineMap> indexingMaps;
6598 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6599 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6601 return indexingMaps;
6604bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6605 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6608 if (maps.size() != 3)
6613 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6614 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6615 (*positions)[2] == SmallVector<int64_t>{1, 2};
6617unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6619std::string BatchReduceMatmulOp::getLibraryCallName() {
6625bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6626 SmallVector<AffineMap, 3> defaultMaps =
6628 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6629 return defaultMaps != explicitMaps;
6639bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6642 "Expected less than 3 result dim expr.");
6643 bool isValid =
false;
6644 enum Indices { batchPos, mPos, nPos, kPos };
6646 AffineExpr expr = bcastMap.
getResult(0);
6649 AffineExpr expr0 = bcastMap.
getResult(0);
6650 AffineExpr expr1 = bcastMap.
getResult(1);
6655 : ((expr0.isFunctionOfDim(batchPos) &&
6656 expr1.isFunctionOfDim(kPos)) ||
6657 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6662void BatchReduceMatmulOp::regionBuilder(
6663 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
6666 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6671 "BatchReduceMatmulOp regionBuilder expects 3 args");
6672 RegionBuilderHelper helper(
b, block);
6673 SmallVector<Value> yields;
6677 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6679 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6681 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6682 if (!castValA || !castValB || !mulVal)
6685 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6688 yields.push_back(addVal);
6689 helper.yieldOutputs(yields);
6692ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6693 OperationState &
result) {
6694 SmallVector<Attribute, 3> indexingMapsAttr;
6705 if (!isa<AffineMapAttr>(mapAttr)) {
6707 "expected affine map attribute");
6709 indexingMapsAttr.push_back(mapAttr);
6719 if (indexingMapsAttr.empty()) {
6720 indexingMapsAttr = llvm::map_to_vector(
6721 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6722 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6724 result.addAttribute(
"indexing_maps",
6726 return ::parseNamedStructuredOp(parser,
result,
6727 BatchReduceMatmulOp::getNumRegionArgs(),
6728 BatchReduceMatmulOp::getRegionBuilder());
6731void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6732 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6733 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6734 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
6736 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6737 p <<
" indexing_maps = [";
6738 llvm::interleaveComma(getIndexingMaps(), p,
6743 SmallVector<StringRef, 3> elidedAttrs = {
6744 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6750LogicalResult BatchReduceMatmulOp::verify() {
6753 if (!hasUserDefinedMaps())
6756 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6762LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6763 SmallVectorImpl<OpFoldResult> &) {
6766void BatchReduceMatmulOp::getEffects(
6767 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6769 if (hasPureTensorSemantics())
6785void LinalgDialect::getCanonicalizationPatterns(
6794 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override