42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
63 auto type = cast<ShapedType>(v.
getType());
64 if (!type.isDynamicDim(dim))
69 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
70 return builder.create<tensor::DimOp>(loc, v, dim);
72 .Case<MemRefType>([&](MemRefType t) ->
Value {
73 return builder.create<memref::DimOp>(loc, v, dim);
84 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
85 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
88 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
89 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
105 llvm_unreachable(
"Expected MemRefType or TensorType");
110 auto shapedType = llvm::cast<ShapedType>(source.
getType());
111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
151 regionBuilder(b, *body, attrs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
178 state.addAttributes(attributes);
180 "operandSegmentSizes",
182 static_cast<int32_t>(outputs.size())}));
185 Region ®ion = *state.addRegion();
187 state.attributes.getAttrs(), regionBuilder);
191 std::optional<TypeRange> resultTensorTypes,
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
201 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
207 std::optional<TypeRange> resultTensorTypes,
214 indexingMapsAttrVal =
218 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
220 attributes, regionBuilder);
224 std::optional<TypeRange> resultTensorTypes,
231 indexingMapsAttrVal =
235 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
237 attributes, regionBuilder);
246 bool addOperandSegmentSizes =
true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
276 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
282 if (addOperandSegmentSizes) {
291 attrs.
append(
"operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
304 std::optional<RegisteredOperationName> info =
307 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 <<
"'" << result.name.getStringRef() <<
"' op ";
320 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
321 if (!outputs.empty())
322 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
336 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
357 unsigned numRegionArgs,
373 result.
addTypes(outputTensorsTypes);
375 std::unique_ptr<Region> region = std::make_unique<Region>();
387 if (resultTypes.empty())
432 class RegionBuilderHelper {
435 : builder(builder), block(block) {}
439 if (!isFloatingPoint(arg))
440 llvm_unreachable(
"unsupported non numeric type");
442 builder.setInsertionPointToEnd(&block);
445 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
447 return builder.create<math::LogOp>(arg.
getLoc(), arg);
449 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
451 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
453 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
455 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
456 case UnaryFn::reciprocal: {
458 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
459 ::cast<TypedAttr>(oneAttr));
460 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
463 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
465 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
467 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
468 case UnaryFn::square:
469 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
471 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
473 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
475 llvm_unreachable(
"unsupported unary function");
480 bool allComplex = isComplex(arg0) && isComplex(arg1);
481 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
482 bool allInteger = isInteger(arg0) && isInteger(arg1);
485 if (!allComplex && !allFloatingPoint && !allInteger)
486 llvm_unreachable(
"unsupported non numeric type");
488 builder.setInsertionPointToEnd(&block);
492 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
493 if (allFloatingPoint)
494 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
496 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
497 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
500 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
501 if (allFloatingPoint)
502 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
504 llvm_unreachable(
"unsupported operation: sub with bools");
505 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
508 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
509 if (allFloatingPoint)
510 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
512 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
513 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
516 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
517 if (allFloatingPoint)
518 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
520 llvm_unreachable(
"unsupported operation: div with bools");
521 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
522 case BinaryFn::div_unsigned:
523 if (!allInteger || allBool)
524 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
525 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
526 case BinaryFn::max_signed:
528 if (allFloatingPoint)
529 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
530 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
531 case BinaryFn::min_signed:
533 if (allFloatingPoint)
534 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
535 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
536 case BinaryFn::max_unsigned:
538 if (allFloatingPoint)
539 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
540 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
541 case BinaryFn::min_unsigned:
543 if (allFloatingPoint)
544 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
545 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
547 assert(allFloatingPoint);
548 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
550 llvm_unreachable(
"unsupported binary function");
558 bool tailFloatingPoint =
559 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
560 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
562 builder.setInsertionPointToEnd(&block);
564 case TernaryFn::select:
565 if (!headBool && !(tailFloatingPoint || tailInteger))
566 llvm_unreachable(
"unsupported non numeric type");
567 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
569 llvm_unreachable(
"unsupported ternary function");
575 case TypeFn::cast_signed:
576 return cast(toType, operand,
false);
577 case TypeFn::cast_unsigned:
578 return cast(toType, operand,
true);
580 llvm_unreachable(
"unsupported type conversion function");
585 builder.setInsertionPointToEnd(&block);
586 Location loc = builder.getUnknownLoc();
587 builder.create<YieldOp>(loc, values);
590 Value constant(
const std::string &value) {
592 builder.setInsertionPointToEnd(&block);
593 Location loc = builder.getUnknownLoc();
595 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
598 Value index(int64_t dim) {
600 builder.setInsertionPointToEnd(&block);
601 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
604 Type getIntegerType(
unsigned width) {
618 builder.setInsertionPointToEnd(&block);
619 auto loc = operand.
getLoc();
623 bool isComplex(
Value value) {
624 return llvm::isa<ComplexType>(value.
getType());
626 bool isFloatingPoint(
Value value) {
627 return llvm::isa<FloatType>(value.
getType());
629 bool isInteger(
Value value) {
630 return llvm::isa<IntegerType>(value.
getType());
647 LogicalResult matchAndRewrite(CopyOp copyOp,
649 if (copyOp.getInputs() != copyOp.getOutputs())
651 if (copyOp.hasPureBufferSemantics())
654 rewriter.
replaceOp(copyOp, copyOp.getInputs());
664 results.
add<EraseSelfCopy>(context);
677 template <
typename TensorReshapeOp>
680 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
682 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
687 TensorReshapeOp newInit;
688 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
690 newInit = rewriter.
create<TensorReshapeOp>(
691 loc, reshapeOp.getResultType(), oldFill.output(),
692 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
693 reshapeOp.getStaticOutputShape());
695 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
697 reshapeOp.getReassociation());
710 LogicalResult matchAndRewrite(tensor::PadOp padOp,
712 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
718 Value padValue = padOp.getConstantPaddingValue();
719 if (!padValue || fillOp.value() != padValue)
725 padOp,
"failed to reify tensor.pad op result shape");
727 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
728 padOp.getLoc(), reifiedShape.front(),
729 padOp.getResultType().getElementType());
735 if (replacement.getType() != padOp.getResultType()) {
736 replacement = rewriter.
create<tensor::CastOp>(
737 fillOp.getLoc(), padOp.getResultType(), replacement);
747 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
750 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
752 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
756 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
761 Value firstDest = insertOp.getDest();
762 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
763 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
768 bool disjoint =
false;
769 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
772 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
773 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
774 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
778 int64_t prevStart = prevOp.getStaticOffset(i);
779 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
780 prevOp.getStaticStride(i);
781 int64_t nextStart = insertOp.getStaticOffset(i);
782 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
783 insertOp.getStaticStride(i);
784 if (prevEnd < nextStart || nextEnd < prevStart) {
792 firstDest = prevOp.getDest();
803 Value padValue = srcPadOp.getConstantPaddingValue();
804 if (!padValue || dstFillOp.value() != padValue)
820 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
822 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
825 RankedTensorType srcPadType = srcPadOp.getSourceType();
827 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
828 if (srcPadType.isDynamicDim(i)) {
830 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
833 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
838 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
839 newSizes, insertOp.getMixedStrides());
845 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
849 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
853 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
858 Value extractedScalar = fillOp.getInputs()[0];
861 rewriter.
replaceOp(extractOp, extractedScalar);
869 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
870 linalg::PackOp packOp) {
871 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
875 if (
auto paddingValue = packOp.getPaddingValue())
879 Value packOpDest = packOp.getDest();
883 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
893 LogicalResult matchAndRewrite(linalg::PackOp packOp,
895 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
898 rewriter.
replaceOp(packOp, fillOp.value().result());
907 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
909 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
912 copyOp.getOutputs());
915 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
917 fillOp.getOutputs());
928 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
930 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
932 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
933 transposeOp.getDpsInitOperand(0)->get());
945 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
947 auto concatOperands = concatOp.getInputs();
948 if (concatOperands.empty()) {
952 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
961 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
963 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
964 auto fillOp = v.getDefiningOp<linalg::FillOp>();
971 if (fillVal != firstFillVal)
974 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
977 if (!llvm::all_of(concatOperands.drop_front(),
978 isDefinedByCompatibleFillOp)) {
980 concatOp,
"not all operands are defined by a compatible fill op");
983 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
984 concatOp.getLoc(), concatOp.getDim(), allOuts);
986 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
995 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
996 FoldFillWithPack, FoldFillWithPad,
997 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
998 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
999 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1012 for (
ValueRange container : {inputs, outputs}) {
1013 for (
Value v : container) {
1014 Type t = v.getType();
1015 blockArgTypes.push_back(
1017 blockArgLocs.push_back(v.getLoc());
1023 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1027 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1029 for (
Value v : getRegionInputArgs())
1031 for (
Value v : getRegionOutputArgs())
1032 setNameFn(v,
"out");
1035 void GenericOp::build(
1038 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1041 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall);
1046 inputs, outputs, bodyBuild);
1049 void GenericOp::build(
1053 StringRef libraryCall,
1056 build(builder, result, resultTensorTypes, inputs, outputs,
1061 return IteratorTypeAttr::get(builder.getContext(), iter);
1064 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1065 bodyBuild, attributes);
1068 void GenericOp::build(
1072 StringRef libraryCall,
1075 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1076 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1079 void GenericOp::build(
1085 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1087 "", bodyBuild, attributes);
1090 void GenericOp::build(
1096 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1099 "", bodyBuild, attributes);
1106 auto genericAttrNames = linalgTraitAttrNames();
1109 genericAttrNamesSet.insert_range(genericAttrNames);
1111 for (
auto attr : (*this)->getAttrs()) {
1112 if (attr.getName() == getIteratorTypesAttrName()) {
1113 auto iteratorTypes =
1114 llvm::cast<ArrayAttr>(attr.getValue())
1115 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1121 llvm::to_vector(llvm::map_range(
1122 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1126 genericAttrs.emplace_back(
1127 getIteratorTypesAttrName(),
1129 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1130 genericAttrs.push_back(attr);
1133 if (!genericAttrs.empty()) {
1135 p << genericDictAttr;
1141 genericAttrNames.push_back(
"operandSegmentSizes");
1142 genericAttrNamesSet.insert(genericAttrNames.back());
1144 bool hasExtraAttrs =
false;
1146 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1149 if (hasExtraAttrs) {
1156 if (!getRegion().empty()) {
1166 DictionaryAttr dictAttr;
1175 dictAttr.getValue().end());
1181 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1183 if (!iteratorTypes) {
1184 return parser.
emitError(attributeLocation)
1185 <<
"expected " << getIteratorTypesAttrName(result.
name)
1186 <<
" array attribute";
1191 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1192 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1193 if (!maybeIteratorType.has_value())
1195 <<
"unexpected iterator_type (" << s <<
")";
1197 iteratorTypeAttrs.push_back(
1214 std::unique_ptr<Region> region = std::make_unique<Region>();
1226 result.
addTypes(outputTensorsTypes);
1234 LinalgOp linalgOp) {
1235 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1236 if (!llvm::isa<MemRefType>(operand.
getType()))
1238 effects.emplace_back(
1243 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1244 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1246 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1257 void GenericOp::getEffects(
1267 if (!linalgOp.hasPureTensorSemantics())
1287 template <
typename OpTy>
1291 LogicalResult matchAndRewrite(OpTy linalgOp,
1294 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1299 Block &body = linalgOp->getRegion(0).
front();
1300 if (!llvm::hasSingleElement(body))
1302 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1307 if (linalgOp.hasPureBufferSemantics()) {
1308 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1309 linalgOp.getDpsInputOperand(0)->get() !=
1310 linalgOp.getDpsInitOperand(0)->get()) {
1312 linalgOp,
"expected single input and output to be the same value");
1315 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1316 if (!yieldArg || yieldArg.getOwner() != &body) {
1318 "cannot fold fill-like op");
1325 if (!linalgOp.hasPureTensorSemantics()) {
1327 linalgOp,
"mixed semantics is not supported yet");
1334 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1335 if (!yieldArg || yieldArg.getOwner() != &body)
1337 unsigned argumentNumber = yieldArg.getArgNumber();
1338 Value returnedArg = linalgOp->getOperand(argumentNumber);
1339 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1343 if (returnType != resultType) {
1348 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1349 linalgOp.getLoc(), resultType, returnedArg);
1351 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1354 returnedArg = rewriter.
create<tensor::CastOp>(
1355 linalgOp.getLoc(), resultType, returnedArg);
1358 returnedArgs.push_back(returnedArg);
1361 if (returnedArgs.size() != linalgOp->getNumResults())
1363 rewriter.
replaceOp(linalgOp, returnedArgs);
1372 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1394 for (
Type outputType : outputTypes) {
1395 if (llvm::isa<RankedTensorType>(outputType))
1400 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1409 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1411 for (
Value v : getRegionInputArgs())
1416 if (!getResults().empty())
1417 setNameFn(getResults().front(),
"mapped");
1424 build(builder, result,
TypeRange{}, inputs, init);
1429 if (llvm::isa<RankedTensorType>(initType))
1434 inputs, {}, bodyBuild);
1441 bool initFirst =
false) {
1446 for (
auto &operand : operands) {
1448 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1455 payloadOpOperands.push_back(block.
getArguments().back());
1456 for (
const auto &arg : block.
getArguments().drop_back())
1457 payloadOpOperands.push_back(arg);
1466 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1473 std::optional<OperationName> payloadOpName;
1477 if (failed(operationName))
1481 payloadOpName = operationName.value();
1489 if (payloadOpName.has_value()) {
1527 for (
const auto &[operand, bbArg] :
1529 if (bbArg != operand)
1533 for (
const auto &[operand, bbArg] :
1535 if (bbArg != operand)
1544 std::string attrToElide;
1546 for (
const auto &attr : payloadOp->
getAttrs()) {
1548 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1549 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1550 attrToElide = attr.getName().str();
1551 elidedAttrs.push_back(attrToElide);
1560 Block *mapper = getBody();
1575 [&](
auto arg) { p.printRegionArgument(arg); });
1584 auto *bodyBlock = getBody();
1585 auto blockArgs = bodyBlock->getArguments();
1588 if (getInputs().size() != blockArgs.size())
1589 return emitOpError() <<
"expects number of operands to match the arity of "
1591 << getInputs().size() <<
" and " << blockArgs.size();
1594 for (
const auto &[bbArgType, inputArg] :
1595 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1596 auto inputElemType =
1597 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1598 if (bbArgType != inputElemType) {
1599 return emitOpError() <<
"expected element type of input " << inputElemType
1600 <<
" to match bbArg type " << bbArgType;
1605 auto outputShape = getInit().getType().getShape();
1607 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1608 if (inputElemShape != outputShape) {
1609 return emitOpError() <<
"expected shape of input (" << inputElemShape
1610 <<
") to match shape of output (" << outputShape
1619 int64_t rank = getInit().getType().getRank();
1623 ArrayAttr MapOp::getIndexingMaps() {
1625 int64_t rank = getInit().getType().getRank();
1626 int64_t numIndexingMaps = getOperands().size();
1631 void MapOp::getEffects(
1645 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1647 for (
Value v : getRegionInputArgs())
1649 for (
Value v : getRegionOutputArgs())
1650 setNameFn(v,
"init");
1653 void ReduceOp::getAsmResultNames(
1655 if (!getResults().empty())
1656 setNameFn(getResults().front(),
"reduced");
1659 void ReduceOp::build(
1664 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1668 for (
Value init : inits) {
1670 if (llvm::isa<RankedTensorType>(initType))
1676 inputs, inits, bodyBuild);
1681 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1683 utils::IteratorType::parallel);
1684 for (int64_t reductionDim : getDimensions())
1685 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1686 return iteratorTypes;
1689 ArrayAttr ReduceOp::getIndexingMaps() {
1691 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1698 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1699 affineMaps.push_back(resultMap);
1703 void ReduceOp::getEffects(
1715 StringRef attributeName) {
1724 std::optional<OperationName> payloadOpName;
1728 if (failed(operationName))
1732 payloadOpName = operationName.value();
1743 if (payloadOpName.has_value()) {
1763 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1767 Block *mapper = getBody();
1782 [&](
auto arg) { p.printRegionArgument(arg); });
1793 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1794 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1795 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1796 return emitOpError() <<
"expects all inputs to have the same shapes. "
1797 "Shape at input-index "
1799 <<
" is not equal to the shape at input-index 0.";
1802 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1803 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1804 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1805 return emitOpError() <<
"expects all outputs to have the same shapes. "
1806 "Shape at output-index "
1808 <<
" is not equal to the shape at output-index 0.";
1811 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1812 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1815 for (int64_t dimension : dimensionsRef) {
1816 if (dimension < 0 || dimension >= inputType.getRank()) {
1817 return emitOpError()
1818 <<
"dimensions for reduction should be in the range [0, "
1819 << inputType.getRank() - 1 <<
"].";
1821 dimensionsToReduce.insert(dimension);
1824 auto inputDims = inputType.getShape();
1825 auto initDims = initType.getShape();
1830 if (!dimensionsToReduce.count(en.index()))
1831 reducedInputDims.push_back(en.value());
1834 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1835 return emitOpError() <<
"number of dimensions after reduction "
1836 << reducedInputDims.size()
1837 <<
" doesn't match the init rank "
1838 << initType.getRank();
1841 if (reducedInputDims != initDims)
1842 return emitOpError() <<
"init dimensions [" << initDims
1843 <<
"] doesn't match input dimensions after reduction ["
1844 << reducedInputDims <<
"]";
1846 Block *block = getBody();
1848 return emitOpError()
1849 <<
"mismatching number of operands and block arguments";
1852 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1853 Type inputElementType =
1854 llvm::cast<ShapedType>(input.getType()).getElementType();
1855 if (inputElementType != bbArg.getType())
1856 return emitOpError()
1857 <<
"input element type " << inputElementType
1858 <<
" does not match corresponding block argument type "
1863 for (
auto [output, bbArg] : llvm::zip(
1864 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1865 auto outputElementType =
1866 llvm::cast<ShapedType>(output.getType()).getElementType();
1867 if (outputElementType != bbArg.getType())
1868 return emitOpError()
1869 <<
"output element type " << outputElementType
1870 <<
" does not match corresponding block argument type "
1886 b.
create<linalg::YieldOp>(loc, args[0]);
1901 if (llvm::isa<RankedTensorType>(initType))
1930 void TransposeOp::getAsmResultNames(
1932 if (!getResults().empty())
1933 setNameFn(getResults().front(),
"transposed");
1946 return emitOpError(
"permutation is not valid");
1948 auto inputType = getInput().getType();
1949 auto initType = getInit().getType();
1951 int64_t rank = inputType.getRank();
1953 if (rank != initType.getRank())
1954 return emitOpError() <<
"input rank " << rank
1955 <<
" does not match init rank " << initType.getRank();
1957 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1958 return emitOpError() <<
"size of permutation " << permutationRef.size()
1959 <<
" does not match the argument rank " << rank;
1961 auto inputDims = inputType.getShape();
1962 auto initDims = initType.getShape();
1964 for (int64_t i = 0; i < rank; ++i) {
1965 int64_t inputDim = inputDims[permutationRef[i]];
1966 int64_t initDim = initDims[i];
1968 if (inputDim != initDim) {
1969 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1970 <<
" doesn't match dim(input, permutation[" << i
1971 <<
"]) = " << inputDim;
1979 int64_t rank = getInit().getType().getRank();
1983 ArrayAttr TransposeOp::getIndexingMaps() {
1985 int64_t rank = getInit().getType().getRank();
1988 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1992 void TransposeOp::getEffects(
2002 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2005 if (!isa<TensorType>(getInput().
getType()))
2009 if (getPermutation().size() == 0) {
2010 result.push_back(getInput());
2015 result.push_back(getInput());
2028 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2029 if (!defTransposeOp)
2034 foldedPerms.reserve(perms.size());
2035 for (int64_t perm : perms)
2036 foldedPerms.push_back(defPerms[perm]);
2039 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2053 Value input = transposeOp.getInput();
2054 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2065 unsigned dimensionSize = dimensions.size();
2066 for (
unsigned i = 0; i < dimensionSize; ++i)
2067 resultDimensions.push_back(invertPerm[dimensions[i]]);
2070 Value broadcastInput = broadcastOp.getInput();
2071 Location loc = transposeOp.getLoc();
2074 auto broadcastInputTy =
2075 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2076 unsigned inputRank = broadcastInputTy.getRank();
2077 for (
unsigned i = 0; i < inputRank; ++i) {
2078 if (broadcastInputTy.isDynamicDim(i)) {
2079 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2083 broadcastInputTy.getDimSize(i)));
2088 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2089 transposeOp.getLoc(), transposeResultShapes,
2090 broadcastInputTy.getElementType());
2093 Value transposeResult =
2095 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2099 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2124 if (llvm::isa<RankedTensorType>(initType))
2153 void BroadcastOp::getAsmResultNames(
2155 if (!getResults().empty())
2156 setNameFn(getResults().front(),
"broadcasted");
2168 auto inputType = getInput().getType();
2169 auto initType = getInit().getType();
2171 int64_t inputRank = inputType.getRank();
2172 int64_t initRank = initType.getRank();
2174 auto inputShape = inputType.getShape();
2175 auto initShape = initType.getShape();
2177 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2178 return emitOpError() <<
"input rank plus added dimensions does not "
2179 "match init rank. input rank: "
2181 <<
", dimensions size: " << dimensionsRef.size()
2182 <<
", init rank: " << initRank;
2185 if (dim < 0 || dim >= initRank)
2186 return emitOpError() <<
"dimension " << idx
2187 <<
" is out of range. expected range: [0, "
2188 << initRank - 1 <<
"], got: " << dim;
2193 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2194 if (!llvm::is_contained(dimensionsRef, dim))
2195 dimMap.push_back(dim);
2198 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2201 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2202 return emitOpError() <<
"input dim " << inputDimIdx
2203 <<
" should match init dim " << initDimIdx
2204 <<
". input: " << inputShape[inputDimIdx]
2205 <<
", init: " << initShape[initDimIdx];
2212 int64_t rank = getInit().getType().getRank();
2216 ArrayAttr BroadcastOp::getIndexingMaps() {
2218 int64_t rank = getInit().getType().getRank();
2224 void BroadcastOp::getEffects(
2236 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2244 if (getNumOperands() > 0)
2245 p <<
' ' << getOperands();
2247 if (getNumOperands() > 0)
2248 p <<
" : " << getOperandTypes();
2263 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2264 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2265 return op.emitOpError(
"expected number of yield values (")
2266 << op.getNumOperands()
2267 <<
") to match the number of inits / outs operands of the enclosing "
2268 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2270 for (
OpOperand &opOperand : op->getOpOperands()) {
2272 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2274 if (isa<MemRefType, RankedTensorType>(elementType))
2276 if (opOperand.get().getType() != elementType)
2277 return op.emitOpError(
"type of yield operand ")
2278 << (opOperand.getOperandNumber() + 1) <<
" ("
2279 << opOperand.get().getType() <<
") doesn't match "
2280 <<
"the element type of the enclosing linalg.generic op ("
2281 << elementType <<
")";
2287 auto *parentOp = (*this)->getParentOp();
2288 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2289 return emitOpError(
"expected single non-empty parent region");
2291 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2294 return emitOpError(
"expected parent op with LinalgOp interface");
2302 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2304 return emitOpError(
"expected parent op with LinalgOp interface");
2305 if (linalgOp.getNumLoops() <= getDim())
2306 return emitOpError(
"expected dim (")
2307 << getDim() <<
") to be lower than the number of loops ("
2308 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2313 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2322 uint64_t dim = getDim();
2323 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2324 if (loopBounds[dim] == 1)
2332 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2334 #define GET_OP_CLASSES
2335 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2337 #define GET_OP_CLASSES
2338 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2339 #define GET_OP_CLASSES
2340 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2357 for (
unsigned i = 0; i < num; ++i)
2364 auto rangeA = llvm::make_range(a.begin(), a.end());
2365 auto rangeB = llvm::make_range(b.begin(), b.end());
2366 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2367 return llvm::to_vector<4>(concatRanges);
2371 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2373 for (
auto size : memref.getShape())
2380 if (
auto as = memref.getMemorySpace()) {
2381 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2382 ss <<
"as" << attr.getInt();
2388 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2391 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2404 assert(isa<LinalgOp>(op));
2406 std::string fun =
"";
2408 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2409 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2410 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2411 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2415 llvm::replace(name,
'.',
'_');
2416 llvm::raw_string_ostream ss(name);
2420 return std::string();
2435 LogicalResult matchAndRewrite(LinalgOp op,
2437 for (
OpOperand &opOperand : op->getOpOperands()) {
2441 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2444 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2455 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2458 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2463 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2470 if (castOp->getBlock() != linalgOp->getBlock())
2477 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2480 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2486 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2488 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2491 linalgOp.getDpsInits().end());
2492 outputOperands[resultNumber] = newOperand;
2493 newOperands.append(outputOperands.begin(), outputOperands.end());
2496 linalgOp->result_type_end());
2497 resultTypes[resultNumber] = resultType;
2498 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2505 results[resultNumber] = castBack;
2517 if (linalgOp.isScalar(&opOperand))
2519 Value src = opOperand.get();
2520 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2521 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2529 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2530 Value castSource = castOp.getSource();
2531 auto castSourceType =
2532 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2533 if (castSourceType && castSourceType.hasStaticShape())
2534 sourceShape = castSourceType.getShape();
2540 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2541 if (sourceType.isDynamicDim(i))
2543 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2544 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2554 static void createNewOperandWithStaticSizes(
2558 bool &changeNeeded) {
2560 newOperands.push_back(src);
2561 if (linalgOp.isScalar(opOperand))
2563 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2564 Type resultType = sourceType;
2565 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2566 resultTypes.push_back(resultType);
2570 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2574 bool newOperandNeeded =
false;
2575 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2576 int64_t dimShape = sourceShape[i];
2578 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2579 newShape.push_back(dimShape);
2585 newShape.push_back(affineExprToSize[dimExpr]);
2586 newOperandNeeded =
true;
2589 sourceType.getEncoding());
2590 if (newOperandNeeded) {
2591 changeNeeded =
true;
2594 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2596 newOperands[index] = newOperand;
2598 if (linalgOp.isDpsInit(opOperand))
2599 resultTypes.push_back(resultType);
2608 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2610 if (!linalgOp.hasPureTensorSemantics())
2614 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2615 return !map.isProjectedPermutation();
2625 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2632 bool changeNeeded =
false;
2633 newOperands.reserve(linalgOp->getNumOperands());
2634 resultTypes.reserve(linalgOp.getNumDpsInits());
2637 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2638 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2639 affineExprToSize, linalgOp, newOperands,
2640 resultTypes, changeNeeded);
2649 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2652 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2653 Value newResult = std::get<1>(it);
2654 Value oldResult = std::get<0>(it);
2657 replacements.push_back(
2658 (newType != oldType)
2659 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2662 rewriter.
replaceOp(linalgOp, replacements);
2677 ShapedType inputType = getInputOperandType();
2678 ShapedType outputType = getOutputOperandType();
2683 return emitOpError(
"incompatible output shape");
2685 int64_t inputRank = getInputOperandRank();
2686 int64_t dimension = getDimension();
2687 if ((dimension < 0) || (dimension >= inputRank))
2688 return emitOpError(
"incorrect dimension specified");
2694 int64_t operandRank = getInputOperandRank();
2697 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2698 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2699 Value source = getInput();
2700 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2701 loopBounds[dim].offset = zero;
2702 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2703 loopBounds[dim].stride = one;
2710 utils::IteratorType::parallel);
2711 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2712 return iteratorTypes;
2715 FailureOr<TilingResult>
2719 int64_t rank = getInputOperandRank();
2724 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2726 return emitOpError(
"failed to compute input slice");
2728 tiledOperands.emplace_back(inputSlice->
getResult(0));
2730 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2732 return emitOpError(
"failed to compute output slice");
2734 tiledOperands.emplace_back(outputSlice->
getResult(0));
2737 if (hasPureTensorSemantics())
2738 resultTypes.push_back(tiledOperands[1].
getType());
2740 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2752 if (resultNumber == 0) {
2753 resultOffsets.assign(offsets.begin(), offsets.end());
2754 resultSizes.assign(sizes.begin(), sizes.end());
2769 Location loc = getOperation()->getLoc();
2771 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2772 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2773 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2774 if (!outputShapedType.isDynamicDim(dim)) {
2776 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2783 reifiedReturnShapes.emplace_back(std::move(shapes));
2787 void SoftmaxOp::getEffects(
2791 if (!llvm::isa<MemRefType>(operand.
getType()))
2794 &getOperation()->getOpOperand(index), 0,
2799 for (
OpOperand &operand : getDpsInitsMutable()) {
2800 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2833 int64_t dim,
bool allParallel =
false) {
2835 utils::IteratorType::parallel);
2837 iteratorTypes[dim] = utils::IteratorType::reduction;
2841 for (
int i = 0; i < inputRank; i++) {
2848 return std::make_tuple(iteratorTypes, indexingMaps);
2853 template <
typename T>
2856 auto inputType = cast<ShapedType>(input.
getType());
2858 int64_t inputRank = inputShape.size();
2859 auto [iteratorTypes, indexingMaps] =
2861 assert(indexingMaps.size() == 2 &&
2862 "We should have two maps: 1 for the input, 1 for the output");
2863 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2865 auto genericOp = builder.
create<linalg::GenericOp>(
2866 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2868 Value result = b.create<T>(loc, args[0], args[1]);
2869 b.create<linalg::YieldOp>(loc, result);
2879 auto inputType = cast<ShapedType>(input.
getType());
2881 int64_t inputRank = inputShape.size();
2883 builder, inputRank, dim,
true);
2884 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2885 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2887 indexingMaps.push_back(indexingMaps[0]);
2888 auto genericOp = builder.
create<linalg::GenericOp>(
2891 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2892 Value result = b.create<math::ExpOp>(loc, diff);
2893 b.create<linalg::YieldOp>(loc, result);
2904 Value denominator,
Value output, int64_t dim) {
2905 auto inputType = cast<ShapedType>(numerator.
getType());
2907 int64_t inputRank = inputShape.size();
2909 builder, inputRank, dim,
true);
2910 assert(indexingMaps.size() == 2 &&
2911 "We should have one map for each input (2)");
2912 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2914 indexingMaps.push_back(indexingMaps[0]);
2915 auto genericOp = builder.
create<linalg::GenericOp>(
2917 indexingMaps, iteratorTypes,
2919 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2920 b.create<linalg::YieldOp>(loc, result);
2944 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2948 Value input = getInput();
2949 ShapedType inputType = getInputOperandType();
2950 Type elementType = inputType.getElementType();
2951 int64_t reductionDim = getDimension();
2953 Value output = getOutput();
2954 dims.erase(dims.begin() + reductionDim);
2956 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2958 elementType, b, loc,
2960 Value neutralForMaxFInit =
2961 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2964 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2973 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2975 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2979 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2988 auto filterType = cast<ShapedType>(getFilter().
getType());
2990 int64_t filterH = filterShape[getFilterHDim()];
2991 int64_t filterW = filterShape[getFilterWDim()];
2995 if (filterH != r && filterH != 1)
2996 return emitOpError(
"expect filter height either equals to r or 1");
2997 if (filterW != r && filterW != 1)
2998 return emitOpError(
"expect filter width either equals to r or 1");
2999 if (filterH == 1 && filterW == 1)
3000 return emitOpError(
"expect either filter height or width equals to r");
3003 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3004 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3005 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3006 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3008 auto outputType = cast<ShapedType>(getOutput().
getType());
3011 return emitOpError(
"the output shape is not expected");
3017 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
3021 Value filter = getFilter();
3022 int64_t filterRank = getFilterOperandRank();
3024 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3025 loopBounds[dim].offset = zeroAttr;
3026 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3027 loopBounds[dim].stride = oneAttr;
3033 WinogradFilterTransformOp::getLoopIteratorTypes() {
3034 int64_t filterRank = getFilterOperandRank();
3036 utils::IteratorType::parallel);
3037 return iteratorTypes;
3045 ShapedType filterType = getFilterOperandType();
3047 int64_t filterH = filterShape[getFilterHDim()];
3048 int64_t filterW = filterShape[getFilterWDim()];
3051 int64_t alpha = m + r - 1;
3052 int64_t alphaH = filterH != 1 ? alpha : 1;
3053 int64_t alphaW = filterW != 1 ? alpha : 1;
3057 resultOffsets.append(
3058 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3060 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3076 ShapedType filterType = getFilterOperandType();
3078 int64_t filterH = filterShape[getFilterHDim()];
3079 int64_t filterW = filterShape[getFilterWDim()];
3085 sliceOffsets.append(
3086 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3087 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3088 sizes[getFilterCDim()]});
3089 int64_t filterRank = getFilterOperandRank();
3092 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3093 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3094 tiledOperands.emplace_back(filterSlice);
3101 int64_t outputRank = getOutputOperandRank();
3103 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3104 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3105 tiledOperands.emplace_back(outputSlice);
3108 resultTypes.push_back(tiledOperands[1].
getType());
3110 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3123 auto inputType = cast<ShapedType>(getInput().
getType());
3125 int64_t inputH = inputShape[getInputHDim()];
3126 int64_t inputW = inputShape[getInputWDim()];
3129 int64_t tileSize = m + r - 1;
3131 auto outputType = cast<ShapedType>(getOutput().
getType());
3133 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3134 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3137 if (ShapedType::isDynamic(inputH)) {
3138 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3139 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3141 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3142 expectedOutputShape[getOutputTileHDim()] =
3143 leftTransform ? (inputH - (r - 1)) / m : inputH;
3145 if (ShapedType::isDynamic(inputW)) {
3146 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3147 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3149 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3150 expectedOutputShape[getOutputTileWDim()] =
3151 rightTransform ? (inputW - (r - 1)) / m : inputW;
3153 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3154 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3157 return emitOpError(
"the output shape is not expected");
3163 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3167 Value output = getOutput();
3168 int64_t outputRank = getOutputOperandRank();
3170 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3171 loopBounds[dim].offset = zeroAttr;
3173 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3174 loopBounds[dim].stride = oneAttr;
3180 WinogradInputTransformOp::getLoopIteratorTypes() {
3181 int64_t outputRank = getOutputOperandRank();
3183 utils::IteratorType::parallel);
3184 return iteratorTypes;
3192 ShapedType outputType = getOutputOperandType();
3194 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3195 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3199 int64_t alpha = m + r - 1;
3200 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3201 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3206 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3207 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3208 offsets[getOutputCDim()]});
3209 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3210 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3211 sizes[getOutputCDim()]});
3222 FailureOr<TilingResult>
3230 ShapedType outputType = getOutputOperandType();
3232 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3233 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3237 auto identityAffineMap =
3239 auto offsetAffineMap =
3242 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3243 offsets[getOutputTileHDim()]);
3245 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3246 offsets[getOutputTileWDim()]);
3250 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3252 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3259 sliceOffsets.append(
3260 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3266 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3267 int64_t inputRank = getInputOperandRank();
3269 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3270 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3271 tiledOperands.emplace_back(inputSlice);
3278 int64_t outputRank = getOutputOperandRank();
3280 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3281 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3282 tiledOperands.emplace_back(outputSlice);
3285 resultTypes.push_back(tiledOperands[1].
getType());
3287 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3300 auto valueType = cast<ShapedType>(getValue().
getType());
3302 int64_t valueH = valueShape[getValueAlphaHDim()];
3303 int64_t valueW = valueShape[getValueAlphaWDim()];
3304 int64_t valueTileH = valueShape[getValueTileHDim()];
3305 int64_t valueTileW = valueShape[getValueTileWDim()];
3308 bool leftTransform = valueH != 1;
3309 bool rightTransform = valueW != 1;
3311 int64_t outputRank = getOutputOperandRank();
3313 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3314 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3316 if (valueH != (leftTransform ? m + r - 1 : 1))
3317 return emitOpError(
"expect input height equals to input tile size");
3318 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3320 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3321 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3323 if (valueW != (rightTransform ? m + r - 1 : 1))
3324 return emitOpError(
"expect input width equals to input tile size");
3325 expectedOutputShape[getOutputWDim()] =
3326 (rightTransform ? m : 1) * valueTileW;
3328 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3329 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3331 auto outputType = cast<ShapedType>(getOutput().
getType());
3334 return emitOpError(
"the output shape is not expected");
3340 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3344 Value value = getValue();
3345 int64_t valueRank = getValueOperandRank();
3347 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3348 loopBounds[dim].offset = zeroAttr;
3350 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3351 loopBounds[dim].stride = oneAttr;
3357 WinogradOutputTransformOp::getLoopIteratorTypes() {
3358 int64_t valueRank = getValueOperandRank();
3360 utils::IteratorType::parallel);
3361 return iteratorTypes;
3372 auto identityAffineMap =
3377 ShapedType valueType = getValueOperandType();
3379 int64_t valueH = valueShape[0];
3380 int64_t valueW = valueShape[1];
3382 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3383 offsets[getValueTileHDim()]);
3385 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3386 offsets[getValueTileWDim()]);
3388 builder, loc, affineMap, sizes[getValueTileHDim()]);
3390 builder, loc, affineMap, sizes[getValueTileWDim()]);
3400 resultOffsets.append(
3401 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3403 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3422 ShapedType valueType = getValueOperandType();
3424 int64_t alphaH = valueShape[getValueAlphaHDim()];
3425 int64_t alphaW = valueShape[getValueAlphaWDim()];
3429 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3430 offsets[getValueTileWDim()], offsets[getValueNDim()],
3431 offsets[getValueFDim()]});
3432 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3433 sizes[getValueTileWDim()], sizes[getValueNDim()],
3434 sizes[getValueFDim()]});
3435 int64_t valueRank = getValueOperandRank();
3437 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3438 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3439 tiledOperands.emplace_back(valueSlice);
3446 int64_t outputRank = getOutputOperandRank();
3448 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3449 loc, getOutput(), resultOffsets, resultSizes, strides);
3450 tiledOperands.emplace_back(outputSlice);
3453 resultTypes.push_back(tiledOperands[1].
getType());
3455 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3474 llvm::set_union(explicitSet, defaultSet);
3475 return explicitSet == defaultSet;
3495 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3497 auto opIndexingMap = opIndexingMaps[opIndex];
3498 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3501 return matmulOp->emitOpError()
3502 <<
"Unexpected dim expression in map result.";
3505 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3506 return matmulOp->emitOpError()
3507 <<
"Invalid broadcast requested, should be (d2).";
3516 template <
typename OpTy>
3519 AffineMap defaultIndexingMap,
bool isLHS) {
3520 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3521 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3522 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3525 return batchVariantMatmulOp->emitOpError()
3526 <<
"Unexpected result dim expression (outside the set of default "
3531 return batchVariantMatmulOp->emitOpError()
3532 <<
"no. of result dim expressions exceeds 3.";
3534 auto hasValidBatchDim = [](
AffineMap map) {
3541 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3542 return batchVariantMatmulOp->emitOpError()
3543 <<
"Invalid broadcast requested.";
3544 }
else if (!hasValidBatchDim(opIndexingMap)) {
3545 return batchVariantMatmulOp->emitOpError()
3546 <<
"Invalid batch dimension expression.";
3554 template <
typename OpTy>
3557 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3558 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3559 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3560 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3563 return batchVariantMatmulOp->emitOpError()
3564 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3567 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3569 return batchVariantMatmulOp->emitOpError()
3570 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3574 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3575 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3576 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3577 outputMap.getResult(1).isFunctionOfDim(1) &&
3578 outputMap.getResult(2).isFunctionOfDim(2)
3579 : outputMap.getResult(0).isFunctionOfDim(1) &&
3580 outputMap.getResult(1).isFunctionOfDim(2);
3583 if (!areValidOutputResultDim(opIndexingMap)) {
3584 return batchVariantMatmulOp->emitOpError()
3585 <<
"Invalid output map result dimension.";
3594 template <
typename OpTy>
3595 static LogicalResult
3599 batchVariantMatmulOp.getIndexingMapsArray();
3601 batchVariantMatmulOp.getDefaultIndexingMaps(
3602 batchVariantMatmulOp->getContext());
3604 if (opIndexingMaps.size() != 3)
3605 return batchVariantMatmulOp->emitOpError()
3606 <<
"Indexing_map attribute must have 3 affine maps.";
3608 auto opIndexingMap = opIndexingMaps[opIndex];
3609 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3617 defaultIndexingMap, opIndex == 0)))
3638 return indexingMaps;
3643 utils::IteratorType::parallel,
3644 utils::IteratorType::reduction};
3647 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3649 std::string MatmulOp::getLibraryCallName() {
3653 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3657 bool MatmulOp::hasUserDefinedMaps() {
3661 return defaultMaps != explicitMaps;
3669 "MatmulOp regionBuilder expects 3 (>=0) args");
3670 RegionBuilderHelper helper(b, block);
3673 TypeFn castVal = TypeFn::cast_signed;
3674 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3675 return attr.
getName() ==
"cast";
3677 if (castIter != attrs.end()) {
3678 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3686 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3688 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3689 yields.push_back(value4);
3690 helper.yieldOutputs(yields);
3700 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3701 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3712 ArrayAttr arrayAttr;
3716 if (llvm::any_of(arrayAttr,
3717 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3719 <<
"element of indexing_maps array is not an affine_map";
3726 if (failed(indexingMapsAttr))
3729 if (*indexingMapsAttr ==
nullptr) {
3730 auto indexingMapAttrs = llvm::map_to_vector(
3731 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3736 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3738 MatmulOp::getRegionBuilder());
3743 MatmulOp::getDefaultIndexingMaps(
getContext()),
3745 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3746 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3748 std::array<StringRef, 3> elidedAttrs = {
3749 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3757 if (!hasUserDefinedMaps())
3760 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3771 void MatmulOp::getEffects(
3774 if (hasPureTensorSemantics())
3788 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3798 for (
auto result : outAffineMap.
getResults()) {
3799 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3800 assert(dimExpr &&
"affine_map is a projected permutation");
3801 dimsInOutput[dimExpr.getPosition()] =
true;
3805 for (
auto dimOccursInOutput : dimsInOutput)
3806 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3807 : utils::IteratorType::reduction);
3809 return iteratorTypes;
3812 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3818 "ContractOp regionBuilder expects 3 args");
3819 RegionBuilderHelper helper(b, block);
3821 TypeFn castSignedness = TypeFn::cast_signed;
3822 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3823 return attr.
getName() ==
"cast";
3825 if (castIter != attrs.end()) {
3826 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3832 Value lhsAtOutType =
3833 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3834 Value rhsAtOutType =
3835 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3836 Value productAtOutType =
3837 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3840 helper.yieldOutputs({result});
3845 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3847 "expected 'indexing_maps' attribute");
3848 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3855 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3857 p, getOperation(), getInputs(), getOutputs(),
3858 {
"indexing_maps",
"operandSegmentSizes"});
3862 int iterationSpaceDims = -1;
3871 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3872 bool isInput) -> LogicalResult {
3875 return emitError(
"provided affine_map is not a projected permutation");
3878 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3880 return emitError(
"ranks of shaped operand and results of corresponding "
3881 "affine_map differ");
3883 return emitError(
"affine_map specifies shaped access while operand has "
3888 if (iterationSpaceDims == -1) {
3892 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
3893 return emitError(
"iteration spaces of provided affine_maps differ");
3898 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3900 llvm_unreachable(
"affine_map is a projected permutation");
3903 inOccurrences[affineDimExpr.getPosition()] += 1;
3905 outOccurrences[affineDimExpr.getPosition()] += 1;
3911 for (
auto &&[affineMap, operandType, isInput] :
3912 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3914 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3918 bool hasContractingDim =
false;
3919 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3920 size_t inOccCount = inOccurrences[dimIndex];
3921 size_t outOccCount = outOccurrences[dimIndex];
3924 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3926 if (inOccCount == 0 && outOccCount == 0)
3927 return emitError() <<
"iteration space dim at index " << dimIndex
3928 <<
" not used to access any operand";
3939 if (inOccCount == 1 && outOccCount != 1)
3941 <<
"iteration space dim at index " << dimIndex
3942 <<
" is neither a contracting dim nor of parallel iteration type";
3945 if (!hasContractingDim)
3946 return emitError(
"'indexing_maps' do not specify a contracting dimension");
3955 void ContractOp::getEffects(
3958 if (hasPureTensorSemantics())
3971 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
3975 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
3976 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
3977 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
3978 return indexingMaps;
3983 utils::IteratorType::parallel, utils::IteratorType::parallel,
3984 utils::IteratorType::parallel, utils::IteratorType::reduction};
3987 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
3989 std::string BatchMatmulOp::getLibraryCallName() {
3995 bool BatchMatmulOp::hasUserDefinedMaps() {
3999 return defaultMaps != explicitMaps;
4009 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
4011 "Expected less than 3 result dim expr.");
4012 bool isValid =
false;
4013 enum Indices { batchPos, mPos, nPos, kPos };
4034 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
4035 RegionBuilderHelper helper(b, block);
4038 TypeFn castVal = TypeFn::cast_signed;
4039 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4040 return attr.
getName() ==
"cast";
4042 if (castIter != attrs.end()) {
4043 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4048 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4049 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4050 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4052 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4053 yields.push_back(addVal);
4054 helper.yieldOutputs(yields);
4070 if (!isa<AffineMapAttr>(mapAttr)) {
4072 "expected affine map attribute");
4074 indexingMapsAttr.push_back(mapAttr);
4084 if (indexingMapsAttr.empty()) {
4085 indexingMapsAttr = llvm::map_to_vector(
4086 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4093 BatchMatmulOp::getNumRegionArgs(),
4094 BatchMatmulOp::getRegionBuilder());
4099 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4101 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4102 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4104 std::array<StringRef, 3> elidedAttrs = {
4105 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4114 if (!hasUserDefinedMaps())
4117 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4124 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4129 void BatchMatmulOp::getEffects(
4132 if (hasPureTensorSemantics())
4146 struct ArityGroupAndKind {
4158 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4164 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4165 constexpr
int lastBinary =
4166 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4167 constexpr
int lastTernary =
4168 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4170 int val =
static_cast<int>(
kind);
4171 ArityGroupAndKind result;
4173 if (val < lastUnary) {
4174 result.arityGroup = ElementwiseArityGroup::Unary;
4175 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4178 if (val < lastBinary) {
4179 result.arityGroup = ElementwiseArityGroup::Binary;
4180 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4183 if (val >= lastTernary) {
4184 llvm_unreachable(
"unhandled ElementwiseFn");
4186 result.arityGroup = ElementwiseArityGroup::Ternary;
4187 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4192 auto rank = getResultRank();
4197 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4206 mlir::linalg::ElementwiseKind elemwiseKindVal;
4211 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4212 if (!elemwiseKindAttr)
4214 "expected ElementwiseKind attribute");
4215 elemwiseKindVal = elemwiseKindAttr.getValue();
4218 "expected operation 'kind' attribute");
4234 if (!isa<AffineMapAttr>(mapAttr))
4236 "expected affine map attribute");
4237 indexingMapsAttr.push_back(mapAttr);
4248 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4250 ElementwiseOp::getRegionBuilder())) {
4252 "unable to parse elemwise op");
4256 if (indexingMapsAttr.empty()) {
4260 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4263 "return type needs to be shaped type");
4264 auto numDims = shapedType.getRank();
4265 indexingMapsAttr = llvm::map_to_vector(
4266 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4283 unsigned numDims = getResultRank();
4286 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4290 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4291 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4308 ElementwiseKind elemwiseKind;
4309 for (
auto attr : attrs) {
4311 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4312 assert(kindAttr &&
"op kind attribute incorrectly set");
4313 elemwiseKind = kindAttr.getValue();
4320 auto kind = groupAndKind.kind;
4323 &&
"Elementwise regionBuilder number of block args mismatch");
4325 RegionBuilderHelper helper(b, block);
4329 if (
arityGroup == ElementwiseArityGroup::Unary) {
4332 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4336 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4341 assert(
false &&
"found unhandled category in elemwise");
4344 yields.push_back(result);
4345 helper.yieldOutputs(yields);
4348 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4353 void ElementwiseOp::getEffects(
4356 if (hasPureTensorSemantics())
4379 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4381 .take_back(mixedTiles.size()),
4383 int64_t shape = std::get<0>(it);
4384 if (shape == ShapedType::kDynamic) {
4385 newMixedTileSizes.push_back(std::get<1>(it));
4392 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4394 newMixedTileSizes.push_back(
tile);
4397 "tile size and dim size don't match!");
4398 newMixedTileSizes.push_back(
4403 return newMixedTileSizes;
4406 template <
typename OpTy>
4407 static LogicalResult
4410 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4411 "applies to only pack or unpack operations");
4412 int64_t destRank = op.getDestRank();
4414 reifiedReturnShapes[0] =
4419 template <
typename OpTy>
4421 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4422 "applies to only pack or unpack operations");
4426 assert(tiles.size() == dimsToTile.size() &&
4427 "tiles must match indices of dimension to block");
4429 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4430 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4431 return dimAndTileMapping;
4434 template <
typename OpTy>
4436 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4437 "applies to only pack or unpack operations");
4440 unsigned dynamicValIndex = 0;
4441 for (int64_t staticTile : op.getStaticInnerTiles()) {
4442 if (!ShapedType::isDynamic(staticTile))
4445 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4447 return mixedInnerTiles;
4450 template <
typename OpTy>
4452 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4453 "applies to only pack or unpack operations");
4466 size_t dimsPosSize = dimsPos.size();
4467 if (dimsPosSize > rank)
4470 if (dimsPosSize != uniqued.size())
4472 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4473 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4482 sourceShape.size() == limitShape.size() &&
4483 "expected source shape rank, and limit of the shape to have same rank");
4484 return llvm::all_of(
4485 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4486 int64_t sourceExtent = std::get<0>(it);
4487 int64_t limit = std::get<1>(it);
4488 return ShapedType::isDynamic(sourceExtent) ||
4489 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4493 template <
typename OpTy>
4495 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4496 "applies to only pack or unpack operations");
4497 Operation *op = packOrUnPack.getOperation();
4506 if (hasZeros(mixedTiles))
4507 return op->
emitError(
"invalid zero tile factor");
4510 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4511 ? packOrUnPack.getSourceType()
4512 : packOrUnPack.getDestType();
4513 size_t unpackedRank = unpackedType.getRank();
4517 return op->
emitError(
"invalid inner_dims_pos vector");
4519 return op->
emitError(
"invalid outer_dims_perm vector");
4520 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4521 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4525 if (mixedTiles.size() > unpackedRank) {
4526 return op->
emitError(
"tiling factors must be less than or equal to the "
4527 "input rank for pack or output rank for unpack");
4531 "tiling factors must equal the number of dimensions to tile");
4534 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4535 ? packOrUnPack.getDestType()
4536 : packOrUnPack.getSourceType();
4537 size_t packedRank = packedType.getRank();
4539 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4540 if (expectedPackedRank != packedRank) {
4542 "packed rank != (unpacked rank + num tiling factors), got ")
4543 << packedRank <<
" != " << expectedPackedRank;
4549 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4550 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4551 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4552 return op->
emitError(
"the shape of output is not large enough to hold the "
4553 "packed data. Expected at least ")
4554 << expectedPackedType <<
", got " << packedType;
4557 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4559 [](std::tuple<int64_t, OpFoldResult> it) {
4560 int64_t shape = std::get<0>(it);
4561 if (Attribute attr =
4562 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4563 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4564 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4565 return shape == staticTileSize;
4567 return ShapedType::isDynamic(shape);
4569 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4570 "tiled dimension in the packed type");
4582 struct PackOrUnPackTransposeResult {
4589 template <
typename OpTy>
4590 static PackOrUnPackTransposeResult
4594 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4595 "applies to only pack or unpack operations");
4596 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4597 "some permutation must be non-empty");
4598 PackOrUnPackTransposeResult metadata;
4599 metadata.innerDimsPos =
4601 metadata.innerTiles =
4603 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4604 ? packOrUnPackOp.getSourceRank()
4605 : packOrUnPackOp.getDestRank();
4606 metadata.outerDimsPerm =
4607 packOrUnPackOp.getOuterDimsPerm().empty()
4608 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4610 if (!innerPermutation.empty()) {
4611 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4613 "invalid inner permutation");
4617 if (!outerPermutation.empty()) {
4618 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4620 "invalid outer permutation");
4630 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4631 setNameFn(getResult(),
"pack");
4637 std::optional<Value> paddingValue,
4640 "number of tile sizes specified must match the specified number of "
4641 "original dimensions to be tiled");
4645 build(builder, state, dest.
getType(), source, dest,
4646 paddingValue ? *paddingValue :
nullptr,
4672 ShapedType inputType = getSourceType();
4673 int64_t inputRank = inputType.getRank();
4674 return getDestType().getShape().take_front(inputRank);
4679 auto packedShape = getDestType().getShape();
4683 res.push_back(packedShape[index]);
4694 outputShape.take_front(inputShape.size()));
4697 "expected output and outer_dims_perm to have same size");
4702 if (ShapedType::isDynamic(inputShape[pos]))
4706 if (!constantTile) {
4707 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4708 (inputShape[pos] % outputTileSizes[pos] != 0))
4710 }
else if (inputShape[pos] % (*constantTile) != 0) {
4724 auto paddingValue = getPaddingValue();
4727 return emitOpError(
"expected padding_value has ")
4728 << getSourceType().getElementType()
4729 <<
" but got: " << paddingValue.getType();
4732 if (!paddingValue &&
4733 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4734 getDestType().
getShape(), getOuterDimsPerm(),
4737 "invalid tile factor or output size provided. Only full tiles are "
4738 "supported when padding_value is not set");
4748 for (
auto o : ofrs) {
4750 if (llvm::dyn_cast_if_present<Value>(o))
4751 result.push_back(ShapedType::kDynamic);
4766 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4768 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4769 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4772 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4773 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4781 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4796 builder, loc, ceilDivExpr,
4797 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4801 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4812 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4813 if (!ShapedType::isDynamic(resultTypeShape[i]))
4824 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4846 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4847 if (ShapedType::isDynamic(value))
4848 mixedSizes.push_back(
4853 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
4854 int64_t dimPos = std::get<0>(it);
4856 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4859 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
4861 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4862 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4863 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4870 *
this, innerPermutation, outerPermutation);
4871 Value transposedDest =
4872 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4873 metadata.innerDimsPos, metadata.outerDimsPerm);
4874 return b.
create<PackOp>(loc, getSource(), transposedDest,
4875 metadata.innerDimsPos, metadata.innerTiles,
4876 getPaddingValue(), metadata.outerDimsPerm);
4880 template <
typename OpTy>
4882 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4883 "applies to only pack or unpack operations");
4884 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4886 : op.getSourceType();
4888 for (
auto [dimDest,
tile] : llvm::zip(
4889 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4891 if (!constTileSize || ShapedType::isDynamic(dimDest))
4898 if (getPaddingValue())
4913 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4915 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4927 auto packTiles = packOp.getMixedTiles();
4928 auto unPackTiles = unPackOp.getMixedTiles();
4929 if (packTiles.size() != unPackTiles.size())
4931 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4940 auto srcType = op.getSourceType();
4941 if (llvm::any_of(op.getInnerDimsPos(),
4942 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4944 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4946 return !PackOp::requirePaddingValue(
4947 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4948 op.getOuterDimsPerm(), op.getMixedTiles());
4955 bool changeNeeded =
false;
4956 srcShape.assign(packOp.getSourceType().getShape().begin(),
4957 packOp.getSourceType().getShape().end());
4958 destShape.assign(packOp.getDestType().getShape().begin(),
4959 packOp.getDestType().getShape().end());
4960 llvm::SmallSetVector<int64_t, 4> innerDims;
4961 innerDims.insert_range(packOp.getInnerDimsPos());
4963 if (!packOp.getOuterDimsPerm().empty())
4965 int srcRank = packOp.getSourceRank();
4966 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4967 if (innerDims.contains(i))
4970 int64_t destPos = i;
4971 if (!inverseOuterDimsPerm.empty())
4972 destPos = inverseOuterDimsPerm[srcPos];
4973 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4974 ShapedType::isDynamic(destShape[destPos])) {
4977 int64_t size = srcShape[srcPos];
4978 if (ShapedType::isDynamic(size))
4979 size = destShape[destPos];
4980 srcShape[srcPos] = size;
4981 destShape[destPos] = size;
4982 changeNeeded =
true;
4984 return changeNeeded;
4987 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4989 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4990 if (unPackOp.getSourceType() != packOp.getDestType())
4992 if (packOp.getPaddingValue() ||
4996 rewriter.
replaceOp(packOp, unPackOp.getSource());
5003 packOp.getPaddingValueMutable().clear();
5012 Value source = packOp.getSource();
5013 if (srcShape != packOp.getSourceType().getShape()) {
5014 auto newSrcType = packOp.getSourceType().clone(srcShape);
5016 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
5018 Value dest = packOp.getDest();
5019 RankedTensorType originalResultType = packOp.getDestType();
5020 bool needUpdateDestType = (destShape != originalResultType.getShape());
5021 if (needUpdateDestType) {
5022 auto newDestType = packOp.getDestType().clone(destShape);
5024 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5027 packOp.getSourceMutable().assign(source);
5028 packOp.getDestMutable().assign(dest);
5029 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5032 if (needUpdateDestType) {
5035 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
5044 template <
typename PackOrUnpackOp>
5046 RankedTensorType packedTensorType) {
5047 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5048 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5049 "Function meant for pack/unpack");
5055 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5062 int64_t packedRank = packedTensorType.getRank();
5072 return llvm::all_of(
5073 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5074 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5077 bool PackOp::isLikePad() {
5078 auto packedTensorType =
5079 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5084 std::optional<Attribute> paddingValue;
5085 if (
auto pad = adaptor.getPaddingValue())
5087 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5088 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5089 getDestType(), paddingValue))
5090 return reshapedSource;
5128 PackOp newOp = rewriter.
create<PackOp>(
5129 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5130 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5134 Value oldResult = op.getResult();
5135 Value newResult = newOp.getResult();
5137 ? rewriter.
create<tensor::CastOp>(
5138 op->getLoc(), oldResult.
getType(), newResult)
5151 void UnPackOp::getAsmResultNames(
5153 setNameFn(getResult(),
"unpack");
5175 ShapedType destType = getDestType();
5176 int64_t destRank = destType.getRank();
5177 return getSourceType().getShape().take_front(destRank);
5182 auto packedShape = getSourceType().getShape();
5186 res.push_back(packedShape[index]);
5208 "number of tile sizes specified must match the specified number of "
5209 "original dimensions to be tiled");
5213 build(builder, state, dest.
getType(), source, dest,
5232 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5234 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5235 if (srcType.isDynamicDim(i))
5236 mixedSizes.push_back(b.
create<tensor::DimOp>(loc, source, i).
getResult());
5238 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5241 applyPermutationToVector<OpFoldResult>(
5245 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5246 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5248 auto elemType = srcType.getElementType();
5249 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5253 Value transposedSource,
5257 *
this, innerPermutation, outerPermutation);
5258 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
5259 metadata.innerDimsPos, metadata.innerTiles,
5260 metadata.outerDimsPerm);
5267 bool changeNeeded =
false;
5268 srcShape.assign(op.getSourceType().getShape().begin(),
5269 op.getSourceType().getShape().end());
5270 destShape.assign(op.getDestType().getShape().begin(),
5271 op.getDestType().getShape().end());
5272 llvm::SmallSetVector<int64_t, 4> innerDims;
5273 innerDims.insert_range(op.getInnerDimsPos());
5275 if (!op.getOuterDimsPerm().empty())
5277 int destRank = op.getDestRank();
5278 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5279 if (innerDims.contains(i))
5282 int64_t destPos = i;
5283 if (!inverseOuterDimsPerm.empty())
5284 srcPos = inverseOuterDimsPerm[destPos];
5285 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5286 ShapedType::isDynamic(destShape[destPos])) {
5289 int64_t size = srcShape[srcPos];
5290 if (ShapedType::isDynamic(size))
5291 size = destShape[destPos];
5292 srcShape[srcPos] = size;
5293 destShape[destPos] = size;
5294 changeNeeded =
true;
5296 return changeNeeded;
5299 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5302 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5303 if (packOp.getSourceType() != unPackOp.getDestType())
5305 if (packOp.getPaddingValue() ||
5309 rewriter.
replaceOp(unPackOp, packOp.getSource());
5313 if (
auto dstStyleOp =
5314 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5315 auto destValue = cast<OpResult>(unPackOp.getDest());
5316 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5318 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5322 if (unPackOp->hasOneUse()) {
5323 auto extractSliceUser =
5324 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5325 if (extractSliceUser &&
5328 extractSliceUser.getSourceType().getRank() ==
5329 extractSliceUser.getResultType().getRank()) {
5332 auto newDest = rewriter.
create<tensor::ExtractSliceOp>(
5333 unPackOp->getLoc(), unPackOp.getDest(),
5334 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5335 extractSliceUser.getMixedStrides());
5337 unPackOp.setDpsInitOperand(0, newDest);
5338 unPackOp.getResult().setType(newDest.
getType());
5340 rewriter.
replaceOp(extractSliceUser, unPackOp);
5349 Value source = unPackOp.getSource();
5350 if (srcShape != unPackOp.getSourceType().getShape()) {
5351 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5352 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
5353 unPackOp.getSource());
5355 Value dest = unPackOp.getDest();
5356 if (destShape != unPackOp.getDestType().getShape()) {
5357 auto newDestType = unPackOp.getDestType().clone(destShape);
5359 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5362 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5363 unPackOp.getOuterDimsPerm());
5365 unPackOp, unPackOp.getResult().getType(), newOp);
5372 bool UnPackOp::isLikeUnPad() {
5373 RankedTensorType packedTensorType = getSourceType();
5378 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5379 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5381 return reshapedSource;
5410 Value sourceTensor = newOperands[0];
5414 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5420 UnPackOp newOp = rewriter.
create<UnPackOp>(
5421 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5422 newMixedTileSizes, op.getOuterDimsPerm());
5426 Value oldResult = op.getResult();
5427 Value newResult = newOp.getResult();
5429 ? rewriter.
create<tensor::CastOp>(
5430 op->getLoc(), oldResult.
getType(), newResult)
5444 utils::IteratorType::reduction, utils::IteratorType::parallel,
5445 utils::IteratorType::parallel, utils::IteratorType::reduction};
5449 BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
5453 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
5454 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
5456 return indexingMaps;
5459 unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
5461 std::string BatchReduceMatmulOp::getLibraryCallName() {
5467 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5471 return defaultMaps != explicitMaps;
5481 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
5484 "Expected less than 3 result dim expr.");
5485 bool isValid =
false;
5486 enum Indices { batchPos, mPos, nPos, kPos };
5507 "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
5508 RegionBuilderHelper helper(b, block);
5513 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
5515 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
5516 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5518 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
5519 yields.push_back(addVal);
5520 helper.yieldOutputs(yields);
5536 if (!isa<AffineMapAttr>(mapAttr)) {
5538 "expected affine map attribute");
5540 indexingMapsAttr.push_back(mapAttr);
5550 if (indexingMapsAttr.empty()) {
5551 indexingMapsAttr = llvm::map_to_vector(
5552 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
5558 BatchReduceMatmulOp::getNumRegionArgs(),
5559 BatchReduceMatmulOp::getRegionBuilder());
5564 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
5567 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5568 p <<
" indexing_maps = [";
5569 llvm::interleaveComma(getIndexingMaps(), p,
5575 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
5584 if (!hasUserDefinedMaps())
5587 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
5593 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5597 void BatchReduceMatmulOp::getEffects(
5600 if (hasPureTensorSemantics())
5616 void LinalgDialect::getCanonicalizationPatterns(
5625 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override