39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SetOperations.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/StringSet.h"
44#include "llvm/ADT/TypeSwitch.h"
45#include "llvm/Support/FormatVariadic.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
48#include "llvm/Support/MathExtras.h"
49#include "llvm/Support/raw_ostream.h"
59 auto type = cast<ShapedType>(v.
getType());
60 if (!type.isDynamicDim(dim))
65 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
68 .Case<MemRefType>([&](MemRefType t) ->
Value {
69 return memref::DimOp::create(builder, loc, v, dim);
80 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
81 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
84 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
85 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
98 return b.createOrFold<memref::DimOp>(loc, source, dim);
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
100 return b.createOrFold<tensor::DimOp>(loc, source, dim);
101 llvm_unreachable(
"Expected MemRefType or TensorType");
106 auto shapedType = llvm::cast<ShapedType>(source.
getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109 return b.getIndexAttr(shapedType.getDimSize(dim));
132 for (
auto containers : {inputTypes, outputTypes}) {
133 for (
auto t : containers) {
145 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
161 std::optional<TypeRange> resultTensorTypes,
168 if (!resultTensorTypes)
169 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
178 "operandSegmentSizes",
179 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180 static_cast<int32_t>(outputs.size())}));
190 std::optional<TypeRange> resultTensorTypes,
197 indexingMapsAttrVal =
199 return AffineMapAttr::get(map);
201 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
207 std::optional<TypeRange> resultTensorTypes,
214 indexingMapsAttrVal =
216 return AffineMapAttr::get(map);
218 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
220 attributes, regionBuilder);
224 std::optional<TypeRange> resultTensorTypes,
231 indexingMapsAttrVal =
233 return AffineMapAttr::get(map);
235 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
237 attributes, regionBuilder);
246 bool addOperandSegmentSizes =
true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
276 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
282 if (addOperandSegmentSizes) {
289 if (
result.propertiesAttr) {
291 attrs.
append(
"operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
297 result.addAttribute(
"operandSegmentSizes",
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
303 if (!
result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
307 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 <<
"'" << result.name.getStringRef() <<
"' op ";
320 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
321 if (!outputs.empty())
322 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
336 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
344 opBuilder, region, inputTypes, outputTypes, attrs,
363 unsigned numRegionArgs,
380 result.addTypes(outputTensorsTypes);
382 std::unique_ptr<Region> region = std::make_unique<Region>();
384 outputTypes,
result.attributes.getAttrs(),
387 result.addRegion(std::move(region));
394 if (resultTypes.empty())
439class RegionBuilderHelper {
441 RegionBuilderHelper(OpBuilder &builder,
Block &block)
442 : builder(builder), block(block) {}
445 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
447 if (!isFloatingPoint(arg)) {
449 emitError() <<
"unsupported non numeric type";
452 llvm_unreachable(
"unsupported non numeric type");
454 OpBuilder::InsertionGuard g(builder);
455 builder.setInsertionPointToEnd(&block);
458 return math::ExpOp::create(builder, arg.
getLoc(), arg);
460 return math::LogOp::create(builder, arg.
getLoc(), arg);
462 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
464 return math::CeilOp::create(builder, arg.
getLoc(), arg);
466 return math::FloorOp::create(builder, arg.
getLoc(), arg);
468 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
469 case UnaryFn::reciprocal: {
470 Attribute oneAttr = builder.getOneAttr(arg.
getType());
471 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
476 return math::RoundOp::create(builder, arg.
getLoc(), arg);
478 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
480 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
484 return math::TanhOp::create(builder, arg.
getLoc(), arg);
486 return math::ErfOp::create(builder, arg.
getLoc(), arg);
489 emitError() <<
"unsupported unary function";
492 llvm_unreachable(
"unsupported unary function");
499 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
506 if (!allComplex && !allFloatingPoint && !allInteger) {
509 <<
"Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
514 llvm_unreachable(
"unsupported non numeric type");
516 OpBuilder::InsertionGuard g(builder);
517 builder.setInsertionPointToEnd(&block);
521 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
525 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
529 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 emitError() <<
"unsupported operation: sub with bools";
537 llvm_unreachable(
"unsupported operation: sub with bools");
539 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
546 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 emitError() <<
"unsupported operation: div with bools";
558 llvm_unreachable(
"unsupported operation: div with bools");
560 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
564 emitError() <<
"unsupported operation: unsigned div not on uint";
567 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
569 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
582 if (allFloatingPoint)
583 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
584 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
585 case BinaryFn::min_unsigned:
587 if (allFloatingPoint)
588 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
589 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
591 assert(allFloatingPoint);
592 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
595 emitError() <<
"unsupported binary function";
598 llvm_unreachable(
"unsupported binary function");
602 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
606 bool tailFloatingPoint =
607 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
608 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
609 OpBuilder::InsertionGuard g(builder);
610 builder.setInsertionPointToEnd(&block);
612 case TernaryFn::select:
613 if (!headBool && !(tailFloatingPoint || tailInteger))
614 llvm_unreachable(
"unsupported non numeric type");
615 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
618 emitError() <<
"unsupported ternary function";
621 llvm_unreachable(
"unsupported ternary function");
625 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
628 case TypeFn::cast_signed:
629 return cast(toType, operand,
false);
630 case TypeFn::cast_unsigned:
631 return cast(toType, operand,
true);
634 emitError() <<
"unsupported type conversion function";
637 llvm_unreachable(
"unsupported type conversion function");
641 OpBuilder::InsertionGuard g(builder);
642 builder.setInsertionPointToEnd(&block);
643 Location loc = builder.getUnknownLoc();
644 YieldOp::create(builder, loc, values);
647 Value constant(
const std::string &value) {
648 OpBuilder::InsertionGuard g(builder);
649 builder.setInsertionPointToEnd(&block);
650 Location loc = builder.getUnknownLoc();
651 Attribute valueAttr =
parseAttribute(value, builder.getContext());
652 return arith::ConstantOp::create(builder, loc,
653 ::cast<TypedAttr>(valueAttr));
656 Value index(int64_t dim) {
657 OpBuilder::InsertionGuard g(builder);
658 builder.setInsertionPointToEnd(&block);
659 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
662 Type getIntegerType(
unsigned width) {
663 return IntegerType::get(builder.getContext(), width);
666 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
667 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
674 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
675 OpBuilder::InsertionGuard g(builder);
676 builder.setInsertionPointToEnd(&block);
677 auto loc = operand.
getLoc();
678 if (isa<UnknownLoc>(loc)) {
688 bool isComplex(Value value) {
689 return llvm::isa<ComplexType>(value.
getType());
691 bool isFloatingPoint(Value value) {
692 return llvm::isa<FloatType>(value.
getType());
694 bool isInteger(Value value) {
695 return llvm::isa<IntegerType>(value.
getType());
711 using OpRewritePattern<CopyOp>::OpRewritePattern;
712 LogicalResult matchAndRewrite(CopyOp copyOp,
713 PatternRewriter &rewriter)
const override {
714 if (copyOp.getInputs() != copyOp.getOutputs())
716 if (copyOp.hasPureBufferSemantics())
719 rewriter.
replaceOp(copyOp, copyOp.getInputs());
729 results.
add<EraseSelfCopy>(context);
742template <
typename TensorReshapeOp>
744 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
745 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
746 PatternRewriter &rewriter)
const override {
747 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
751 Location loc = oldFill.getLoc();
752 TensorReshapeOp newInit;
753 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
755 newInit = TensorReshapeOp::create(
756 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
757 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
758 reshapeOp.getStaticOutputShape());
760 newInit = TensorReshapeOp::create(
761 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
762 reshapeOp.getReassociation());
775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
776 PatternRewriter &rewriter)
const override {
777 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
783 Value padValue = padOp.getConstantPaddingValue();
784 if (!padValue || fillOp.value() != padValue)
790 padOp,
"failed to reify tensor.pad op result shape");
793 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
794 padOp.getResultType().getElementType());
796 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
799 if (
replacement.getType() != padOp.getResultType()) {
800 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
811struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
814 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
815 PatternRewriter &rewriter)
const override {
816 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
820 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
825 Value firstDest = insertOp.getDest();
826 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
827 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
832 bool disjoint =
false;
833 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
836 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
837 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
838 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
842 int64_t prevStart = prevOp.getStaticOffset(i);
843 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
844 prevOp.getStaticStride(i);
845 int64_t nextStart = insertOp.getStaticOffset(i);
846 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
847 insertOp.getStaticStride(i);
848 if (prevEnd < nextStart || nextEnd < prevStart) {
856 firstDest = prevOp.getDest();
867 Value padValue = srcPadOp.getConstantPaddingValue();
868 if (!padValue || dstFillOp.value() != padValue)
871 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
872 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
874 Location loc = insertOp.getLoc();
877 AffineExpr sym0, sym1;
883 SmallVector<OpFoldResult, 4> newOffsets;
884 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
886 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
889 RankedTensorType srcPadType = srcPadOp.getSourceType();
890 SmallVector<OpFoldResult, 4> newSizes;
891 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
892 if (srcPadType.isDynamicDim(i)) {
894 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
897 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
902 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
903 newSizes, insertOp.getMixedStrides());
909struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
911 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
913 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
914 PatternRewriter &rewriter)
const override {
917 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
922 Value extractedScalar = fillOp.getInputs()[0];
925 rewriter.
replaceOp(extractOp, extractedScalar);
933static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
934 linalg::PackOp packOp) {
935 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
939 if (
auto paddingValue = packOp.getPaddingValue())
943 Value packOpDest = packOp.getDest();
947 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
954 FoldFillWithPack(MLIRContext *context)
955 : OpRewritePattern<linalg::PackOp>(context) {}
957 LogicalResult matchAndRewrite(linalg::PackOp packOp,
958 PatternRewriter &rewriter)
const override {
959 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
962 rewriter.
replaceOp(packOp, fillOp.value().result());
969 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
971 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
972 PatternRewriter &rewriter)
const override {
973 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
976 copyOp.getOutputs());
979 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
981 fillOp.getOutputs());
990 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
992 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
993 PatternRewriter &rewriter)
const override {
994 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
996 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
997 transposeOp.getDpsInitOperand(0)->get());
1009 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1010 PatternRewriter &rewriter)
const override {
1011 auto concatOperands = concatOp.getInputs();
1012 if (concatOperands.empty()) {
1016 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1021 OpFoldResult firstFillVal =
1024 SmallVector<Value> allOuts;
1025 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1027 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1028 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1033 OpFoldResult fillVal =
1035 if (fillVal != firstFillVal)
1038 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1041 if (!llvm::all_of(concatOperands.drop_front(),
1042 isDefinedByCompatibleFillOp)) {
1044 concatOp,
"not all operands are defined by a compatible fill op");
1047 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1048 concatOp.getDim(), allOuts);
1050 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1059 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1060 FoldFillWithPack, FoldFillWithPad,
1061 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1062 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1063 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1076 for (
ValueRange container : {inputs, outputs}) {
1077 for (
Value v : container) {
1078 Type t = v.getType();
1079 blockArgTypes.push_back(
1081 blockArgLocs.push_back(v.getLoc());
1087 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1091void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1093 for (Value v : getRegionInputArgs())
1095 for (Value v : getRegionOutputArgs())
1096 setNameFn(v,
"out");
1099void GenericOp::build(
1100 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1102 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1104 ArrayRef<NamedAttribute> attributes) {
1105 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1106 iteratorTypes, doc, libraryCall);
1107 result.addAttributes(attributes);
1110 inputs, outputs, bodyBuild);
1113void GenericOp::build(
1114 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1116 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1117 StringRef libraryCall,
1119 ArrayRef<NamedAttribute> attributes) {
1120 build(builder,
result, resultTensorTypes, inputs, outputs,
1124 [&](utils::IteratorType iter) -> mlir::Attribute {
1125 return IteratorTypeAttr::get(builder.getContext(), iter);
1128 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1129 bodyBuild, attributes);
1132void GenericOp::build(
1134 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1135 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1136 StringRef libraryCall,
1138 ArrayRef<NamedAttribute> attributes) {
1140 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1143void GenericOp::build(
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes,
1148 ArrayRef<NamedAttribute> attributes) {
1149 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1151 "", bodyBuild, attributes);
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1163 "", bodyBuild, attributes);
1166void GenericOp::print(OpAsmPrinter &p) {
1170 auto genericAttrNames = linalgTraitAttrNames();
1172 llvm::StringSet<> genericAttrNamesSet;
1173 genericAttrNamesSet.insert_range(genericAttrNames);
1174 SmallVector<NamedAttribute, 8> genericAttrs;
1175 for (
auto attr : (*this)->getAttrs()) {
1176 if (attr.getName() == getIteratorTypesAttrName()) {
1177 auto iteratorTypes =
1178 llvm::cast<ArrayAttr>(attr.getValue())
1179 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1184 SmallVector<Attribute> iteratorTypeNames =
1185 llvm::to_vector(llvm::map_range(
1186 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1187 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1190 genericAttrs.emplace_back(
1191 getIteratorTypesAttrName(),
1192 ArrayAttr::get(
getContext(), iteratorTypeNames));
1193 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1194 genericAttrs.push_back(attr);
1197 if (!genericAttrs.empty()) {
1198 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1199 p << genericDictAttr;
1205 genericAttrNames.push_back(
"operandSegmentSizes");
1206 genericAttrNamesSet.insert(genericAttrNames.back());
1208 bool hasExtraAttrs =
false;
1209 for (NamedAttribute n : (*this)->getAttrs()) {
1210 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1213 if (hasExtraAttrs) {
1220 if (!getRegion().empty()) {
1229ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1230 DictionaryAttr dictAttr;
1238 result.attributes.assign(dictAttr.getValue().begin(),
1239 dictAttr.getValue().end());
1245 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1246 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1247 if (!iteratorTypes) {
1248 return parser.
emitError(attributeLocation)
1249 <<
"expected " << getIteratorTypesAttrName(
result.name)
1250 <<
" array attribute";
1253 SmallVector<Attribute> iteratorTypeAttrs;
1255 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1256 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1257 if (!maybeIteratorType.has_value())
1259 <<
"unexpected iterator_type (" << s <<
")";
1261 iteratorTypeAttrs.push_back(
1262 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1264 result.attributes.set(getIteratorTypesAttrName(
result.name),
1268 SmallVector<Type, 1> inputTypes, outputTypes;
1278 std::unique_ptr<Region> region = std::make_unique<Region>();
1281 result.addRegion(std::move(region));
1287 SmallVector<Type, 1> outputTensorsTypes;
1290 result.addTypes(outputTensorsTypes);
1298 LinalgOp linalgOp) {
1299 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1300 if (!llvm::isa<MemRefType>(operand.
getType()))
1302 effects.emplace_back(
1307 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1308 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1310 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321void GenericOp::getEffects(
1322 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1331 if (!linalgOp.hasPureTensorSemantics())
1349template <
typename OpTy>
1350struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1351 using OpRewritePattern<OpTy>::OpRewritePattern;
1353 LogicalResult matchAndRewrite(OpTy linalgOp,
1354 PatternRewriter &rewriter)
const override {
1356 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1361 Block &body = linalgOp->getRegion(0).front();
1362 if (!llvm::hasSingleElement(body))
1364 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1369 if (linalgOp.hasPureBufferSemantics()) {
1370 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1371 linalgOp.getDpsInputOperand(0)->get() !=
1372 linalgOp.getDpsInitOperand(0)->get()) {
1374 linalgOp,
"expected single input and output to be the same value");
1377 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1378 if (!yieldArg || yieldArg.getOwner() != &body) {
1380 "cannot fold fill-like op");
1387 if (!linalgOp.hasPureTensorSemantics()) {
1389 linalgOp,
"mixed semantics is not supported yet");
1394 SmallVector<Value> returnedArgs;
1395 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1396 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1397 if (!yieldArg || yieldArg.getOwner() != &body)
1399 unsigned argumentNumber = yieldArg.getArgNumber();
1400 Value returnedArg = linalgOp->getOperand(argumentNumber);
1401 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1404 Type returnType = returnedArg.
getType();
1405 if (returnType != resultType) {
1410 returnedArg = sparse_tensor::ConvertOp::create(
1411 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1413 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1416 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1417 resultType, returnedArg);
1420 returnedArgs.push_back(returnedArg);
1423 if (returnedArgs.size() != linalgOp->getNumResults())
1425 rewriter.
replaceOp(linalgOp, returnedArgs);
1432void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1437LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1456 for (
Type outputType : outputTypes) {
1457 if (llvm::isa<RankedTensorType>(outputType))
1458 result.addTypes(outputType);
1462 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1471void MapOp::getAsmBlockArgumentNames(Region ®ion,
1473 for (Value v : getRegionInputArgs())
1475 for (Value v : getRegionOutputArgs())
1476 setNameFn(v,
"init");
1479void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1480 if (!getResults().empty())
1481 setNameFn(getResults().front(),
"mapped");
1487 ArrayRef<NamedAttribute> attributes) {
1489 result.addAttributes(attributes);
1492 Type initType = init.
getType();
1493 if (llvm::isa<RankedTensorType>(initType))
1494 result.addTypes(initType);
1498 inputs, {init}, bodyBuild);
1505 bool initFirst =
false,
bool mapInit =
true) {
1509 b.setInsertionPointToStart(&block);
1510 for (
auto &operand : operands) {
1512 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1520 payloadOpOperands.push_back(block.
getArguments().back());
1521 for (
const auto &arg : block.
getArguments().drop_back())
1522 payloadOpOperands.push_back(arg);
1531 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1537ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1538 std::optional<OperationName> payloadOpName;
1539 NamedAttrList payloadOpAttrs;
1542 if (
failed(operationName))
1546 payloadOpName = operationName.value();
1554 if (payloadOpName.has_value()) {
1555 if (!
result.operands.empty())
1557 payloadOpAttrs, ArrayRef(
result.operands),
false,
1562 SmallVector<OpAsmParser::Argument> regionArgs;
1567 Region *body =
result.addRegion();
1575 bool mapInit =
true) {
1577 if (initFirst && !mapInit)
1601 for (
const auto &[operand, bbArg] :
1603 if (bbArg != operand)
1607 for (
const auto &[operand, bbArg] :
1610 if (bbArg != operand)
1617 return yieldOp.getNumOperands() == 1 &&
1618 yieldOp.getOperand(0).getDefiningOp() &&
1619 yieldOp.getOperand(0).getDefiningOp() == &payload;
1624 std::string attrToElide;
1626 for (
const auto &attr : payloadOp->
getAttrs()) {
1628 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1629 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1630 attrToElide = attr.getName().str();
1631 elidedAttrs.push_back(attrToElide);
1639void MapOp::print(OpAsmPrinter &p) {
1640 Block *mapper = getBody();
1650 if (!useShortForm) {
1656 [&](
auto arg) { p.printRegionArgument(arg); });
1664LogicalResult MapOp::verify() {
1665 auto *bodyBlock = getBody();
1666 auto blockArgs = bodyBlock->getArguments();
1670 if (getInputs().size() + 1 != blockArgs.size())
1671 return emitOpError() <<
"expects number of operands to match the arity of "
1673 << getInputs().size() + 1 <<
" and "
1674 << blockArgs.size();
1677 for (
const auto &[bbArgType, inputArg] :
1678 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1679 auto inputElemType =
1680 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1681 if (bbArgType != inputElemType) {
1682 return emitOpError() <<
"expected element type of input " << inputElemType
1683 <<
" to match bbArg type " << bbArgType;
1688 auto outputShape = getInit().getType().getShape();
1689 for (Type inputArgType :
TypeRange{getInputs()}) {
1690 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1691 if (inputElemShape != outputShape) {
1692 return emitOpError() <<
"expected shape of input (" << inputElemShape
1693 <<
") to match shape of output (" << outputShape
1701SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1702 int64_t rank = getInit().getType().getRank();
1703 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1708 int64_t rank = getInit().getType().getRank();
1709 int64_t numIndexingMaps = getOperands().size();
1714void MapOp::getEffects(
1715 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1728void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1730 for (Value v : getRegionInputArgs())
1732 for (Value v : getRegionOutputArgs())
1733 setNameFn(v,
"init");
1736void ReduceOp::getAsmResultNames(
1738 if (!getResults().empty())
1739 setNameFn(getResults().front(),
"reduced");
1742void ReduceOp::build(
1744 ValueRange inits, ArrayRef<int64_t> dimensions,
1746 ArrayRef<NamedAttribute> attributes) {
1748 result.addAttributes(attributes);
1751 for (Value init : inits) {
1752 Type initType = init.
getType();
1753 if (llvm::isa<RankedTensorType>(initType))
1754 result.addTypes(initType);
1759 inputs, inits, bodyBuild);
1762SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1764 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1765 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1766 utils::IteratorType::parallel);
1767 for (int64_t reductionDim : getDimensions())
1768 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1769 return iteratorTypes;
1774 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1775 SmallVector<AffineMap> affineMaps(
1778 AffineMap resultMap =
1781 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1782 affineMaps.push_back(resultMap);
1783 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1786void ReduceOp::getEffects(
1787 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1798 StringRef attributeName) {
1806ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1807 std::optional<OperationName> payloadOpName;
1808 NamedAttrList payloadOpAttrs;
1811 if (
failed(operationName))
1815 payloadOpName = operationName.value();
1821 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1826 if (payloadOpName.has_value()) {
1828 ArrayRef(
result.operands),
true);
1830 SmallVector<OpAsmParser::Argument> regionArgs;
1836 Region *body =
result.addRegion();
1846 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1849void ReduceOp::print(OpAsmPrinter &p) {
1850 Block *mapper = getBody();
1859 if (!useShortForm) {
1865 [&](
auto arg) { p.printRegionArgument(arg); });
1873LogicalResult ReduceOp::verify() {
1874 ArrayRef<int64_t> dimensionsRef = getDimensions();
1876 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1879 return emitOpError() <<
"expects all inputs to have the same shapes. "
1880 "Shape at input-index "
1882 <<
" is not equal to the shape at input-index 0.";
1885 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1888 return emitOpError() <<
"expects all outputs to have the same shapes. "
1889 "Shape at output-index "
1891 <<
" is not equal to the shape at output-index 0.";
1894 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1895 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1898 for (int64_t dimension : dimensionsRef) {
1899 if (dimension < 0 || dimension >= inputType.getRank()) {
1901 <<
"dimensions for reduction should be in the range [0, "
1902 << inputType.getRank() - 1 <<
"].";
1904 dimensionsToReduce.insert(dimension);
1907 auto inputDims = inputType.getShape();
1908 auto initDims = initType.getShape();
1911 SmallVector<int64_t> reducedInputDims;
1912 for (
const auto &en : llvm::enumerate(inputDims)) {
1913 if (!dimensionsToReduce.count(en.index()))
1914 reducedInputDims.push_back(en.value());
1917 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1918 return emitOpError() <<
"number of dimensions after reduction "
1919 << reducedInputDims.size()
1920 <<
" doesn't match the init rank "
1921 << initType.getRank();
1924 if (reducedInputDims != initDims)
1925 return emitOpError() <<
"init dimensions [" << initDims
1926 <<
"] doesn't match input dimensions after reduction ["
1927 << reducedInputDims <<
"]";
1929 Block *block = getBody();
1932 <<
"mismatching number of operands and block arguments";
1935 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1936 Type inputElementType =
1937 llvm::cast<ShapedType>(input.getType()).getElementType();
1938 if (inputElementType != bbArg.getType())
1940 <<
"input element type " << inputElementType
1941 <<
" does not match corresponding block argument type "
1946 for (
auto [output, bbArg] : llvm::zip(
1947 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1948 auto outputElementType =
1949 llvm::cast<ShapedType>(output.getType()).getElementType();
1950 if (outputElementType != bbArg.getType())
1952 <<
"output element type " << outputElementType
1953 <<
" does not match corresponding block argument type "
1969 linalg::YieldOp::create(
b, loc, args[0]);
1973void TransposeOp::build(::mlir::OpBuilder &builder,
1974 ::mlir::OperationState &
result, Value input, Value init,
1976 ArrayRef<NamedAttribute> attributes) {
1977 result.addOperands(input);
1978 result.addOperands(init);
1979 result.addAttribute(getPermutationAttrName(
result.name), permutation);
1980 result.addAttributes(attributes);
1983 Type initType = init.
getType();
1984 if (llvm::isa<RankedTensorType>(initType))
1985 result.addTypes(initType);
1991void TransposeOp::build(::mlir::OpBuilder &builder,
1992 ::mlir::OperationState &
result, Value input, Value init,
1993 ArrayRef<int64_t> permutation,
1994 ArrayRef<NamedAttribute> attributes) {
1999ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2001 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2013void TransposeOp::getAsmResultNames(
2015 if (!getResults().empty())
2016 setNameFn(getResults().front(),
"transposed");
2019void TransposeOp::print(OpAsmPrinter &p) {
2025LogicalResult TransposeOp::verify() {
2026 ArrayRef<int64_t> permutationRef = getPermutation();
2031 auto inputType = getInput().getType();
2032 auto initType = getInit().getType();
2034 int64_t rank = inputType.getRank();
2036 if (rank != initType.getRank())
2038 <<
" does not match init rank " << initType.getRank();
2040 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2041 return emitOpError() <<
"size of permutation " << permutationRef.size()
2042 <<
" does not match the argument rank " << rank;
2044 auto inputDims = inputType.getShape();
2045 auto initDims = initType.getShape();
2047 for (int64_t i = 0; i < rank; ++i) {
2048 int64_t inputDim = inputDims[permutationRef[i]];
2049 int64_t initDim = initDims[i];
2051 if (inputDim != initDim) {
2052 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2053 <<
" doesn't match dim(input, permutation[" << i
2054 <<
"]) = " << inputDim;
2061SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2062 int64_t rank = getInit().getType().getRank();
2063 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2066ArrayAttr TransposeOp::getIndexingMaps() {
2068 int64_t rank = getInit().getType().getRank();
2071 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2075void TransposeOp::getEffects(
2076 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2085LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2086 SmallVectorImpl<OpFoldResult> &
result) {
2088 if (!isa<TensorType>(getInput().
getType()))
2092 if (getPermutation().empty()) {
2093 result.push_back(getInput());
2098 result.push_back(getInput());
2111 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2112 if (!defTransposeOp)
2117 foldedPerms.reserve(perms.size());
2119 foldedPerms.push_back(defPerms[perm]);
2122 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2136 Value input = transposeOp.getInput();
2137 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2148 unsigned dimensionSize = dimensions.size();
2149 for (
unsigned i = 0; i < dimensionSize; ++i)
2150 resultDimensions.push_back(invertPerm[dimensions[i]]);
2153 Value broadcastInput = broadcastOp.getInput();
2154 Location loc = transposeOp.getLoc();
2157 auto broadcastInputTy =
2158 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2159 unsigned inputRank = broadcastInputTy.getRank();
2160 for (
unsigned i = 0; i < inputRank; ++i) {
2161 if (broadcastInputTy.isDynamicDim(i)) {
2162 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2165 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2166 broadcastInputTy.getDimSize(i)));
2171 Value transposeInit = tensor::EmptyOp::create(
2172 rewriter, transposeOp.getLoc(), transposeResultShapes,
2173 broadcastInputTy.getElementType());
2176 Value transposeResult =
2177 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2178 transposeInit, resultPerms)
2181 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2186void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2187 MLIRContext *context) {
2188 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2195void BroadcastOp::build(::mlir::OpBuilder &builder,
2196 ::mlir::OperationState &
result, Value input, Value init,
2198 ArrayRef<NamedAttribute> attributes) {
2199 result.addOperands(input);
2200 result.addOperands(init);
2201 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2202 result.addAttributes(attributes);
2205 Type initType = init.
getType();
2206 if (llvm::isa<RankedTensorType>(initType))
2207 result.addTypes(initType);
2213void BroadcastOp::build(::mlir::OpBuilder &builder,
2214 ::mlir::OperationState &
result, Value input, Value init,
2215 ArrayRef<int64_t> dimensions,
2216 ArrayRef<NamedAttribute> attributes) {
2221ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2223 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2235void BroadcastOp::getAsmResultNames(
2237 if (!getResults().empty())
2238 setNameFn(getResults().front(),
"broadcasted");
2241void BroadcastOp::print(OpAsmPrinter &p) {
2247LogicalResult BroadcastOp::verify() {
2248 ArrayRef<int64_t> dimensionsRef = getDimensions();
2250 auto inputType = getInput().getType();
2251 auto initType = getInit().getType();
2253 int64_t inputRank = inputType.getRank();
2254 int64_t initRank = initType.getRank();
2256 auto inputShape = inputType.getShape();
2257 auto initShape = initType.getShape();
2259 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2260 return emitOpError() <<
"input rank plus added dimensions does not "
2261 "match init rank. input rank: "
2263 <<
", dimensions size: " << dimensionsRef.size()
2264 <<
", init rank: " << initRank;
2266 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2267 if (dim < 0 || dim >= initRank)
2269 <<
" is out of range. expected range: [0, "
2270 << initRank - 1 <<
"], got: " << dim;
2274 SmallVector<int64_t> dimMap;
2275 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2276 if (!llvm::is_contained(dimensionsRef, dim))
2277 dimMap.push_back(dim);
2280 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2283 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2284 return emitOpError() <<
"input dim " << inputDimIdx
2285 <<
" should match init dim " << initDimIdx
2286 <<
". input: " << inputShape[inputDimIdx]
2287 <<
", init: " << initShape[initDimIdx];
2293SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2294 int64_t rank = getInit().getType().getRank();
2295 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2298ArrayAttr BroadcastOp::getIndexingMaps() {
2300 int64_t rank = getInit().getType().getRank();
2306void BroadcastOp::getEffects(
2307 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2322 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2323 if (!defBroadcastOp)
2328 Value init = broadcastOp.getInit();
2332 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2333 if (!llvm::is_contained(dimensions, dim))
2334 dimMap.push_back(dim);
2336 for (
auto dim : defDimensions)
2337 foldedDims.push_back(dimMap[dim]);
2339 llvm::sort(foldedDims);
2341 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2346void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2347 MLIRContext *context) {
2348 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2355void linalg::YieldOp::print(OpAsmPrinter &p) {
2356 if (getNumOperands() > 0)
2357 p <<
' ' << getOperands();
2359 if (getNumOperands() > 0)
2360 p <<
" : " << getOperandTypes();
2363ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2364 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2365 SmallVector<Type, 2> types;
2375static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2376 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2377 return op.emitOpError(
"expected number of yield values (")
2378 << op.getNumOperands()
2379 <<
") to match the number of inits / outs operands of the enclosing "
2380 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2382 for (
OpOperand &opOperand : op->getOpOperands()) {
2384 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2386 if (isa<MemRefType, RankedTensorType>(elementType))
2388 if (opOperand.get().getType() != elementType)
2389 return op.emitOpError(
"type of yield operand ")
2390 << (opOperand.getOperandNumber() + 1) <<
" ("
2391 << opOperand.get().getType() <<
") doesn't match "
2392 <<
"the element type of the enclosing linalg.generic op ("
2393 << elementType <<
")";
2398LogicalResult linalg::YieldOp::verify() {
2399 auto *parentOp = (*this)->getParentOp();
2400 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2401 return emitOpError(
"expected single non-empty parent region");
2403 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2406 return emitOpError(
"expected parent op with LinalgOp interface");
2413LogicalResult IndexOp::verify() {
2414 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2416 return emitOpError(
"expected parent op with LinalgOp interface");
2417 if (linalgOp.getNumLoops() <= getDim())
2419 << getDim() <<
") to be lower than the number of loops ("
2420 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2424OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2425 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2430 return OpFoldResult{};
2433 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2434 uint64_t dim = getDim();
2435 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2436 if (loopBounds[dim] == 1)
2437 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2439 return OpFoldResult{};
2444#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2446#define GET_OP_CLASSES
2447#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2449#define GET_OP_CLASSES
2450#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2451#define GET_OP_CLASSES
2452#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2469 for (
unsigned i = 0; i < num; ++i)
2476 auto rangeA = llvm::make_range(a.begin(), a.end());
2477 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2478 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2479 return llvm::to_vector<4>(concatRanges);
2483 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2485 for (
auto size :
memref.getShape())
2492 if (
auto as =
memref.getMemorySpace()) {
2493 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2494 ss <<
"as" << attr.getInt();
2500 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2503 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2516 assert(isa<LinalgOp>(op));
2518 std::string fun =
"";
2520 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2521 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2522 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2523 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2527 llvm::replace(name,
'.',
'_');
2528 llvm::raw_string_ostream ss(name);
2532 return std::string();
2547 LogicalResult matchAndRewrite(LinalgOp op,
2549 for (
OpOperand &opOperand : op->getOpOperands()) {
2553 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2556 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2567struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2568 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2570 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2571 PatternRewriter &rewriter)
const override {
2575 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2582 if (castOp->getBlock() != linalgOp->getBlock())
2585 OpBuilder::InsertionGuard guard(rewriter);
2588 Location loc = linalgOp.getLoc();
2589 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2592 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2598 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2600 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2601 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2602 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2603 linalgOp.getDpsInits().end());
2604 outputOperands[resultNumber] = newOperand;
2605 newOperands.append(outputOperands.begin(), outputOperands.end());
2607 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2608 linalgOp->result_type_end());
2609 resultTypes[resultNumber] = resultType;
2610 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2613 Value castBack = tensor::CastOp::create(
2617 results[resultNumber] = castBack;
2626static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2627 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2628 for (OpOperand &opOperand : operands) {
2629 if (linalgOp.isScalar(&opOperand))
2631 Value src = opOperand.get();
2632 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2633 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2639 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2641 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2642 Value castSource = castOp.getSource();
2643 auto castSourceType =
2644 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2645 if (castSourceType && castSourceType.hasStaticShape())
2646 sourceShape = castSourceType.getShape();
2652 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2653 if (sourceType.isDynamicDim(i))
2655 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2656 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2666static void createNewOperandWithStaticSizes(
2667 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2668 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2669 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2670 bool &changeNeeded) {
2671 Value src = opOperand->
get();
2672 newOperands.push_back(src);
2673 if (linalgOp.isScalar(opOperand))
2675 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2676 Type resultType = sourceType;
2677 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2678 resultTypes.push_back(resultType);
2681 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2682 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2683 SmallVector<int64_t> newShape;
2686 bool newOperandNeeded =
false;
2687 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2688 int64_t dimShape = sourceShape[i];
2689 AffineExpr dimExpr = sourceMap.
getResult(i);
2690 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2691 newShape.push_back(dimShape);
2697 newShape.push_back(affineExprToSize[dimExpr]);
2698 newOperandNeeded =
true;
2700 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2701 sourceType.getEncoding());
2702 if (newOperandNeeded) {
2703 changeNeeded =
true;
2706 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2708 newOperands[index] = newOperand;
2710 if (linalgOp.isDpsInit(opOperand))
2711 resultTypes.push_back(resultType);
2717struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2718 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2720 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2721 PatternRewriter &rewriter)
const override {
2722 if (!linalgOp.hasPureTensorSemantics())
2726 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2727 return !map.isProjectedPermutation();
2732 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2733 Location loc = linalgOp.getLoc();
2737 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2739 SmallVector<Value> newOperands;
2740 SmallVector<Type> resultTypes;
2744 bool changeNeeded =
false;
2745 newOperands.reserve(linalgOp->getNumOperands());
2746 resultTypes.reserve(linalgOp.getNumDpsInits());
2749 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2750 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2751 affineExprToSize, linalgOp, newOperands,
2752 resultTypes, changeNeeded);
2761 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2762 SmallVector<Value> replacements;
2764 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2765 Value newResult = std::get<1>(it);
2766 Value oldResult = std::get<0>(it);
2767 Type newType = newResult.
getType();
2768 Type oldType = oldResult.
getType();
2769 replacements.push_back(
2770 (newType != oldType)
2771 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2774 rewriter.
replaceOp(linalgOp, replacements);
2788LogicalResult SoftmaxOp::verify() {
2789 ShapedType inputType = getInputOperandType();
2790 ShapedType outputType = getOutputOperandType();
2792 ArrayRef<int64_t> inputShape = inputType.getShape();
2793 ArrayRef<int64_t> outputShape = outputType.getShape();
2797 int64_t inputRank = getInputOperandRank();
2798 int64_t dimension = getDimension();
2799 if ((dimension < 0) || (dimension >= inputRank))
2800 return emitOpError(
"incorrect dimension specified");
2805SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2806 int64_t operandRank = getInputOperandRank();
2807 SmallVector<Range> loopBounds(operandRank);
2808 Location loc = getLoc();
2811 Value source = getInput();
2812 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2813 loopBounds[dim].offset = zero;
2814 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2815 loopBounds[dim].stride = one;
2820SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2821 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2822 utils::IteratorType::parallel);
2823 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2824 return iteratorTypes;
2827FailureOr<TilingResult>
2828SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2829 ArrayRef<OpFoldResult> offsets,
2830 ArrayRef<OpFoldResult> sizes) {
2831 int64_t rank = getInputOperandRank();
2833 SmallVector<OpFoldResult> strides(rank, oneAttr);
2834 SmallVector<Value> tiledOperands;
2835 Operation *inputSlice =
2836 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2838 return emitOpError(
"failed to compute input slice");
2840 tiledOperands.emplace_back(inputSlice->
getResult(0));
2841 Operation *outputSlice =
2842 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2844 return emitOpError(
"failed to compute output slice");
2846 tiledOperands.emplace_back(outputSlice->
getResult(0));
2848 SmallVector<Type, 4> resultTypes;
2849 if (hasPureTensorSemantics())
2850 resultTypes.push_back(tiledOperands[1].
getType());
2851 Operation *tiledOp =
2852 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2854 return TilingResult{
2857 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2860LogicalResult SoftmaxOp::getResultTilePosition(
2861 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2862 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2863 SmallVector<OpFoldResult> &resultSizes) {
2864 if (resultNumber == 0) {
2865 resultOffsets.assign(offsets.begin(), offsets.end());
2866 resultSizes.assign(sizes.begin(), sizes.end());
2873LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2878SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2880 SmallVector<OpFoldResult> shapes;
2881 Location loc = getOperation()->getLoc();
2882 IRRewriter rewriter(
b);
2883 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2884 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2885 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2886 if (!outputShapedType.isDynamicDim(dim)) {
2888 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2895 reifiedReturnShapes.emplace_back(std::move(shapes));
2899void SoftmaxOp::getEffects(
2900 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2902 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2903 if (!llvm::isa<MemRefType>(operand.
getType()))
2906 &getOperation()->getOpOperand(index), 0,
2911 for (OpOperand &operand : getDpsInitsMutable()) {
2912 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2943static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2945 int64_t dim,
bool allParallel =
false) {
2947 utils::IteratorType::parallel);
2949 iteratorTypes[dim] = utils::IteratorType::reduction;
2953 for (
int i = 0; i < inputRank; i++) {
2960 return std::make_tuple(iteratorTypes, indexingMaps);
2965template <
typename T>
2968 auto inputType = cast<ShapedType>(input.
getType());
2970 int64_t inputRank = inputShape.size();
2971 auto [iteratorTypes, indexingMaps] =
2973 assert(indexingMaps.size() == 2 &&
2974 "We should have two maps: 1 for the input, 1 for the output");
2975 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2977 auto genericOp = linalg::GenericOp::create(
2978 builder, loc, output.
getType(), input, output, indexingMaps,
2980 Value result = T::create(b, loc, args[0], args[1]);
2981 linalg::YieldOp::create(b, loc, result);
2983 return genericOp.getResult(0);
2991 auto inputType = cast<ShapedType>(input.
getType());
2993 int64_t inputRank = inputShape.size();
2995 builder, inputRank, dim,
true);
2996 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2997 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2999 indexingMaps.push_back(indexingMaps[0]);
3000 auto genericOp = linalg::GenericOp::create(
3002 indexingMaps, iteratorTypes,
3004 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3005 Value result = math::ExpOp::create(b, loc, diff);
3006 linalg::YieldOp::create(b, loc, result);
3008 return genericOp.getResult(0);
3018 auto inputType = cast<ShapedType>(numerator.
getType());
3020 int64_t inputRank = inputShape.size();
3022 builder, inputRank, dim,
true);
3023 assert(indexingMaps.size() == 2 &&
3024 "We should have one map for each input (2)");
3025 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3027 indexingMaps.push_back(indexingMaps[0]);
3028 auto genericOp = linalg::GenericOp::create(
3030 output, indexingMaps, iteratorTypes,
3032 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3033 linalg::YieldOp::create(b, loc, result);
3035 return genericOp.getResult(0);
3057FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3058 OpBuilder::InsertionGuard guard(
b);
3059 b.setInsertionPoint(*
this);
3060 Location loc = getLoc();
3061 Value input = getInput();
3062 ShapedType inputType = getInputOperandType();
3063 Type elementType = inputType.getElementType();
3064 int64_t reductionDim = getDimension();
3066 Value output = getOutput();
3067 dims.erase(dims.begin() + reductionDim);
3069 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3071 elementType,
b, loc,
3073 Value neutralForMaxFInit =
3074 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3086 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3092 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3093 return SmallVector<Value>{
result};
3100LogicalResult WinogradFilterTransformOp::verify() {
3101 auto filterType = cast<ShapedType>(getFilter().
getType());
3102 ArrayRef<int64_t> filterShape = filterType.getShape();
3103 int64_t filterH = filterShape[getFilterHDim()];
3104 int64_t filterW = filterShape[getFilterWDim()];
3105 WinogradConv2DFmr fmr = getFmr();
3109 if (filterH != r && filterH != 1)
3110 return emitOpError(
"expect filter height either equals to r or 1");
3111 if (filterW != r && filterW != 1)
3112 return emitOpError(
"expect filter width either equals to r or 1");
3113 if (filterH == 1 && filterW == 1)
3114 return emitOpError(
"expect either filter height or width equals to r");
3116 SmallVector<int64_t> expectedOutputShape;
3117 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3118 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3119 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3120 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3122 auto outputType = cast<ShapedType>(getOutput().
getType());
3123 ArrayRef<int64_t> outputShape = outputType.getShape();
3125 return emitOpError(
"the output shape is not expected");
3131WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3132 Location loc = getLoc();
3135 Value filter = getFilter();
3136 int64_t filterRank = getFilterOperandRank();
3137 SmallVector<Range> loopBounds(filterRank);
3138 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3139 loopBounds[dim].offset = zeroAttr;
3140 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3141 loopBounds[dim].stride = oneAttr;
3146SmallVector<utils::IteratorType>
3147WinogradFilterTransformOp::getLoopIteratorTypes() {
3148 int64_t filterRank = getFilterOperandRank();
3149 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3150 utils::IteratorType::parallel);
3151 return iteratorTypes;
3154LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3155 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3156 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3157 SmallVector<OpFoldResult> &resultSizes) {
3159 ShapedType filterType = getFilterOperandType();
3160 ArrayRef<int64_t> filterShape = filterType.getShape();
3161 int64_t filterH = filterShape[getFilterHDim()];
3162 int64_t filterW = filterShape[getFilterWDim()];
3163 WinogradConv2DFmr fmr = getFmr();
3166 int64_t alpha = m + r - 1;
3167 int64_t alphaH = filterH != 1 ? alpha : 1;
3168 int64_t alphaW = filterW != 1 ? alpha : 1;
3172 resultOffsets.append(
3173 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3175 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3186FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3187 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3188 ArrayRef<OpFoldResult> sizes) {
3191 ShapedType filterType = getFilterOperandType();
3192 ArrayRef<int64_t> filterShape = filterType.getShape();
3193 int64_t filterH = filterShape[getFilterHDim()];
3194 int64_t filterW = filterShape[getFilterWDim()];
3197 SmallVector<Value> tiledOperands;
3198 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3200 sliceOffsets.append(
3201 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3202 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3203 sizes[getFilterCDim()]});
3204 int64_t filterRank = getFilterOperandRank();
3205 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3206 Location loc = getLoc();
3207 auto filterSlice = tensor::ExtractSliceOp::create(
3208 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3209 tiledOperands.emplace_back(filterSlice);
3211 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3216 int64_t outputRank = getOutputOperandRank();
3217 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3218 auto outputSlice = tensor::ExtractSliceOp::create(
3219 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3220 tiledOperands.emplace_back(outputSlice);
3222 SmallVector<Type> resultTypes;
3223 resultTypes.push_back(tiledOperands[1].
getType());
3224 Operation *tiledOp =
3225 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3227 return TilingResult{
3230 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3237LogicalResult WinogradInputTransformOp::verify() {
3238 auto inputType = cast<ShapedType>(getInput().
getType());
3239 ArrayRef<int64_t> inputShape = inputType.getShape();
3240 int64_t inputH = inputShape[getInputHDim()];
3241 int64_t inputW = inputShape[getInputWDim()];
3242 WinogradConv2DFmr fmr = getFmr();
3245 int64_t tileSize = m + r - 1;
3247 auto outputType = cast<ShapedType>(getOutput().
getType());
3248 ArrayRef<int64_t> outputShape = outputType.getShape();
3249 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3250 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3252 SmallVector<int64_t> expectedOutputShape(6, inputH);
3253 if (ShapedType::isDynamic(inputH)) {
3254 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3255 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3257 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3258 expectedOutputShape[getOutputTileHDim()] =
3259 leftTransform ? (inputH - (r - 1)) / m : inputH;
3261 if (ShapedType::isDynamic(inputW)) {
3262 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3263 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3265 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3266 expectedOutputShape[getOutputTileWDim()] =
3267 rightTransform ? (inputW - (r - 1)) / m : inputW;
3269 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3270 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3273 return emitOpError(
"the output shape is not expected");
3279WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3280 Location loc = getLoc();
3283 Value output = getOutput();
3284 int64_t outputRank = getOutputOperandRank();
3285 SmallVector<Range> loopBounds(outputRank);
3286 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3287 loopBounds[dim].offset = zeroAttr;
3289 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3290 loopBounds[dim].stride = oneAttr;
3295SmallVector<utils::IteratorType>
3296WinogradInputTransformOp::getLoopIteratorTypes() {
3297 int64_t outputRank = getOutputOperandRank();
3298 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3299 utils::IteratorType::parallel);
3300 return iteratorTypes;
3303LogicalResult WinogradInputTransformOp::getResultTilePosition(
3304 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3305 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3306 SmallVector<OpFoldResult> &resultSizes) {
3308 ShapedType outputType = getOutputOperandType();
3309 ArrayRef<int64_t> outputShape = outputType.getShape();
3310 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3311 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3313 WinogradConv2DFmr fmr = getFmr();
3316 int64_t alpha = m + r - 1;
3317 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3318 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3323 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3324 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3325 offsets[getOutputCDim()]});
3326 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3327 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3328 sizes[getOutputCDim()]});
3339FailureOr<TilingResult>
3340WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3341 ArrayRef<OpFoldResult> offsets,
3342 ArrayRef<OpFoldResult> sizes) {
3344 WinogradConv2DFmr fmr = getFmr();
3348 ShapedType outputType = getOutputOperandType();
3349 ArrayRef<int64_t> outputShape = outputType.getShape();
3350 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3351 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3353 Location loc = getLoc();
3355 auto identityAffineMap =
3357 auto offsetAffineMap =
3360 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3361 offsets[getOutputTileHDim()]);
3363 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3364 offsets[getOutputTileWDim()]);
3368 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3370 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3372 SmallVector<Value> tiledOperands;
3373 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3375 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3376 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3377 sliceOffsets.append(
3378 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3379 OpFoldResult sizeH =
3380 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3381 OpFoldResult sizeW =
3382 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3384 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3385 int64_t inputRank = getInputOperandRank();
3386 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3387 auto inputSlice = tensor::ExtractSliceOp::create(
3388 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3389 tiledOperands.emplace_back(inputSlice);
3391 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3396 int64_t outputRank = getOutputOperandRank();
3397 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3398 auto outputSlice = tensor::ExtractSliceOp::create(
3399 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3400 tiledOperands.emplace_back(outputSlice);
3402 SmallVector<Type> resultTypes;
3403 resultTypes.push_back(tiledOperands[1].
getType());
3404 Operation *tiledOp =
3405 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3407 return TilingResult{
3410 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3417LogicalResult WinogradOutputTransformOp::verify() {
3418 auto valueType = cast<ShapedType>(getValue().
getType());
3419 ArrayRef<int64_t> valueShape = valueType.getShape();
3420 int64_t valueH = valueShape[getValueAlphaHDim()];
3421 int64_t valueW = valueShape[getValueAlphaWDim()];
3422 int64_t valueTileH = valueShape[getValueTileHDim()];
3423 int64_t valueTileW = valueShape[getValueTileWDim()];
3424 WinogradConv2DFmr fmr = getFmr();
3427 bool leftTransform = valueH != 1;
3428 bool rightTransform = valueW != 1;
3430 int64_t outputRank = getOutputOperandRank();
3431 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3432 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3433 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3435 if (valueH != (leftTransform ? m + r - 1 : 1))
3436 return emitOpError(
"expect input height equals to input tile size");
3437 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3439 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3440 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3442 if (valueW != (rightTransform ? m + r - 1 : 1))
3443 return emitOpError(
"expect input width equals to input tile size");
3444 expectedOutputShape[getOutputWDim()] =
3445 (rightTransform ? m : 1) * valueTileW;
3447 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3448 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3450 auto outputType = cast<ShapedType>(getOutput().
getType());
3451 ArrayRef<int64_t> outputShape = outputType.getShape();
3453 return emitOpError(
"the output shape is not expected");
3459WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3460 Location loc = getLoc();
3463 Value value = getValue();
3464 int64_t valueRank = getValueOperandRank();
3465 SmallVector<Range> loopBounds(valueRank);
3466 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3467 loopBounds[dim].offset = zeroAttr;
3469 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3470 loopBounds[dim].stride = oneAttr;
3475SmallVector<utils::IteratorType>
3476WinogradOutputTransformOp::getLoopIteratorTypes() {
3477 int64_t valueRank = getValueOperandRank();
3478 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3479 utils::IteratorType::parallel);
3480 return iteratorTypes;
3483LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3484 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3485 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3486 SmallVector<OpFoldResult> &resultSizes) {
3487 WinogradConv2DFmr fmr = getFmr();
3491 Location loc = getLoc();
3493 auto identityAffineMap =
3498 ShapedType valueType = getValueOperandType();
3499 ArrayRef<int64_t> valueShape = valueType.getShape();
3500 int64_t valueH = valueShape[0];
3501 int64_t valueW = valueShape[1];
3503 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3504 offsets[getValueTileHDim()]);
3506 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3507 offsets[getValueTileWDim()]);
3509 builder, loc, affineMap, sizes[getValueTileHDim()]);
3511 builder, loc, affineMap, sizes[getValueTileWDim()]);
3514 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3515 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3516 OpFoldResult sizeH =
3517 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3518 OpFoldResult sizeW =
3519 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3521 resultOffsets.append(
3522 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3524 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3534FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3535 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3536 ArrayRef<OpFoldResult> sizes) {
3539 Location loc = getLoc();
3540 SmallVector<Value> tiledOperands;
3541 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3543 ShapedType valueType = getValueOperandType();
3544 ArrayRef<int64_t> valueShape = valueType.getShape();
3545 int64_t alphaH = valueShape[getValueAlphaHDim()];
3546 int64_t alphaW = valueShape[getValueAlphaWDim()];
3550 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3551 offsets[getValueTileWDim()], offsets[getValueNDim()],
3552 offsets[getValueFDim()]});
3553 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3554 sizes[getValueTileWDim()], sizes[getValueNDim()],
3555 sizes[getValueFDim()]});
3556 int64_t valueRank = getValueOperandRank();
3557 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3558 auto valueSlice = tensor::ExtractSliceOp::create(
3559 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3560 tiledOperands.emplace_back(valueSlice);
3562 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3567 int64_t outputRank = getOutputOperandRank();
3568 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3569 auto outputSlice = tensor::ExtractSliceOp::create(
3570 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3571 tiledOperands.emplace_back(outputSlice);
3573 SmallVector<Type> resultTypes;
3574 resultTypes.push_back(tiledOperands[1].
getType());
3575 Operation *tiledOp =
3576 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3578 return TilingResult{
3581 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3595 llvm::set_union(explicitSet, defaultSet);
3596 return explicitSet == defaultSet;
3616 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3618 auto opIndexingMap = opIndexingMaps[opIndex];
3619 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3622 return matmulOp->emitOpError()
3623 <<
"Unexpected dim expression in map result.";
3626 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3627 return matmulOp->emitOpError()
3628 <<
"Invalid broadcast requested, should be (d2).";
3637template <
typename OpTy>
3640 AffineMap defaultIndexingMap,
bool isLHS) {
3641 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3642 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3643 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3646 return batchVariantMatmulOp->emitOpError()
3647 <<
"Unexpected result dim expression (outside the set of default "
3652 return batchVariantMatmulOp->emitOpError()
3653 <<
"no. of result dim expressions exceeds 3.";
3655 auto hasValidBatchDim = [](
AffineMap map) {
3662 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3663 return batchVariantMatmulOp->emitOpError()
3664 <<
"Invalid broadcast requested.";
3665 }
else if (!hasValidBatchDim(opIndexingMap)) {
3666 return batchVariantMatmulOp->emitOpError()
3667 <<
"Invalid batch dimension expression.";
3675template <
typename OpTy>
3678 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3679 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3680 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3681 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3684 return batchVariantMatmulOp->emitOpError()
3685 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3688 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3690 return batchVariantMatmulOp->emitOpError()
3691 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3695 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3696 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3697 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3698 outputMap.getResult(1).isFunctionOfDim(1) &&
3699 outputMap.getResult(2).isFunctionOfDim(2)
3700 : outputMap.getResult(0).isFunctionOfDim(1) &&
3701 outputMap.getResult(1).isFunctionOfDim(2);
3704 if (!areValidOutputResultDim(opIndexingMap)) {
3705 return batchVariantMatmulOp->emitOpError()
3706 <<
"Invalid output map result dimension.";
3715template <
typename OpTy>
3720 batchVariantMatmulOp.getIndexingMapsArray();
3722 batchVariantMatmulOp.getDefaultIndexingMaps(
3723 batchVariantMatmulOp->getContext());
3725 if (opIndexingMaps.size() != 3)
3726 return batchVariantMatmulOp->emitOpError()
3727 <<
"Indexing_map attribute must have 3 affine maps.";
3729 auto opIndexingMap = opIndexingMaps[opIndex];
3730 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3738 defaultIndexingMap, opIndex == 0)))
3748 if (m == 2 && r == 3)
3749 return WinogradConv2DFmr::F_2_3;
3750 if (m == 4 && r == 3)
3751 return WinogradConv2DFmr::F_4_3;
3752 if (m == 2 && r == 5)
3753 return WinogradConv2DFmr::F_2_5;
3754 return std::nullopt;
3759 case WinogradConv2DFmr::F_2_3:
3761 case WinogradConv2DFmr::F_4_3:
3763 case WinogradConv2DFmr::F_2_5:
3772static FailureOr<SmallVector<SmallVector<int64_t>>>
3775 for (
auto map : maps) {
3776 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3780 for (
auto result : attr.getAffineMap().getResults()) {
3781 auto dim = dyn_cast<AffineDimExpr>(
result);
3784 pos.push_back(dim.getPosition());
3786 positions.push_back(pos);
3799 return indexingMaps;
3802bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3803 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3806 if (maps.size() != 3)
3811 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3812 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3813 (*positions)[2] == SmallVector<int64_t>{0, 1};
3816SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3817 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3818 utils::IteratorType::parallel,
3819 utils::IteratorType::reduction};
3822unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3824std::string MatmulOp::getLibraryCallName() {
3828bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3832bool MatmulOp::hasUserDefinedMaps() {
3833 SmallVector<AffineMap, 3> defaultMaps =
3835 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3836 return defaultMaps != explicitMaps;
3841void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3842 ArrayRef<NamedAttribute> attrs,
3845 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3850 "MatmulOp regionBuilder expects 3 args");
3851 RegionBuilderHelper helper(
b, block);
3852 SmallVector<Value> yields;
3854 TypeFn castVal = TypeFn::cast_signed;
3855 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3856 return attr.
getName() ==
"cast";
3858 if (castIter != attrs.end()) {
3859 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3867 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3870 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3874 yields.push_back(value4);
3875 helper.yieldOutputs(yields);
3885bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3886 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3887 AffineExpr expr = bcastMap.
getResult(0);
3901 if (llvm::any_of(arrayAttr,
3902 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3904 <<
"element of indexing_maps array is not an affine_map";
3911 if (failed(indexingMapsAttr))
3914 if (*indexingMapsAttr ==
nullptr) {
3915 auto indexingMapAttrs = llvm::map_to_vector(
3916 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3921 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3923 MatmulOp::getRegionBuilder());
3926void MatmulOp::print(OpAsmPrinter &p) {
3927 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3928 MatmulOp::getDefaultIndexingMaps(
getContext()),
3929 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3930 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3931 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3933 std::array<StringRef, 3> elidedAttrs = {
3934 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3940LogicalResult MatmulOp::verify() {
3942 if (!hasUserDefinedMaps())
3945 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3952LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3956void MatmulOp::getEffects(
3957 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3959 if (hasPureTensorSemantics())
3968SmallVector<AffineMap>
3969MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3970 AffineExpr d0, d1, d2;
3976 return {mapLHS, mapRHS, mapOut};
3980 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3983 if (maps.size() != 3)
3986 if (failed(positions))
3998 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4006 build(builder, state, inputs, outputs, attributes);
4007 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4008 assert(res &&
"builder didn't return the right type");
4018 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4027 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4028 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4029 assert(res &&
"builder didn't return the right type");
4039 result.addAttribute(
"cast", cast);
4041 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4050 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4051 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4052 assert(res &&
"builder didn't return the right type");
4057 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4059 op->
getAttr(
"indexing_maps"));
4063MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4070 return {mapLHS, mapRHS, mapOut};
4074 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4077 if (maps.size() != 3)
4080 if (failed(positions))
4092 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4100 build(builder, state, inputs, outputs, attributes);
4101 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4102 assert(res &&
"builder didn't return the right type");
4112 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4121 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4122 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4123 assert(res &&
"builder didn't return the right type");
4133 result.addAttribute(
"cast", cast);
4135 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4144 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4145 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4146 assert(res &&
"builder didn't return the right type");
4151 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4153 op->
getAttr(
"indexing_maps"));
4157BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4164 return {mapLHS, mapRHS, mapOut};
4168 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4171 if (maps.size() != 3)
4174 if (failed(positions))
4185 BatchMatmulOp::getRegionBuilder(),
4186 getDefaultIndexingMaps(builder));
4194 build(builder, state, inputs, outputs, attributes);
4195 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4196 assert(res &&
"builder didn't return the right type");
4205 BatchMatmulOp::getRegionBuilder(),
4206 getDefaultIndexingMaps(builder));
4215 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4216 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4217 assert(res &&
"builder didn't return the right type");
4225 result.addAttribute(
"cast", cast);
4227 BatchMatmulOp::getRegionBuilder(),
4228 getDefaultIndexingMaps(builder));
4237 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4238 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4239 assert(res &&
"builder didn't return the right type");
4244 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4246 op->
getAttr(
"indexing_maps"));
4250BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4257 return {mapLHS, mapRHS, mapOut};
4261 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4264 if (maps.size() != 3)
4267 if (failed(positions))
4278 BatchMatmulOp::getRegionBuilder(),
4279 getDefaultIndexingMaps(builder));
4287 build(builder, state, inputs, outputs, attributes);
4288 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4289 assert(res &&
"builder didn't return the right type");
4298 BatchMatmulOp::getRegionBuilder(),
4299 getDefaultIndexingMaps(builder));
4308 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4309 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4310 assert(res &&
"builder didn't return the right type");
4318 result.addAttribute(
"cast", cast);
4320 BatchMatmulOp::getRegionBuilder(),
4321 getDefaultIndexingMaps(builder));
4330 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4331 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4332 assert(res &&
"builder didn't return the right type");
4337 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4339 op->
getAttr(
"indexing_maps"));
4347 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4358 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4359 assert(dimExpr &&
"affine_map is a projected permutation");
4360 dimsInOutput[dimExpr.getPosition()] =
true;
4364 for (
auto dimOccursInOutput : dimsInOutput)
4365 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4366 : utils::IteratorType::reduction);
4368 return iteratorTypes;
4371unsigned ContractOp::getNumRegionArgs() {
return 3; }
4374void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4375 ArrayRef<NamedAttribute> attrs,
4378 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4383 "ContractOp regionBuilder expects 3 args");
4384 RegionBuilderHelper helper(
b, block);
4386 TypeFn castSignedness = TypeFn::cast_signed;
4387 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4388 return attr.
getName() ==
"cast";
4390 if (castIter != attrs.end()) {
4391 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4397 Value lhsAtOutType =
4398 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4399 Value rhsAtOutType =
4400 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4401 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4403 if (!productAtOutType)
4409 helper.yieldOutputs({
result});
4412ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4414 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4416 "expected 'indexing_maps' attribute");
4417 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4423void ContractOp::print(OpAsmPrinter &p) {
4424 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4426 p, getOperation(), getInputs(), getOutputs(),
4427 {
"indexing_maps",
"operandSegmentSizes"});
4430LogicalResult ContractOp::verify() {
4431 int iterationSpaceDims = -1;
4436 SmallVector<size_t> inOccurrences;
4437 SmallVector<size_t> outOccurrences;
4440 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4441 bool isInput) -> LogicalResult {
4444 return emitError(
"provided affine_map is not a projected permutation");
4447 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4449 return emitError(
"ranks of shaped operand and results of corresponding "
4450 "affine_map differ");
4452 return emitError(
"affine_map specifies shaped access while operand has "
4457 if (iterationSpaceDims == -1) {
4459 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4460 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4461 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4462 return emitError(
"iteration spaces of provided affine_maps differ");
4466 for (AffineExpr affineExpr : affineMap.
getResults()) {
4467 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4469 llvm_unreachable(
"affine_map is a projected permutation");
4472 inOccurrences[affineDimExpr.getPosition()] += 1;
4474 outOccurrences[affineDimExpr.getPosition()] += 1;
4480 for (
auto &&[affineMap, operandType, isInput] :
4481 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4482 SmallVector<bool>{
true,
true,
false})) {
4483 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4487 bool hasContractingDim =
false;
4488 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4489 size_t inOccCount = inOccurrences[dimIndex];
4490 size_t outOccCount = outOccurrences[dimIndex];
4493 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4495 if (inOccCount == 0 && outOccCount == 0)
4496 return emitError() <<
"iteration space dim at index " << dimIndex
4497 <<
" not used to access any operand";
4508 if (inOccCount == 1 && outOccCount != 1)
4510 <<
"iteration space dim at index " << dimIndex
4511 <<
" is neither a contracting dim nor of parallel iteration type";
4514 if (!hasContractingDim)
4515 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4520LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4524void ContractOp::getEffects(
4525 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4527 if (hasPureTensorSemantics())
4539SmallVector<AffineMap>
4540BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4541 AffineExpr d0, d1, d2, d3;
4542 SmallVector<AffineMap> indexingMaps;
4544 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4545 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4546 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4547 return indexingMaps;
4550bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4551 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4554 if (maps.size() != 3)
4559 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4560 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4561 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4564SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4565 return SmallVector<utils::IteratorType>{
4566 utils::IteratorType::parallel, utils::IteratorType::parallel,
4567 utils::IteratorType::parallel, utils::IteratorType::reduction};
4570unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4572std::string BatchMatmulOp::getLibraryCallName() {
4578bool BatchMatmulOp::hasUserDefinedMaps() {
4579 SmallVector<AffineMap, 3> defaultMaps =
4581 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4582 return defaultMaps != explicitMaps;
4592bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4594 "Expected less than 3 result dim expr.");
4595 bool isValid =
false;
4596 enum Indices { batchPos, mPos, nPos, kPos };
4598 AffineExpr expr = bcastMap.
getResult(0);
4601 AffineExpr expr0 = bcastMap.
getResult(0);
4602 AffineExpr expr1 = bcastMap.
getResult(1);
4607 : ((expr0.isFunctionOfDim(batchPos) &&
4608 expr1.isFunctionOfDim(kPos)) ||
4609 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4614void BatchMatmulOp::regionBuilder(
4615 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4618 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4623 "BatchMatmulOp regionBuilder expects 3 args");
4624 RegionBuilderHelper helper(
b, block);
4625 SmallVector<Value> yields;
4627 TypeFn castVal = TypeFn::cast_signed;
4628 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4629 return attr.
getName() ==
"cast";
4631 if (castIter != attrs.end()) {
4632 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4637 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4638 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4639 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4641 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4642 yields.push_back(addVal);
4643 helper.yieldOutputs(yields);
4646ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4647 SmallVector<Attribute, 3> indexingMapsAttr;
4659 if (!isa<AffineMapAttr>(mapAttr)) {
4661 "expected affine map attribute");
4663 indexingMapsAttr.push_back(mapAttr);
4673 if (indexingMapsAttr.empty()) {
4674 indexingMapsAttr = llvm::map_to_vector(
4675 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4676 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4678 result.addAttribute(
"indexing_maps",
4681 return ::parseNamedStructuredOp(parser,
result,
4682 BatchMatmulOp::getNumRegionArgs(),
4683 BatchMatmulOp::getRegionBuilder());
4686void BatchMatmulOp::print(OpAsmPrinter &p) {
4687 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4688 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4689 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4690 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4691 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4693 std::array<StringRef, 3> elidedAttrs = {
4694 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4700LogicalResult BatchMatmulOp::verify() {
4703 if (!hasUserDefinedMaps())
4706 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4713LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4714 SmallVectorImpl<OpFoldResult> &) {
4718void BatchMatmulOp::getEffects(
4719 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4721 if (hasPureTensorSemantics())
4735struct ArityGroupAndKind {
4737 ElementwiseArityGroup arityGroup;
4743 TernaryFn ternaryFn;
4747unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4748 return static_cast<unsigned>(arityGroup);
4753 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4754 constexpr int lastBinary =
4755 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4756 constexpr int lastTernary =
4757 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4759 int val =
static_cast<int>(kind);
4760 ArityGroupAndKind
result;
4762 if (val < lastUnary) {
4763 result.arityGroup = ElementwiseArityGroup::Unary;
4764 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4767 if (val < lastBinary) {
4768 result.arityGroup = ElementwiseArityGroup::Binary;
4769 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4772 if (val >= lastTernary) {
4773 llvm_unreachable(
"unhandled ElementwiseFn");
4775 result.arityGroup = ElementwiseArityGroup::Ternary;
4776 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4781 auto rank = getResultRank();
4786ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4792ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4795 mlir::linalg::ElementwiseKind elemwiseKindVal;
4800 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4801 if (!elemwiseKindAttr)
4803 "expected ElementwiseKind attribute");
4804 elemwiseKindVal = elemwiseKindAttr.getValue();
4807 "expected operation 'kind' attribute");
4810 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4813 SmallVector<Attribute, 3> indexingMapsAttr;
4823 if (!isa<AffineMapAttr>(mapAttr))
4825 "expected affine map attribute");
4826 indexingMapsAttr.push_back(mapAttr);
4837 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4839 ElementwiseOp::getRegionBuilder())) {
4841 "unable to parse elemwise op");
4845 if (indexingMapsAttr.empty()) {
4848 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4849 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4852 "return type needs to be shaped type");
4853 auto numDims = shapedType.getRank();
4854 indexingMapsAttr = llvm::map_to_vector(
4855 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4857 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4860 result.addAttribute(
"indexing_maps",
4865void ElementwiseOp::print(OpAsmPrinter &p) {
4868 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4872 unsigned numDims = getResultRank();
4874 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4875 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4877 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4879 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4880 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4888void ElementwiseOp::regionBuilder(
4889 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4891 ElementwiseKind elemwiseKind;
4892 for (
auto attr : attrs) {
4893 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4894 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4895 assert(kindAttr &&
"op kind attribute incorrectly set");
4896 elemwiseKind = kindAttr.getValue();
4902 auto arityGroup = groupAndKind.arityGroup;
4903 auto kind = groupAndKind.kind;
4905 getArityGroupAsUInt(arityGroup) + 1 ) {
4906 emitError() <<
"Elementwise regionBuilder expects "
4907 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4912 getArityGroupAsUInt(arityGroup) + 1
4913 &&
"Elementwise regionBuilder number of block args mismatch");
4915 RegionBuilderHelper helper(
b, block);
4916 SmallVector<Value> yields;
4919 if (arityGroup == ElementwiseArityGroup::Unary) {
4922 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4926 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4931 assert(
false &&
"found unhandled category in elemwise");
4934 yields.push_back(
result);
4935 helper.yieldOutputs(yields);
4938LogicalResult ElementwiseOp::fold(FoldAdaptor,
4939 SmallVectorImpl<OpFoldResult> &) {
4943void ElementwiseOp::getEffects(
4944 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4946 if (hasPureTensorSemantics())
4959template <
typename OpTy,
typename>
4962 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4963 ? packOrUnPack.getDestType()
4964 : packOrUnPack.getSourceType();
4965 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4966 ? packOrUnPack.getSourceType()
4967 : packOrUnPack.getDestType();
4969 packedType.getShape().take_front(unpackedType.getRank()));
4970 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4992 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4994 .take_back(mixedTiles.size()),
4997 if (
shape == ShapedType::kDynamic) {
4998 newMixedTileSizes.push_back(std::get<1>(it));
5005 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5007 newMixedTileSizes.push_back(
tile);
5010 "tile size and dim size don't match!");
5011 newMixedTileSizes.push_back(
5016 return newMixedTileSizes;
5019template <
typename OpTy>
5023 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5024 "applies to only pack or unpack operations");
5025 int64_t destRank = op.getDestRank();
5027 reifiedReturnShapes[0] =
5032template <
typename OpTy>
5034 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5035 "applies to only pack or unpack operations");
5039 assert(tiles.size() == dimsToTile.size() &&
5040 "tiles must match indices of dimension to block");
5042 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5043 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5044 return dimAndTileMapping;
5047template <
typename OpTy>
5049 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5050 "applies to only pack or unpack operations");
5053 unsigned dynamicValIndex = 0;
5054 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5055 if (ShapedType::isStatic(staticTile))
5058 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5060 return mixedInnerTiles;
5063template <
typename OpTy>
5065 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5066 "applies to only pack or unpack operations");
5079 size_t dimsPosSize = dimsPos.size();
5080 if (dimsPosSize > rank)
5083 if (dimsPosSize != uniqued.size())
5085 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5086 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5090template <
typename OpTy>
5092 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5093 "applies to only pack or unpack operations");
5094 Operation *op = packOrUnPack.getOperation();
5103 if (hasZeros(mixedTiles))
5104 return op->
emitError(
"invalid zero tile factor");
5107 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5108 ? packOrUnPack.getSourceType()
5109 : packOrUnPack.getDestType();
5110 size_t unpackedRank = unpackedType.getRank();
5114 return op->
emitError(
"invalid inner_dims_pos vector");
5116 return op->
emitError(
"invalid outer_dims_perm vector");
5117 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5118 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5122 if (mixedTiles.size() > unpackedRank) {
5123 return op->
emitError(
"tiling factors must be less than or equal to the "
5124 "input rank for pack or output rank for unpack");
5126 if (mixedTiles.size() != innerDimsPos.size()) {
5128 "tiling factors must equal the number of dimensions to tile");
5131 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5132 ? packOrUnPack.getDestType()
5133 : packOrUnPack.getSourceType();
5134 size_t packedRank = packedType.getRank();
5136 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5137 if (expectedPackedRank != packedRank) {
5139 "packed rank != (unpacked rank + num tiling factors), got ")
5140 << packedRank <<
" != " << expectedPackedRank;
5146 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5147 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5149 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5151 [](std::tuple<int64_t, OpFoldResult> it) {
5152 int64_t shape = std::get<0>(it);
5153 if (Attribute attr =
5154 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5155 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5156 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5157 return shape == staticTileSize;
5159 return ShapedType::isDynamic(
shape);
5161 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
5162 "tiled dimension in the packed type");
5165 packedType.getShape()))) {
5166 return op->emitError(
"expected ")
5167 << expectedPackedType <<
" for the packed domain value, got "
5180struct PackOrUnPackTransposeResult {
5187template <
typename OpTy>
5188static PackOrUnPackTransposeResult
5192 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5193 "applies to only pack or unpack operations");
5194 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5195 "some permutation must be non-empty");
5196 PackOrUnPackTransposeResult metadata;
5197 metadata.innerDimsPos =
5199 metadata.innerTiles =
5201 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5202 ? packOrUnPackOp.getSourceRank()
5203 : packOrUnPackOp.getDestRank();
5204 metadata.outerDimsPerm =
5205 packOrUnPackOp.getOuterDimsPerm().empty()
5206 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5208 if (!innerPermutation.empty()) {
5209 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5211 "invalid inner permutation");
5215 if (!outerPermutation.empty()) {
5216 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5218 "invalid outer permutation");
5229 setNameFn(getResult(),
"pack");
5235 std::optional<Value> paddingValue,
5237 assert(innerDimsPos.size() == innerTiles.size() &&
5238 "number of tile sizes specified must match the specified number of "
5239 "original dimensions to be tiled");
5243 build(builder, state, dest.
getType(), source, dest,
5244 paddingValue ? *paddingValue :
nullptr,
5245 outerDimsPerm.empty() ?
nullptr
5252PackOp::reifyResultShapes(
OpBuilder &builder,
5270 ShapedType inputType = getSourceType();
5271 int64_t inputRank = inputType.getRank();
5272 return getDestType().getShape().take_front(inputRank);
5276 auto innerDimsPos = getInnerDimsPos();
5283 if (!outerDimPermInv.empty())
5287 for (
auto index : innerDimsPos)
5288 res.push_back(outerDims[
index]);
5299 outputShape.take_front(inputShape.size()));
5300 if (!outerDimsPerm.empty()) {
5301 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5302 "expected output and outer_dims_perm to have same size");
5306 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5307 if (ShapedType::isDynamic(inputShape[pos]))
5311 if (!constantTile) {
5312 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5313 (inputShape[pos] % outputTileSizes[pos] != 0))
5315 }
else if (inputShape[pos] % (*constantTile) != 0) {
5328 outputShape.take_front(inputShape.size()));
5329 if (!outerDimsPerm.empty()) {
5330 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5331 "expected output and outer_dims_perm to have same size");
5335 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5336 if (ShapedType::isDynamic(inputShape[pos]) ||
5337 ShapedType::isDynamic(outputTileSizes[pos]))
5342 if (inputShape[pos] % (*constantTile) != 0)
5348LogicalResult PackOp::verify() {
5355 auto paddingValue = getPaddingValue();
5359 << getSourceType().getElementType()
5360 <<
" but got: " << paddingValue.getType();
5363 if (!paddingValue &&
5364 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5365 getDestType().
getShape(), getOuterDimsPerm(),
5368 "invalid tile factor or output size provided. Only full tiles are "
5369 "supported when padding_value is not set");
5379 for (
auto o : ofrs) {
5381 if (llvm::dyn_cast_if_present<Value>(o))
5382 result.push_back(ShapedType::kDynamic);
5396 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5397 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5399 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5400 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5403 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5404 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5408 if (!outerDimsPerm.empty())
5412 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5425 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5427 builder, loc, ceilDivExpr,
5428 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5430 if (!outerDimsPerm.empty())
5432 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5437 innerDimsPos, outerDimsPerm);
5443 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5444 if (ShapedType::isStatic(resultTypeShape[i]))
5455RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5460 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5461 return RankedTensorType::get(resultShape, sourceType.getElementType());
5476 for (
auto [
index, value] : llvm::enumerate(
5477 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5478 if (ShapedType::isDynamic(value))
5479 mixedSizes.push_back(
5480 tensor::DimOp::create(
b, loc, source,
index).getResult());
5482 mixedSizes.push_back(
b.getIndexAttr(value));
5484 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5485 int64_t dimPos = std::get<0>(it);
5487 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5489 if (!outerDimsPerm.empty())
5492 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5493 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5494 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5501 *
this, innerPermutation, outerPermutation);
5502 Value transposedDest =
5503 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5504 metadata.innerDimsPos, metadata.outerDimsPerm);
5505 return PackOp::create(
b, loc, getSource(), transposedDest,
5506 metadata.innerDimsPos, metadata.innerTiles,
5507 getPaddingValue(), metadata.outerDimsPerm);
5511template <
typename OpTy>
5513 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5514 "applies to only pack or unpack operations");
5515 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5517 : op.getSourceType();
5519 for (
auto [dimDest,
tile] : llvm::zip(
5520 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5522 if (!constTileSize || ShapedType::isDynamic(dimDest))
5529 if (getPaddingValue())
5544 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5546 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5558 auto packTiles = packOp.getMixedTiles();
5559 auto unPackTiles = unPackOp.getMixedTiles();
5560 if (packTiles.size() != unPackTiles.size())
5562 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5571 auto srcType = op.getSourceType();
5572 if (llvm::any_of(op.getInnerDimsPos(),
5573 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5575 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5577 return !PackOp::requirePaddingValue(
5578 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5579 op.getOuterDimsPerm(), op.getMixedTiles());
5586 bool changeNeeded =
false;
5587 srcShape.assign(packOp.getSourceType().getShape().begin(),
5588 packOp.getSourceType().getShape().end());
5589 destShape.assign(packOp.getDestType().getShape().begin(),
5590 packOp.getDestType().getShape().end());
5591 llvm::SmallSetVector<int64_t, 4> innerDims;
5592 innerDims.insert_range(packOp.getInnerDimsPos());
5594 if (!packOp.getOuterDimsPerm().empty())
5596 int srcRank = packOp.getSourceRank();
5597 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5598 if (innerDims.contains(i))
5602 if (!inverseOuterDimsPerm.empty())
5603 destPos = inverseOuterDimsPerm[srcPos];
5604 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5605 ShapedType::isDynamic(destShape[destPos])) {
5608 int64_t size = srcShape[srcPos];
5609 if (ShapedType::isDynamic(size))
5610 size = destShape[destPos];
5611 srcShape[srcPos] = size;
5612 destShape[destPos] = size;
5613 changeNeeded =
true;
5615 return changeNeeded;
5618LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5620 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5621 if (unPackOp.getSourceType() == packOp.getDestType() &&
5622 !packOp.getPaddingValue() &&
5625 rewriter.
replaceOp(packOp, unPackOp.getSource());
5633 packOp.getPaddingValueMutable().clear();
5642 Value source = packOp.getSource();
5643 if (srcShape != packOp.getSourceType().getShape()) {
5644 auto newSrcType = packOp.getSourceType().clone(srcShape);
5646 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5648 Value dest = packOp.getDest();
5649 RankedTensorType originalResultType = packOp.getDestType();
5650 bool needUpdateDestType = (destShape != originalResultType.getShape());
5651 if (needUpdateDestType) {
5652 auto newDestType = packOp.getDestType().clone(destShape);
5654 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5657 packOp.getSourceMutable().assign(source);
5658 packOp.getDestMutable().assign(dest);
5659 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5662 if (needUpdateDestType) {
5665 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5674template <
typename PackOrUnpackOp>
5676 RankedTensorType packedTensorType) {
5677 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5678 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5679 "Function meant for pack/unpack");
5684 int64_t numPackedDims = innerDimsPos.size();
5685 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5686 if (orderedDims != innerDimsPos) {
5692 int64_t packedRank = packedTensorType.getRank();
5702 return llvm::all_of(
5703 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5704 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5707bool PackOp::isLikePad() {
5708 auto packedTensorType =
5709 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5714 std::optional<Attribute> paddingValue;
5715 if (
auto pad = adaptor.getPaddingValue())
5717 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5718 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5719 getDestType(), paddingValue))
5720 return reshapedSource;
5759 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5760 op.getInnerDimsPos(), newMixedTileSizes,
5761 op.getPaddingValue(), op.getOuterDimsPerm());
5762 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5765 Value oldResult = op.getResult();
5766 Value newResult = newOp.getResult();
5769 ? tensor::CastOp::create(rewriter, op->getLoc(),
5770 oldResult.
getType(), newResult)
5783void UnPackOp::getAsmResultNames(
5785 setNameFn(getResult(),
"unpack");
5789UnPackOp::reifyResultShapes(
OpBuilder &builder,
5807 ShapedType destType = getDestType();
5808 int64_t destRank = destType.getRank();
5809 return getSourceType().getShape().take_front(destRank);
5813 auto innerDimsPos = getInnerDimsPos();
5820 if (!outerDimPermInv.empty())
5824 for (
auto index : innerDimsPos)
5825 res.push_back(outerDims[
index]);
5830LogicalResult UnPackOp::verify() {
5846 assert(innerDimsPos.size() == innerTiles.size() &&
5847 "number of tile sizes specified must match the specified number of "
5848 "original dimensions to be tiled");
5852 build(builder, state, dest.
getType(), source, dest,
5853 outerDimsPerm.empty() ?
nullptr
5871 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5873 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5874 if (srcType.isDynamicDim(i))
5875 mixedSizes.push_back(
5876 tensor::DimOp::create(
b, loc, source, i).getResult());
5878 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
5880 if (!outerDimsPerm.empty()) {
5885 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5886 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5888 auto elemType = srcType.getElementType();
5889 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5893 Value transposedSource,
5897 *
this, innerPermutation, outerPermutation);
5898 return UnPackOp::create(
b, loc, transposedSource, getDest(),
5899 metadata.innerDimsPos, metadata.innerTiles,
5900 metadata.outerDimsPerm);
5907 bool changeNeeded =
false;
5908 srcShape.assign(op.getSourceType().getShape().begin(),
5909 op.getSourceType().getShape().end());
5910 destShape.assign(op.getDestType().getShape().begin(),
5911 op.getDestType().getShape().end());
5912 llvm::SmallSetVector<int64_t, 4> innerDims;
5913 innerDims.insert_range(op.getInnerDimsPos());
5915 if (!op.getOuterDimsPerm().empty())
5917 int destRank = op.getDestRank();
5918 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5919 if (innerDims.contains(i))
5923 if (!inverseOuterDimsPerm.empty())
5924 srcPos = inverseOuterDimsPerm[destPos];
5925 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5926 ShapedType::isDynamic(destShape[destPos])) {
5929 int64_t size = srcShape[srcPos];
5930 if (ShapedType::isDynamic(size))
5931 size = destShape[destPos];
5932 srcShape[srcPos] = size;
5933 destShape[destPos] = size;
5934 changeNeeded =
true;
5936 return changeNeeded;
5939LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5942 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5943 if (packOp.getSourceType() != unPackOp.getDestType())
5945 if (packOp.getPaddingValue() ||
5949 rewriter.
replaceOp(unPackOp, packOp.getSource());
5953 if (
auto dstStyleOp =
5954 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5955 auto destValue = cast<OpResult>(unPackOp.getDest());
5956 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5958 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5962 if (unPackOp->hasOneUse()) {
5963 auto extractSliceUser =
5964 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5965 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5968 auto newDest = tensor::ExtractSliceOp::create(
5969 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5970 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5971 extractSliceUser.getMixedStrides());
5973 unPackOp.setDpsInitOperand(0, newDest);
5974 unPackOp.getResult().setType(newDest.
getType());
5976 rewriter.
replaceOp(extractSliceUser, unPackOp);
5985 Value source = unPackOp.getSource();
5986 if (srcShape != unPackOp.getSourceType().getShape()) {
5987 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5988 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5989 unPackOp.getSource());
5991 Value dest = unPackOp.getDest();
5992 if (destShape != unPackOp.getDestType().getShape()) {
5993 auto newDestType = unPackOp.getDestType().clone(destShape);
5994 dest = tensor::CastOp::create(rewriter, loc, newDestType,
5995 unPackOp.getDest());
5997 Value newOp = UnPackOp::create(
5998 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5999 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6001 unPackOp, unPackOp.getResult().
getType(), newOp);
6008bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6010 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6015 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6018 for (
auto [pos, tileSize] :
6019 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6020 if (unpackedTypeAfterFold.isDynamicDim(pos))
6022 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6024 if (ShapedType::isDynamic(tileSize))
6026 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6027 unpackedTypeAfterFold.getDimSize(pos);
6028 if (paddingSize >= tileSize)
6034bool UnPackOp::isLikeUnPad() {
6035 RankedTensorType packedTensorType = getSourceType();
6040 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6041 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6043 return reshapedSource;
6072 Value sourceTensor = newOperands[0];
6076 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6082 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6083 newOperands[1], op.getInnerDimsPos(),
6084 newMixedTileSizes, op.getOuterDimsPerm());
6085 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6088 Value oldResult = op.getResult();
6089 Value newResult = newOp.getResult();
6092 ? tensor::CastOp::create(rewriter, op->getLoc(),
6093 oldResult.
getType(), newResult)
6107 utils::IteratorType::reduction, utils::IteratorType::parallel,
6108 utils::IteratorType::parallel, utils::IteratorType::reduction};
6112BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
6116 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6117 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6119 return indexingMaps;
6122bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
6123 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6126 if (maps.size() != 3)
6135unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6137std::string BatchReduceMatmulOp::getLibraryCallName() {
6143bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6147 return defaultMaps != explicitMaps;
6157bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
6160 "Expected less than 3 result dim expr.");
6161 bool isValid =
false;
6162 enum Indices { batchPos, mPos, nPos, kPos };
6173 : ((expr0.isFunctionOfDim(batchPos) &&
6174 expr1.isFunctionOfDim(kPos)) ||
6175 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6180void BatchReduceMatmulOp::regionBuilder(
6184 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6189 "BatchReduceMatmulOp regionBuilder expects 3 args");
6190 RegionBuilderHelper helper(
b, block);
6195 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6197 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6198 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6200 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6201 yields.push_back(addVal);
6202 helper.yieldOutputs(yields);
6205ParseResult BatchReduceMatmulOp::parse(
OpAsmParser &parser,
6218 if (!isa<AffineMapAttr>(mapAttr)) {
6220 "expected affine map attribute");
6222 indexingMapsAttr.push_back(mapAttr);
6232 if (indexingMapsAttr.empty()) {
6233 indexingMapsAttr = llvm::map_to_vector(
6234 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6237 result.addAttribute(
"indexing_maps",
6239 return ::parseNamedStructuredOp(parser,
result,
6240 BatchReduceMatmulOp::getNumRegionArgs(),
6241 BatchReduceMatmulOp::getRegionBuilder());
6246 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6249 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6250 p <<
" indexing_maps = [";
6251 llvm::interleaveComma(getIndexingMaps(), p,
6257 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6263LogicalResult BatchReduceMatmulOp::verify() {
6266 if (!hasUserDefinedMaps())
6269 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6275LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6279void BatchReduceMatmulOp::getEffects(
6282 if (hasPureTensorSemantics())
6298void LinalgDialect::getCanonicalizationPatterns(
6307 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override