MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
33#include "mlir/IR/Matchers.h"
40
41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
53#include <cassert>
54#include <optional>
55
56using namespace mlir;
57using namespace mlir::linalg;
58
59/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
61 int64_t dim) {
62 auto type = cast<ShapedType>(v.getType());
63 if (!type.isDynamicDim(dim))
64 return builder.getIndexAttr(type.getDimSize(dim));
65
66 return getAsOpFoldResult(
68 .Case([&](RankedTensorType t) -> Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
70 })
71 .Case([&](MemRefType t) -> Value {
72 return memref::DimOp::create(builder, loc, v, dim);
73 }));
74}
75
76/// Returns a memref.subview or a tensor.extract_slice based on the type of the
77/// `source`.
81 ArrayRef<OpFoldResult> strides) {
83 .Case([&](RankedTensorType t) -> Operation * {
84 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
85 strides);
86 })
87 .Case([&](MemRefType type) -> Operation * {
88 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
89 strides);
90 })
91 .Default([&](Type t) -> Operation * { return nullptr; });
92}
93
94static std::optional<TypedAttr>
96 DenseElementsAttr splatAttr;
98 if (!splatAttr || !splatAttr.isSplat())
99 return std::nullopt;
100
101 return splatAttr.getSplatValue<TypedAttr>();
102}
103
104//===----------------------------------------------------------------------===//
105// Helper functions
106//===----------------------------------------------------------------------===//
107
109 int64_t dim) {
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable("Expected MemRefType or TensorType");
115}
116
118 int64_t dim) {
119 auto shapedType = llvm::cast<ShapedType>(source.getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
121 return createOrFoldDimOp(b, loc, source, dim);
122 return b.getIndexAttr(shapedType.getDimSize(dim));
123}
124
125//===----------------------------------------------------------------------===//
126// Support for named Linalg ops defined in ods-gen.
127//===----------------------------------------------------------------------===//
128
132
133/// Fills the region of a structured operation using the provided
134/// `regionBuilder`. The method is used by both named structured ops created by
135/// ods-gen and by manually defined C++ ops. It is called by both builders and
136/// parsers and creates a block with arguments corresponding to the elemental
137/// types of `inputTypes` and `outputTypes`.
138static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
139 TypeRange inputTypes, TypeRange outputTypes,
142 RegionBuilderFn regionBuilder) {
143 SmallVector<Type, 8> argTypes;
145 for (auto containers : {inputTypes, outputTypes}) {
146 for (auto t : containers) {
147 argTypes.push_back(
148 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
149
150 // TODO: Pass in a proper location here.
151 argLocs.push_back(opBuilder.getUnknownLoc());
152 }
153 }
154
155 // RAII.
156 OpBuilder::InsertionGuard guard(opBuilder);
157 Block *body =
158 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
159
160 opBuilder.setInsertionPointToStart(body);
161 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
162 regionBuilder(b, *body, attrs, emitError);
163
164 // indexing_maps is an auto-generated method.
165
166 // iterator_types is an auto-generated method.
167}
168
169/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
170/// The result types are derived automatically if `resultTensorTypes` is none.
171/// The body of the operation is filled using `regionBuilder`. All ods-gen
172/// created structured operations use the method to implement their builders.
174 std::optional<TypeRange> resultTensorTypes,
175 ValueRange inputs, ValueRange outputs,
176 ArrayRef<NamedAttribute> attributes,
177 RegionBuilderFn regionBuilder) {
178 // Derive the result types if needed.
179 SmallVector<Type> derivedResultTypes =
180 resultTensorTypes.value_or(TypeRange());
181 if (!resultTensorTypes)
182 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
184
185 state.addOperands(inputs);
186 state.addOperands(outputs);
187 state.addTypes(derivedResultTypes);
188
189 state.addAttributes(attributes);
190 state.addAttribute(
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
194
195 // Create and fill the region of the structured operation.
196 Region &region = *state.addRegion();
197 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
198 state.attributes.getAttrs(), /*emitError=*/{},
199 regionBuilder);
200}
201
203 std::optional<TypeRange> resultTensorTypes,
204 ValueRange inputs, ValueRange outputs,
205 ArrayRef<NamedAttribute> attributes,
206 RegionBuilderFn regionBuilder,
207 ArrayRef<AffineMap> defaultIndexingMaps) {
208 // If indexing maps are not provided, apply the default ones.
209 if (none_of(attributes, [](NamedAttribute attr) {
210 return attr.getName() == "indexing_maps";
211 })) {
212 SmallVector<Attribute, 3> indexingMapsAttrVal;
213 indexingMapsAttrVal = llvm::map_to_vector(
214 defaultIndexingMaps,
215 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
216 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217 }
218 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
219 attributes, regionBuilder);
220}
221
223 std::optional<TypeRange> resultTensorTypes,
224 ValueRange inputs, ValueRange outputs,
225 ArrayRef<NamedAttribute> attributes,
226 RegionBuilderFn regionBuilder,
227 ArrayRef<AffineMap> defaultIndexingMaps) {
228 // If indexing maps are not provided, apply the default ones.
229 if (none_of(attributes, [](NamedAttribute attr) {
230 return attr.getName() == "indexing_maps";
231 })) {
232 SmallVector<Attribute, 4> indexingMapsAttrVal;
233 indexingMapsAttrVal = llvm::map_to_vector(
234 defaultIndexingMaps,
235 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
236 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
237 }
238 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
239 attributes, regionBuilder);
240}
241
243 std::optional<TypeRange> resultTensorTypes,
244 ValueRange inputs, ValueRange outputs,
245 ArrayRef<NamedAttribute> attributes,
246 RegionBuilderFn regionBuilder,
247 ArrayRef<AffineMap> indexingMaps) {
248 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
249 SmallVector<Attribute, 4> indexingMapsAttrVal;
250 indexingMapsAttrVal =
251 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
252 return AffineMapAttr::get(map);
253 });
254 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
255 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
256 attributes, regionBuilder);
257}
258
259/// Common parsing used for both named structured ops created by ods-gen and by
260/// manually defined C++ ops. Does not handle regions.
261static ParseResult
263 SmallVectorImpl<Type> &inputTypes,
264 SmallVectorImpl<Type> &outputTypes,
265 bool addOperandSegmentSizes = true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
268 outputsOperands;
269
270 if (succeeded(parser.parseOptionalLess())) {
271 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
272 return failure();
273 }
274 attrsLoc = parser.getCurrentLocation();
275 if (parser.parseOptionalAttrDict(result.attributes))
276 return failure();
277
278 if (succeeded(parser.parseOptionalKeyword("ins"))) {
279 if (parser.parseLParen())
280 return failure();
281
282 inputsOperandsLoc = parser.getCurrentLocation();
283 if (parser.parseOperandList(inputsOperands) ||
284 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
285 return failure();
286 }
287
288 if (succeeded(parser.parseOptionalKeyword("outs"))) {
289 outputsOperandsLoc = parser.getCurrentLocation();
290 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
291 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
292 return failure();
293 }
294
295 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
296 result.operands) ||
297 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
298 result.operands))
299 return failure();
300
301 if (addOperandSegmentSizes) {
302 // This is a bit complex because we're trying to be backward compatible with
303 // operation syntax that mix the inherent attributes and the discardable
304 // ones in the same dictionary. If the properties are used, we append the
305 // operandSegmentSizes there directly. Otherwise we append it to the
306 // discardable attributes dictionary where it is handled by the generic
307 // Operation::create(...) method.
308 if (result.propertiesAttr) {
309 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
310 attrs.append("operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
314 result.propertiesAttr = attrs.getDictionary(parser.getContext());
315 } else {
316 result.addAttribute("operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
320 }
321 }
322 if (!result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
325 if (info) {
326 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 << "'" << result.name.getStringRef() << "' op ";
329 })))
330 return failure();
331 }
332 }
333 return success();
334}
335
337 ValueRange outputs) {
338 if (!inputs.empty())
339 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
340 if (!outputs.empty())
341 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
342}
343
344//===----------------------------------------------------------------------===//
345// Specific parsing and printing for named structured ops created by ods-gen.
346//===----------------------------------------------------------------------===//
347
349 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
350 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
351 RegionBuilderFn regionBuilder, SMLoc loc) {
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
353 return parser.emitError(
354 parser.getCurrentLocation(),
355 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
358 }
359
360 OpBuilder opBuilder(parser.getContext());
361 ParseResult result = success();
363 opBuilder, region, inputTypes, outputTypes, attrs,
364 [&]() {
365 result = failure();
366 return parser.emitError(loc);
367 },
368 regionBuilder);
369 return result;
370}
371
372static ParseResult
374 SmallVectorImpl<Type> &resultTypes) {
375 if (parser.parseOptionalArrowTypeList(resultTypes))
376 return failure();
377 return success();
378}
379
380static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
382 unsigned numRegionArgs,
383 RegionBuilderFn regionBuilder) {
384 // TODO: Enable when ods-gen supports captures.
385 SmallVector<Type, 1> inputTypes, outputTypes;
386 SMLoc loc = parser.getCurrentLocation();
387 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
388 return failure();
389
390 // Parse optional attributes.
391 if (parser.parseOptionalAttrDict(result.attributes))
392 return failure();
393
394 // TODO: consider merging results parsing into region parsing.
395 // Need to wait for declarative assembly resolution to decide.
396 SmallVector<Type, 1> outputTensorsTypes;
397 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
398 return failure();
399 result.addTypes(outputTensorsTypes);
400
401 std::unique_ptr<Region> region = std::make_unique<Region>();
402 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
403 outputTypes, result.attributes.getAttrs(),
404 regionBuilder, loc))
405 return failure();
406 result.addRegion(std::move(region));
407
408 return success();
409}
410
412 TypeRange resultTypes) {
413 if (resultTypes.empty())
414 return;
415 p.printOptionalArrowTypeList(resultTypes);
416}
417
419 ValueRange inputs, ValueRange outputs,
420 ArrayRef<StringRef> elidedAttrs = {}) {
421 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
422
423 // Printing is shared with generic ops, except for the region and
424 // attributes.
425 printCommonStructuredOpParts(p, inputs, outputs);
426
427 // Results printing.
429
430 // Region is elided.
431}
432
433//===----------------------------------------------------------------------===//
434// Region builder helper.
435// TODO: Move this to a utility library.
436// The public methods on this class are referenced directly from generated code.
437// Helper build the unary, binary, and type conversion functions defined by the
438// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
439// class.
440//
441// Implementations of the math functions must be polymorphic over numeric types,
442// internally performing necessary casts. If the function application makes no
443// sense, then the only recourse is to assert and return nullptr. This can be
444// extended later if it becomes possible to fail construction of the region. The
445// invariant should be enforced at a higher level.
446//
447// TODO: These helpers are currently type polymorphic over the class of integer
448// and floating point types, but they will not internally cast within bit
449// widths of a class (mixed precision such as i8->i32) or across classes
450// (i.e. mixed float and integer). Many such combinations are ambiguous or need
451// to be handled with care and work is being considered to extend the op
452// language to make such cases explicit. In the mean-time, violating this will
453// fail verification, which is deemed acceptable.
454//===----------------------------------------------------------------------===//
455
456namespace {
457
458class RegionBuilderHelper {
459public:
460 RegionBuilderHelper(OpBuilder &builder, Block &block)
461 : builder(builder), block(block) {}
462
463 // Build the unary functions defined by OpDSL.
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
465 function_ref<InFlightDiagnostic()> emitError = {}) {
466 if (!isFloatingPoint(arg)) {
467 if (emitError) {
468 emitError() << "unsupported non numeric type";
469 return nullptr;
470 }
471 llvm_unreachable("unsupported non numeric type");
472 }
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
475 switch (unaryFn) {
476 case UnaryFn::exp:
477 return math::ExpOp::create(builder, arg.getLoc(), arg);
478 case UnaryFn::log:
479 return math::LogOp::create(builder, arg.getLoc(), arg);
480 case UnaryFn::abs:
481 return math::AbsFOp::create(builder, arg.getLoc(), arg);
482 case UnaryFn::ceil:
483 return math::CeilOp::create(builder, arg.getLoc(), arg);
484 case UnaryFn::floor:
485 return math::FloorOp::create(builder, arg.getLoc(), arg);
486 case UnaryFn::negf:
487 return arith::NegFOp::create(builder, arg.getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.getType());
490 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
493 }
494 case UnaryFn::round:
495 return math::RoundOp::create(builder, arg.getLoc(), arg);
496 case UnaryFn::sqrt:
497 return math::SqrtOp::create(builder, arg.getLoc(), arg);
498 case UnaryFn::rsqrt:
499 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
502 case UnaryFn::tanh:
503 return math::TanhOp::create(builder, arg.getLoc(), arg);
504 case UnaryFn::erf:
505 return math::ErfOp::create(builder, arg.getLoc(), arg);
506 case UnaryFn::sin:
507 return math::SinOp::create(builder, arg.getLoc(), arg);
508 case UnaryFn::cos:
509 return math::CosOp::create(builder, arg.getLoc(), arg);
510 case UnaryFn::tan:
511 return math::TanOp::create(builder, arg.getLoc(), arg);
512 }
513 if (emitError) {
514 emitError() << "unsupported unary function";
515 return nullptr;
516 }
517 llvm_unreachable("unsupported unary function");
518 }
519
520 // Build the binary functions defined by OpDSL.
521 // If emitError is provided, an error will be emitted if the operation is not
522 // supported and a nullptr will be returned, otherwise an assertion will be
523 // raised.
524 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
525 function_ref<InFlightDiagnostic()> emitError = {}) {
526 bool allComplex = isComplex(arg0) && isComplex(arg1);
527 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
528 bool allInteger = isInteger(arg0) && isInteger(arg1);
529 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
530 arg1.getType().getIntOrFloatBitWidth() == 1;
531 if (!allComplex && !allFloatingPoint && !allInteger) {
532 if (emitError) {
533 emitError()
534 << "Cannot build binary Linalg operation: expects allComplex, "
535 "allFloatingPoint, or allInteger, got "
536 << arg0.getType() << " and " << arg1.getType();
537 return nullptr;
538 }
539 llvm_unreachable("unsupported non numeric type");
540 }
541 OpBuilder::InsertionGuard g(builder);
542 builder.setInsertionPointToEnd(&block);
543 switch (binaryFn) {
544 case BinaryFn::add:
545 if (allComplex)
546 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
547 if (allFloatingPoint)
548 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
549 if (allBool)
550 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
551 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
552 case BinaryFn::sub:
553 if (allComplex)
554 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
555 if (allFloatingPoint)
556 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
557 if (allBool) {
558 if (emitError) {
559 emitError() << "unsupported operation: sub with bools";
560 return nullptr;
561 }
562 llvm_unreachable("unsupported operation: sub with bools");
563 }
564 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
565 case BinaryFn::mul:
566 if (allComplex)
567 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
568 if (allFloatingPoint)
569 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
570 if (allBool)
571 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
572 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
573 case BinaryFn::div:
574 if (allComplex)
575 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
576 if (allFloatingPoint)
577 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
578 if (allBool) {
579 if (emitError) {
580 emitError() << "unsupported operation: div with bools";
581 return nullptr;
582 }
583 llvm_unreachable("unsupported operation: div with bools");
584 }
585 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
586 case BinaryFn::div_unsigned:
587 if (!allInteger || allBool) {
588 if (emitError) {
589 emitError() << "unsupported operation: unsigned div not on uint";
590 return nullptr;
591 }
592 llvm_unreachable("unsupported operation: unsigned div not on uint");
593 }
594 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
595 case BinaryFn::max_signed:
596 assert(!allComplex);
597 if (allFloatingPoint)
598 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
599 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
600 case BinaryFn::min_signed:
601 assert(!allComplex);
602 if (allFloatingPoint)
603 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
604 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
605 case BinaryFn::max_unsigned:
606 assert(!allComplex);
607 if (!allInteger || allBool) {
608 if (emitError) {
609 emitError() << "unsupported operation: unsigned max not on uint";
610 return nullptr;
611 }
612 llvm_unreachable("unsupported operation: unsigned max not on uint");
613 }
614 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
615 case BinaryFn::min_unsigned:
616 assert(!allComplex);
617 if (!allInteger || allBool) {
618 if (emitError) {
619 emitError() << "unsupported operation: unsigned min not on uint";
620 return nullptr;
621 }
622 llvm_unreachable("unsupported operation: unsigned min not on uint");
623 }
624 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
625 case BinaryFn::powf:
626 assert(allFloatingPoint);
627 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
628 }
629 if (emitError) {
630 emitError() << "unsupported binary function";
631 return nullptr;
632 }
633 llvm_unreachable("unsupported binary function");
634 }
635
636 // Build the ternary functions defined by OpDSL.
637 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
638 function_ref<InFlightDiagnostic()> emitError = {}) {
639 OpBuilder::InsertionGuard g(builder);
640 builder.setInsertionPointToEnd(&block);
641 switch (ternaryFn) {
642 case TernaryFn::select:
643 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
644 }
645 if (emitError) {
646 emitError() << "unsupported ternary function";
647 return nullptr;
648 }
649 llvm_unreachable("unsupported ternary function");
650 }
651
652 // Build the type functions defined by OpDSL.
653 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
654 function_ref<InFlightDiagnostic()> emitError = {}) {
655 switch (typeFn) {
656 case TypeFn::cast_signed:
657 return cast(toType, operand, false);
658 case TypeFn::cast_unsigned:
659 return cast(toType, operand, true);
660 }
661 if (emitError) {
662 emitError() << "unsupported type conversion function";
663 return nullptr;
664 }
665 llvm_unreachable("unsupported type conversion function");
666 }
667
668 void yieldOutputs(ValueRange values) {
669 OpBuilder::InsertionGuard g(builder);
670 builder.setInsertionPointToEnd(&block);
671 Location loc = builder.getUnknownLoc();
672 YieldOp::create(builder, loc, values);
673 }
674
675 Value constant(const std::string &value) {
676 OpBuilder::InsertionGuard g(builder);
677 builder.setInsertionPointToEnd(&block);
678 Location loc = builder.getUnknownLoc();
679 Attribute valueAttr = parseAttribute(value, builder.getContext());
680 return arith::ConstantOp::create(builder, loc,
681 ::cast<TypedAttr>(valueAttr));
682 }
683
684 Value index(int64_t dim) {
685 OpBuilder::InsertionGuard g(builder);
686 builder.setInsertionPointToEnd(&block);
687 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
688 }
689
690 Type getIntegerType(unsigned width) {
691 return IntegerType::get(builder.getContext(), width);
692 }
693
694 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
695 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
696
697private:
698 // Generates operations to cast the given operand to a specified type.
699 // If the cast cannot be performed, a warning will be issued and the
700 // operand returned as-is (which will presumably yield a verification
701 // issue downstream).
702 Value cast(Type toType, Value operand, bool isUnsignedCast) {
703 OpBuilder::InsertionGuard g(builder);
704 builder.setInsertionPointToEnd(&block);
705 auto loc = operand.getLoc();
706 if (isa<UnknownLoc>(loc)) {
707 if (operand.getDefiningOp())
708 loc = operand.getDefiningOp()->getLoc();
709 else if (operand.getParentBlock() &&
710 operand.getParentBlock()->getParentOp())
711 loc = operand.getParentBlock()->getParentOp()->getLoc();
712 }
713 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
714 }
715
716 bool isComplex(Value value) {
717 return llvm::isa<ComplexType>(value.getType());
718 }
719 bool isFloatingPoint(Value value) {
720 return llvm::isa<FloatType>(value.getType());
721 }
722 bool isInteger(Value value) {
723 return llvm::isa<IntegerType>(value.getType());
724 }
725
726 OpBuilder &builder;
727 Block &block;
728};
729
730} // namespace
731
732//===----------------------------------------------------------------------===//
733// CopyOp
734//===----------------------------------------------------------------------===//
735
736namespace {
737
738struct EraseSelfCopy : OpRewritePattern<CopyOp> {
739 using OpRewritePattern<CopyOp>::OpRewritePattern;
740 LogicalResult matchAndRewrite(CopyOp copyOp,
741 PatternRewriter &rewriter) const override {
742 if (copyOp.getInputs() != copyOp.getOutputs())
743 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
744 if (copyOp.hasPureBufferSemantics())
745 rewriter.eraseOp(copyOp);
746 else
747 rewriter.replaceOp(copyOp, copyOp.getInputs());
748
749 return success();
750 }
751};
752
753} // namespace
754
755void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
756 MLIRContext *context) {
757 results.add<EraseSelfCopy>(context);
758}
759
760//===----------------------------------------------------------------------===//
761// FillOp
762//===----------------------------------------------------------------------===//
763
764namespace {
765
766/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
767///
768/// For such op chains, we can create new linalg.fill ops with the result
769/// type of the tensor.expand/collapse_shape op.
770template <typename TensorReshapeOp>
771struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
772 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
773 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
774 PatternRewriter &rewriter) const override {
775 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
776 if (!oldFill)
777 return failure();
778
779 Location loc = oldFill.getLoc();
780 TensorReshapeOp newInit;
781 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
782
783 newInit = TensorReshapeOp::create(
784 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
785 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
786 reshapeOp.getStaticOutputShape());
787 } else {
788 newInit = TensorReshapeOp::create(
789 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
790 reshapeOp.getReassociation());
791 }
792 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
793 ValueRange{newInit});
794 return success();
795 }
796};
797
798/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
799/// filling value are the same.
800struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
802
803 LogicalResult matchAndRewrite(tensor::PadOp padOp,
804 PatternRewriter &rewriter) const override {
805 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
806 if (!fillOp)
807 return failure();
808
809 // We can only fold if the padding value is the same as the original
810 // filling value.
811 Value padValue = padOp.getConstantPaddingValue();
812 if (!padValue || fillOp.value() != padValue)
813 return failure();
814
815 ReifiedRankedShapedTypeDims reifiedShape;
816 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
817 return rewriter.notifyMatchFailure(
818 padOp, "failed to reify tensor.pad op result shape");
819
820 auto emptyTensor =
821 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
822 padOp.getResultType().getElementType());
823 Value replacement =
824 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
825 ValueRange{emptyTensor})
826 .getResult(0);
827 if (replacement.getType() != padOp.getResultType()) {
828 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
829 padOp.getResultType(), replacement);
830 }
831 rewriter.replaceOp(padOp, replacement);
832 return success();
833 }
834};
835
836/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
837/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
838/// filling value are the same.
839struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
841
842 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
843 PatternRewriter &rewriter) const override {
844 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
845 if (!srcPadOp)
846 return failure();
847
848 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
849 return failure();
850
851 // Walk back the tensor.insert_slice chain and find the first destination
852 // value at the start of the chain.
853 Value firstDest = insertOp.getDest();
854 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
855 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
856 return failure();
857
858 // Make sure the range of values accessed are disjoint. Without this, we
859 // cannot fold tensor.pad away.
860 bool disjoint = false;
861 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
862 // If the dimension has dynamic offset/size, we cannot guarantee
863 // disjoint. So just skip it.
864 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
865 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
866 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
867 continue;
868
869 // Get the range start and end, inclusively for both.
870 int64_t prevStart = prevOp.getStaticOffset(i);
871 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
872 prevOp.getStaticStride(i);
873 int64_t nextStart = insertOp.getStaticOffset(i);
874 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
875 insertOp.getStaticStride(i);
876 if (prevEnd < nextStart || nextEnd < prevStart) {
877 disjoint = true;
878 break;
879 }
880 }
881
882 if (!disjoint)
883 break;
884 firstDest = prevOp.getDest();
885 }
886
887 // Check whether the first destination is a fill op. For overlapped cases,
888 // this also cannot be true.
889 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
890 if (!dstFillOp)
891 return failure();
892
893 // We can only fold if the padding value is the same as the original
894 // filling value.
895 Value padValue = srcPadOp.getConstantPaddingValue();
896 if (!padValue || dstFillOp.value() != padValue)
897 return failure();
898
899 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
900 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
901
902 Location loc = insertOp.getLoc();
903 MLIRContext *context = getContext();
904
905 AffineExpr sym0, sym1;
906 bindSymbols(context, sym0, sym1);
907 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
908
909 // Calculate the new offsets for the insert. It should be the old offsets
910 // plus low padding sizes.
911 SmallVector<OpFoldResult, 4> newOffsets;
912 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
913 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
914 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
915 }
916
917 RankedTensorType srcPadType = srcPadOp.getSourceType();
918 SmallVector<OpFoldResult, 4> newSizes;
919 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
920 if (srcPadType.isDynamicDim(i)) {
921 newSizes.push_back(
922 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
923 .getResult());
924 } else {
925 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
926 }
927 }
928
929 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
930 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
931 newSizes, insertOp.getMixedStrides());
932 return success();
933 }
934};
935
936/// Fold tensor.extract(linalg.fill(<input>)) into <input>
937struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
938public:
939 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
940
941 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
942 PatternRewriter &rewriter) const override {
943 // See if tensor input of tensor.extract op is the result of a linalg.fill
944 // op.
945 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
946 if (!fillOp)
947 return failure();
948
949 // Get scalar input operand of linalg.fill op.
950 Value extractedScalar = fillOp.getInputs()[0];
951
952 // Replace tensor.extract op with scalar value used to fill the tensor.
953 rewriter.replaceOp(extractOp, extractedScalar);
954 return success();
955 }
956};
957
958/// Folds pack(fill) into a single fill op if
959/// 1. The pack op does not have padding value, or
960/// 2. The filled value and padding value are the same.
961static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
962 linalg::PackOp packOp) {
963 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
964 if (!fillOp)
965 return failure();
966
967 if (auto paddingValue = packOp.getPaddingValue())
968 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
969 return failure();
970
971 Value packOpDest = packOp.getDest();
972 if (!packOpDest.hasOneUse())
973 return failure();
974
975 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
976 packOp.getDest());
977}
978
979/// Wrapper pattern that applies foldFillPackIntoFillOp method.
980struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
981public:
982 FoldFillWithPack(MLIRContext *context)
983 : OpRewritePattern<linalg::PackOp>(context) {}
984
985 LogicalResult matchAndRewrite(linalg::PackOp packOp,
986 PatternRewriter &rewriter) const override {
987 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
988 if (failed(fillOp))
989 return failure();
990 rewriter.replaceOp(packOp, fillOp.value().result());
991 return success();
992 }
993};
994
995/// Fold fill with copy.
996struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
997 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
998
999 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
1000 PatternRewriter &rewriter) const override {
1001 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
1002 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
1003 fillOp.getInputs(),
1004 copyOp.getOutputs());
1005 return success();
1006 }
1007 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1008 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
1009 fillOp.getOutputs());
1010 return success();
1011 }
1012 return failure();
1013 }
1014};
1015
1016/// Fold fill with transpose.
1017struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1018 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1019
1020 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1021 PatternRewriter &rewriter) const override {
1022 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1023 rewriter.replaceOpWithNewOp<FillOp>(
1024 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1025 transposeOp.getDpsInitOperand(0)->get());
1026 return success();
1027 }
1028 return failure();
1029 }
1030};
1031
1032/// Fold a concat with all elements being fills of the same value
1033/// into a fill of the concat result shape.
1034struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1036
1037 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1038 PatternRewriter &rewriter) const override {
1039 auto concatOperands = concatOp.getInputs();
1040 if (concatOperands.empty()) {
1041 return failure();
1042 }
1043
1044 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1045 if (!firstFillOp) {
1046 return failure();
1047 }
1048 // Prefetch the fill value.
1049 OpFoldResult firstFillVal =
1050 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1051 // Collect all the outs values for the fill operations.
1052 SmallVector<Value> allOuts;
1053 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1054
1055 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1056 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1057 if (!fillOp) {
1058 return false;
1059 }
1060
1061 OpFoldResult fillVal =
1062 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1063 if (fillVal != firstFillVal)
1064 return false;
1065
1066 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1067 return true;
1068 };
1069 if (!llvm::all_of(concatOperands.drop_front(),
1070 isDefinedByCompatibleFillOp)) {
1071 return rewriter.notifyMatchFailure(
1072 concatOp, "not all operands are defined by a compatible fill op");
1073 }
1074
1075 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1076 concatOp.getDim(), allOuts);
1077 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1078 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1079 return success();
1080 }
1081};
1082
1083} // namespace
1084
1085void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1086 MLIRContext *context) {
1087 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1088 FoldFillWithPack, FoldFillWithPad,
1089 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1090 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1091 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1092}
1093
1094//===----------------------------------------------------------------------===//
1095// GenericOp
1096//===----------------------------------------------------------------------===//
1097
1099 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1100 ValueRange outputs,
1101 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1102 SmallVector<Type, 4> blockArgTypes;
1103 SmallVector<Location, 4> blockArgLocs;
1104 for (ValueRange container : {inputs, outputs}) {
1105 for (Value v : container) {
1106 Type t = v.getType();
1107 blockArgTypes.push_back(
1108 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1109 blockArgLocs.push_back(v.getLoc());
1110 }
1111 }
1112
1113 OpBuilder::InsertionGuard guard(builder);
1114 Block *bodyBlock =
1115 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1116 bodyBuild(builder, loc, bodyBlock->getArguments());
1117}
1118
1119void GenericOp::getAsmBlockArgumentNames(Region &region,
1120 OpAsmSetValueNameFn setNameFn) {
1121 for (Value v : getRegionInputArgs())
1122 setNameFn(v, "in");
1123 for (Value v : getRegionOutputArgs())
1124 setNameFn(v, "out");
1125}
1126
1127void GenericOp::build(
1128 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1129 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1130 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1131 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1132 ArrayRef<NamedAttribute> attributes) {
1133 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1134 iteratorTypes, doc, libraryCall);
1135 result.addAttributes(attributes);
1136 if (bodyBuild)
1137 buildGenericRegion(builder, result.location, *result.regions.front(),
1138 inputs, outputs, bodyBuild);
1139}
1140
1141void GenericOp::build(
1142 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1143 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1144 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1145 StringRef libraryCall,
1146 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1147 ArrayRef<NamedAttribute> attributes) {
1148 build(builder, result, resultTensorTypes, inputs, outputs,
1149 builder.getAffineMapArrayAttr(indexingMaps),
1150 builder.getArrayAttr(llvm::map_to_vector(
1151 iteratorTypes,
1152 [&](utils::IteratorType iter) -> mlir::Attribute {
1153 return IteratorTypeAttr::get(builder.getContext(), iter);
1154 })),
1155 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1156 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1157 bodyBuild, attributes);
1158}
1159
1160void GenericOp::build(
1161 OpBuilder &builder, OperationState &result, ValueRange inputs,
1162 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1163 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1164 StringRef libraryCall,
1165 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1166 ArrayRef<NamedAttribute> attributes) {
1167 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1168 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1169}
1170
1171void GenericOp::build(
1172 OpBuilder &builder, OperationState &result, ValueRange inputs,
1173 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1174 ArrayRef<utils::IteratorType> iteratorTypes,
1175 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1176 ArrayRef<NamedAttribute> attributes) {
1177 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1178 /*doc=*/"",
1179 /*libraryCall=*/"", bodyBuild, attributes);
1180}
1181
1182void GenericOp::build(
1183 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1184 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1185 ArrayRef<utils::IteratorType> iteratorTypes,
1186 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1187 ArrayRef<NamedAttribute> attributes) {
1188 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1189 iteratorTypes,
1190 /*doc=*/"",
1191 /*libraryCall=*/"", bodyBuild, attributes);
1192}
1193
1194void GenericOp::print(OpAsmPrinter &p) {
1195 p << " ";
1196
1197 // Print extra attributes.
1198 auto genericAttrNames = linalgTraitAttrNames();
1199
1200 llvm::StringSet<> genericAttrNamesSet;
1201 genericAttrNamesSet.insert_range(genericAttrNames);
1202 SmallVector<NamedAttribute, 8> genericAttrs;
1203 for (auto attr : (*this)->getAttrs()) {
1204 if (attr.getName() == getIteratorTypesAttrName()) {
1205 auto iteratorTypes =
1206 llvm::cast<ArrayAttr>(attr.getValue())
1207 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1208 // Convert IteratorType enums into the string representation. This is
1209 // needed, because tests still use the old format when 'iterator_types'
1210 // attribute is represented as an array of strings.
1211 // TODO: Remove this conversion once tests are fixed.
1212 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1213 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1214 return StringAttr::get(getContext(), stringifyIteratorType(t));
1215 });
1216
1217 genericAttrs.emplace_back(
1218 getIteratorTypesAttrName(),
1219 ArrayAttr::get(getContext(), iteratorTypeNames));
1220 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1221 genericAttrs.push_back(attr);
1222 }
1223 }
1224 if (!genericAttrs.empty()) {
1225 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1226 p << genericDictAttr;
1227 }
1228
1229 // Printing is shared with named ops, except for the region and attributes
1230 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1231
1232 genericAttrNames.push_back("operandSegmentSizes");
1233 genericAttrNamesSet.insert(genericAttrNames.back());
1234
1235 bool hasExtraAttrs = false;
1236 for (NamedAttribute n : (*this)->getAttrs()) {
1237 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1238 break;
1239 }
1240 if (hasExtraAttrs) {
1241 p << " attrs = ";
1242 p.printOptionalAttrDict((*this)->getAttrs(),
1243 /*elidedAttrs=*/genericAttrNames);
1244 }
1245
1246 // Print region.
1247 if (!getRegion().empty()) {
1248 p << ' ';
1249 p.printRegion(getRegion());
1250 }
1251
1252 // Print results.
1253 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1254}
1255
1256ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1257 DictionaryAttr dictAttr;
1258 // Parse the core linalg traits that must check into a dictAttr.
1259 // The name is unimportant as we will overwrite result.attributes.
1260 // The core linalg traits must contain the information necessary to pass the
1261 // verifier.
1262 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1263 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1264 return failure();
1265 result.attributes.assign(dictAttr.getValue().begin(),
1266 dictAttr.getValue().end());
1267
1268 // Convert array of string into an array of IteratorType enums. This is
1269 // needed, because tests still use the old format when 'iterator_types'
1270 // attribute is represented as an array of strings.
1271 // TODO: Remove this conversion once tests are fixed.
1272 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1273 result.attributes.get(getIteratorTypesAttrName(result.name)));
1274 if (!iteratorTypes) {
1275 return parser.emitError(attributeLocation)
1276 << "expected " << getIteratorTypesAttrName(result.name)
1277 << " array attribute";
1278 }
1279
1280 SmallVector<Attribute> iteratorTypeAttrs;
1281
1282 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1283 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1284 if (!maybeIteratorType.has_value())
1285 return parser.emitError(parser.getCurrentLocation())
1286 << "unexpected iterator_type (" << s << ")";
1287
1288 iteratorTypeAttrs.push_back(
1289 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1290 }
1291 result.attributes.set(getIteratorTypesAttrName(result.name),
1292 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1293
1294 // Parsing is shared with named ops, except for the region.
1295 SmallVector<Type, 1> inputTypes, outputTypes;
1296 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1297 return failure();
1298
1299 // Optional attributes may be added.
1300 if (succeeded(parser.parseOptionalKeyword("attrs")))
1301 if (failed(parser.parseEqual()) ||
1302 failed(parser.parseOptionalAttrDict(result.attributes)))
1303 return failure();
1304
1305 std::unique_ptr<Region> region = std::make_unique<Region>();
1306 if (parser.parseRegion(*region, {}))
1307 return failure();
1308 result.addRegion(std::move(region));
1309
1310 // Generic ops may specify that a subset of its outputs are tensors. Such
1311 // outputs are specified in the result type.
1312 // TODO: may need to move output parsing before region parsing.
1313 // Need to wait for declarative assembly resolution to decide.
1314 SmallVector<Type, 1> outputTensorsTypes;
1315 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1316 return failure();
1317 result.addTypes(outputTensorsTypes);
1318
1319 return success();
1320}
1321
1324 &effects,
1325 LinalgOp linalgOp) {
1326 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1327 if (!llvm::isa<MemRefType>(operand.getType()))
1328 continue;
1329 effects.emplace_back(
1330 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1331 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1332 }
1333
1334 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1335 if (!llvm::isa<MemRefType>(operand.get().getType()))
1336 continue;
1337 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1338 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1339 /*effectOnFullRegion=*/true,
1341 }
1342 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1343 /*effectOnFullRegion=*/true,
1345 }
1346}
1347
1348void GenericOp::getEffects(
1349 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1350 &effects) {
1351 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1352}
1353
1356 // Operands with value semantics are speculatable, while operands with memory
1357 // semantics are not.
1358 if (!linalgOp.hasPureTensorSemantics())
1360 // The body of the op can still have speculation in its region.
1362}
1363
1364Speculation::Speculatability GenericOp::getSpeculatability() {
1365 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1366}
1367
1368namespace {
1369
1370/// Remove linalg operations that are just copying the values from inputs to
1371/// results. In the memref case, the operation must be copying to and from the
1372/// same value. Requirements are:
1373/// 1) All iterator types are parallel
1374/// 2) The body contains just a yield operation with the yielded values being
1375/// the arguments corresponding to the operands.
1376template <typename OpTy>
1377struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1378 using OpRewritePattern<OpTy>::OpRewritePattern;
1379
1380 LogicalResult matchAndRewrite(OpTy linalgOp,
1381 PatternRewriter &rewriter) const override {
1382 // All indexing maps must be equal. It follows that they are permutations.
1383 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1384 return failure();
1385
1386 // Check that the body of the linalg operation is just a linalg.yield
1387 // operation.
1388 Block &body = linalgOp->getRegion(0).front();
1389 if (!llvm::hasSingleElement(body))
1390 return failure();
1391 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1392 if (!yieldOp)
1393 return failure();
1394
1395 // In the buffer case, we need to check exact buffer equality.
1396 if (linalgOp.hasPureBufferSemantics()) {
1397 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1398 linalgOp.getDpsInputOperand(0)->get() !=
1399 linalgOp.getDpsInitOperand(0)->get()) {
1400 return rewriter.notifyMatchFailure(
1401 linalgOp, "expected single input and output to be the same value");
1402 }
1403
1404 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1405 if (!yieldArg || yieldArg.getOwner() != &body) {
1406 return rewriter.notifyMatchFailure(linalgOp,
1407 "cannot fold fill-like op");
1408 }
1409
1410 rewriter.eraseOp(linalgOp);
1411 return success();
1412 }
1413
1414 if (!linalgOp.hasPureTensorSemantics()) {
1415 return rewriter.notifyMatchFailure(
1416 linalgOp, "mixed semantics is not supported yet");
1417 }
1418
1419 // Get the argument number of the returned values. That is the operand
1420 // number to use for replacing uses of this operation.
1421 SmallVector<Value> returnedArgs;
1422 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1423 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1424 if (!yieldArg || yieldArg.getOwner() != &body)
1425 return failure();
1426 unsigned argumentNumber = yieldArg.getArgNumber();
1427 Value returnedArg = linalgOp->getOperand(argumentNumber);
1428 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1429 // The input can have a different type than the result, e.g. a dynamic
1430 // input dimension can be turned into a static output dimension.
1431 Type returnType = returnedArg.getType();
1432 if (returnType != resultType) {
1433 // Distinguish between sparse conversion or dense tensor casting.
1434 // TODO: unify the two ops?
1437 returnedArg = sparse_tensor::ConvertOp::create(
1438 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1439 else {
1440 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1441 resultType))
1442 return failure();
1443 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1444 resultType, returnedArg);
1445 }
1446 }
1447 returnedArgs.push_back(returnedArg);
1448 }
1449
1450 if (returnedArgs.size() != linalgOp->getNumResults())
1451 return failure();
1452 rewriter.replaceOp(linalgOp, returnedArgs);
1453 return success();
1454 }
1455};
1456
1457} // namespace
1458
1459void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1460 MLIRContext *context) {
1461 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1462}
1463
1464LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1465 return memref::foldMemRefCast(*this);
1466}
1467
1468//===----------------------------------------------------------------------===//
1469// MapOp
1470//===----------------------------------------------------------------------===//
1471
1472static ParseResult parseDstStyleOp(
1474 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1475 nullptr) {
1476 // Parse `ins` and `outs`.
1477 SmallVector<Type, 4> inputTypes, outputTypes;
1478 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1479 /*addOperandSegmentSizes=*/false))
1480 return failure();
1481
1482 // Add result types.
1483 for (Type outputType : outputTypes) {
1484 if (llvm::isa<RankedTensorType>(outputType))
1485 result.addTypes(outputType);
1486 }
1487
1488 // Parse required attributes.
1489 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1490 return failure();
1491
1492 // Parse optional attributes.
1493 if (parser.parseOptionalAttrDict(result.attributes))
1494 return failure();
1495 return success();
1496}
1497
1498void MapOp::getAsmBlockArgumentNames(Region &region,
1499 OpAsmSetValueNameFn setNameFn) {
1500 for (Value v : getRegionInputArgs())
1501 setNameFn(v, "in");
1502 for (Value v : getRegionOutputArgs())
1503 setNameFn(v, "init");
1504}
1505
1506void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1507 if (!getResults().empty())
1508 setNameFn(getResults().front(), "mapped");
1509}
1510
1511void MapOp::build(
1512 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1513 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1514 ArrayRef<NamedAttribute> attributes) {
1515 build(builder, result, TypeRange{}, inputs, init);
1516 result.addAttributes(attributes);
1517
1518 // Add output types for `RankedTensorType` output arguments.
1519 Type initType = init.getType();
1520 if (llvm::isa<RankedTensorType>(initType))
1521 result.addTypes(initType);
1522
1523 if (bodyBuild)
1524 buildGenericRegion(builder, result.location, *result.regions.front(),
1525 inputs, /*outputs=*/{init}, bodyBuild);
1526}
1527
1529 const OperationName &payloadOpName,
1530 const NamedAttrList &payloadOpAttrs,
1531 ArrayRef<Value> operands,
1532 bool initFirst = false, bool mapInit = true) {
1533 OpBuilder b(parser.getContext());
1534 Region *body = result.addRegion();
1535 Block &block = body->emplaceBlock();
1536 b.setInsertionPointToStart(&block);
1537 for (auto &operand : operands) {
1538 block.addArgument(
1539 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1540 b.getUnknownLoc());
1541 }
1542 SmallVector<Value> payloadOpOperands;
1543 // If initFirst flag is enabled, we consider init as the first position of
1544 // payload operands.
1545 if (initFirst) {
1546 if (mapInit)
1547 payloadOpOperands.push_back(block.getArguments().back());
1548 for (const auto &arg : block.getArguments().drop_back())
1549 payloadOpOperands.push_back(arg);
1550 } else {
1551 payloadOpOperands = {block.getArguments().begin(),
1552 block.getArguments().end() - int(!mapInit)};
1553 }
1554
1555 Operation *payloadOp = b.create(
1556 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1557 payloadOpOperands,
1558 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1559 .getElementType()},
1560 payloadOpAttrs);
1561 YieldOp::create(b, result.location, payloadOp->getResults());
1562}
1563
1564ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1565 std::optional<OperationName> payloadOpName;
1566 NamedAttrList payloadOpAttrs;
1567 if (succeeded(parser.parseOptionalLBrace())) {
1568 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1569 if (failed(operationName))
1570 return failure();
1571 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1572 return failure();
1573 payloadOpName = operationName.value();
1574 if (parser.parseRBrace())
1575 return failure();
1576 }
1577
1578 if (parseDstStyleOp(parser, result))
1579 return failure();
1580
1581 if (payloadOpName.has_value()) {
1582 if (!result.operands.empty())
1583 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1584 payloadOpAttrs, ArrayRef(result.operands), false,
1585 false);
1586 else
1587 result.addRegion();
1588 } else {
1589 SmallVector<OpAsmParser::Argument> regionArgs;
1590 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1591 /*allowType=*/true, /*allowAttrs=*/true)) {
1592 return failure();
1593 }
1594 Region *body = result.addRegion();
1595 if (parser.parseRegion(*body, regionArgs))
1596 return failure();
1597 }
1598 return success();
1599}
1600
1601static bool canUseShortForm(Block *body, bool initFirst = false,
1602 bool mapInit = true) {
1603 // `intFirst == true` implies that we want to map init arg
1604 if (initFirst && !mapInit)
1605 return false;
1606 // Check if the body can be printed in short form. The following 4 conditions
1607 // must be satisfied:
1608
1609 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1610 if (body->getOperations().size() != 2)
1611 return false;
1612 Operation &payload = body->getOperations().front();
1613
1614 // 2) The payload op must have the same number of operands as the number of
1615 // block arguments.
1616 if (payload.getNumOperands() == 0 ||
1617 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1618 return false;
1619
1620 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1621 // must be the first operand of the payload op, otherwise, the operands
1622 // must match the block arguments in order.
1623 if (initFirst) {
1624 // check init
1625 if (payload.getOperands().back() != body->getArgument(0))
1626 return false;
1627 // check rest
1628 for (const auto &[operand, bbArg] :
1629 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1630 if (bbArg != operand)
1631 return false;
1632 }
1633 } else {
1634 for (const auto &[operand, bbArg] :
1635 llvm::zip(payload.getOperands(),
1636 body->getArguments().drop_back(int(!mapInit)))) {
1637 if (bbArg != operand)
1638 return false;
1639 }
1640 }
1641
1642 // 4) The `yield` operand must be the result of the payload op.
1643 auto yieldOp = cast<YieldOp>(body->getTerminator());
1644 return yieldOp.getNumOperands() == 1 &&
1645 yieldOp.getOperand(0).getDefiningOp() &&
1646 yieldOp.getOperand(0).getDefiningOp() == &payload;
1647}
1648
1649static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1650 SmallVector<StringRef> elidedAttrs;
1651 std::string attrToElide;
1652 p << " { " << payloadOp->getName().getStringRef();
1653 for (const auto &attr : payloadOp->getAttrs()) {
1654 auto fastAttr =
1655 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1656 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1657 attrToElide = attr.getName().str();
1658 elidedAttrs.push_back(attrToElide);
1659 break;
1660 }
1661 }
1662 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1663 p << " }";
1664}
1665
1666void MapOp::print(OpAsmPrinter &p) {
1667 Block *mapper = getBody();
1668 bool useShortForm =
1669 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1670 if (useShortForm) {
1671 printShortForm(p, &mapper->getOperations().front());
1672 }
1673
1674 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1675 p.printOptionalAttrDict((*this)->getAttrs());
1676
1677 if (!useShortForm) {
1678 // Print region if the payload op was not detected.
1679 p.increaseIndent();
1680 p.printNewline();
1681 p << "(";
1682 llvm::interleaveComma(mapper->getArguments(), p,
1683 [&](auto arg) { p.printRegionArgument(arg); });
1684 p << ") ";
1685
1686 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1687 p.decreaseIndent();
1688 }
1689}
1690
1691LogicalResult MapOp::verify() {
1692 auto *bodyBlock = getBody();
1693 auto blockArgs = bodyBlock->getArguments();
1694
1695 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1696 // region.
1697 if (getInputs().size() + 1 != blockArgs.size())
1698 return emitOpError() << "expects number of operands to match the arity of "
1699 "mapper, but got: "
1700 << getInputs().size() + 1 << " and "
1701 << blockArgs.size();
1702
1703 // The parameters of mapper should all match the element type of inputs.
1704 for (const auto &[bbArgType, inputArg] :
1705 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1706 auto inputElemType =
1707 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1708 if (bbArgType != inputElemType) {
1709 return emitOpError() << "expected element type of input " << inputElemType
1710 << " to match bbArg type " << bbArgType;
1711 }
1712 }
1713
1714 // The shape of each input must match the shape of the output.
1715 auto outputShape = getInit().getType().getShape();
1716 for (Type inputArgType : TypeRange{getInputs()}) {
1717 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1718 if (inputElemShape != outputShape) {
1719 return emitOpError() << "expected shape of input (" << inputElemShape
1720 << ") to match shape of output (" << outputShape
1721 << ")";
1722 }
1723 }
1724
1725 return success();
1726}
1727
1728SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1729 int64_t rank = getInit().getType().getRank();
1730 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1731}
1732
1733ArrayAttr MapOp::getIndexingMaps() {
1734 Builder builder(getContext());
1735 int64_t rank = getInit().getType().getRank();
1736 int64_t numIndexingMaps = getOperands().size();
1737 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1738 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1739}
1740
1741void MapOp::getEffects(
1742 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1743 &effects) {
1744 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1745}
1746
1747Speculation::Speculatability MapOp::getSpeculatability() {
1748 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1749}
1750
1751//===----------------------------------------------------------------------===//
1752// ReduceOp
1753//===----------------------------------------------------------------------===//
1754
1755void ReduceOp::getAsmBlockArgumentNames(Region &region,
1756 OpAsmSetValueNameFn setNameFn) {
1757 for (Value v : getRegionInputArgs())
1758 setNameFn(v, "in");
1759 for (Value v : getRegionOutputArgs())
1760 setNameFn(v, "init");
1761}
1762
1763void ReduceOp::getAsmResultNames(
1764 function_ref<void(Value, StringRef)> setNameFn) {
1765 if (!getResults().empty())
1766 setNameFn(getResults().front(), "reduced");
1767}
1768
1769void ReduceOp::build(
1770 OpBuilder &builder, OperationState &result, ValueRange inputs,
1771 ValueRange inits, ArrayRef<int64_t> dimensions,
1772 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1773 ArrayRef<NamedAttribute> attributes) {
1774 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1775 result.addAttributes(attributes);
1776
1777 // Add output types for `RankedTensorType` output arguments.
1778 for (Value init : inits) {
1779 Type initType = init.getType();
1780 if (llvm::isa<RankedTensorType>(initType))
1781 result.addTypes(initType);
1782 }
1783
1784 if (bodyBuild)
1785 buildGenericRegion(builder, result.location, *result.regions.front(),
1786 inputs, inits, bodyBuild);
1787}
1788
1789SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1790 int64_t inputRank =
1791 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1792 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1793 utils::IteratorType::parallel);
1794 for (int64_t reductionDim : getDimensions())
1795 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1796 return iteratorTypes;
1797}
1798
1799ArrayAttr ReduceOp::getIndexingMaps() {
1800 int64_t inputRank =
1801 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1802 SmallVector<AffineMap> affineMaps(
1803 getNumDpsInputs(),
1805 AffineMap resultMap =
1807 .dropResults(getDimensions());
1808 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1809 affineMaps.push_back(resultMap);
1810 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1811}
1812
1813void ReduceOp::getEffects(
1814 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1815 &effects) {
1816 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1817}
1818
1819Speculation::Speculatability ReduceOp::getSpeculatability() {
1820 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1821}
1822
1823static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1824 NamedAttrList &attributes,
1825 StringRef attributeName) {
1826 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1827 return failure();
1828
1829 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1830 return success();
1831}
1832
1833ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1834 std::optional<OperationName> payloadOpName;
1835 NamedAttrList payloadOpAttrs;
1836 if (succeeded(parser.parseOptionalLBrace())) {
1837 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1838 if (failed(operationName))
1839 return failure();
1840 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1841 return failure();
1842 payloadOpName = operationName.value();
1843 if (parser.parseRBrace())
1844 return failure();
1845 }
1846
1847 if (parseDstStyleOp(
1848 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1849 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1850 }))
1851 return failure();
1852
1853 if (payloadOpName.has_value()) {
1854 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1855 ArrayRef(result.operands), /*initFirst=*/true);
1856 } else {
1857 SmallVector<OpAsmParser::Argument> regionArgs;
1858 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1859 /*allowType=*/true, /*allowAttrs=*/true)) {
1860 return failure();
1861 }
1862
1863 Region *body = result.addRegion();
1864 if (parser.parseRegion(*body, regionArgs))
1865 return failure();
1866 }
1867
1868 return success();
1869}
1870
1871static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1872 ArrayRef<int64_t> attributeValue) {
1873 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1874}
1875
1876void ReduceOp::print(OpAsmPrinter &p) {
1877 Block *mapper = getBody();
1878 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1879 if (useShortForm) {
1880 printShortForm(p, &mapper->getOperations().front());
1881 }
1882
1883 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1884 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1885 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1886 if (!useShortForm) {
1887 // Print region if the payload op was not detected.
1888 p.increaseIndent();
1889 p.printNewline();
1890 p << "(";
1891 llvm::interleaveComma(mapper->getArguments(), p,
1892 [&](auto arg) { p.printRegionArgument(arg); });
1893 p << ") ";
1894
1895 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1896 p.decreaseIndent();
1897 }
1898}
1899
1900LogicalResult ReduceOp::verify() {
1901 ArrayRef<int64_t> dimensionsRef = getDimensions();
1902
1903 // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
1904 // of inputs and inits. Detect a mismatch early: when they differ, the
1905 // ODS-generated getInputs()/getInits() accessors compute each group's size
1906 // via floordiv of the total operand count, producing incorrect slices that
1907 // would cause out-of-bounds accesses below.
1908 if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
1909 return emitOpError()
1910 << "expected equal number of inputs and outputs (required by "
1911 "SameVariadicOperandSize), got "
1912 << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
1913 << " output(s)";
1914
1915 if (getInputs().empty())
1916 return emitOpError() << "expected at least one input";
1917
1918 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1919 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1920 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1921 return emitOpError() << "expects all inputs to have the same shapes. "
1922 "Shape at input-index "
1923 << i
1924 << " is not equal to the shape at input-index 0.";
1925 }
1926 }
1927 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1928 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1929 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1930 return emitOpError() << "expects all outputs to have the same shapes. "
1931 "Shape at output-index "
1932 << i
1933 << " is not equal to the shape at output-index 0.";
1934 }
1935 }
1936 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1937 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1938
1939 DenseSet<int64_t> dimensionsToReduce;
1940 for (int64_t dimension : dimensionsRef) {
1941 if (dimension < 0 || dimension >= inputType.getRank()) {
1942 return emitOpError()
1943 << "dimensions for reduction should be in the range [0, "
1944 << inputType.getRank() - 1 << "].";
1945 }
1946 dimensionsToReduce.insert(dimension);
1947 }
1948
1949 auto inputDims = inputType.getShape();
1950 auto initDims = initType.getShape();
1951
1952 // Input dimensions that will be left after the reduction.
1953 SmallVector<int64_t> reducedInputDims;
1954 for (const auto &en : llvm::enumerate(inputDims)) {
1955 if (!dimensionsToReduce.count(en.index()))
1956 reducedInputDims.push_back(en.value());
1957 }
1958
1959 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1960 return emitOpError() << "number of dimensions after reduction "
1961 << reducedInputDims.size()
1962 << " doesn't match the init rank "
1963 << initType.getRank();
1964 }
1965
1966 if (reducedInputDims != initDims)
1967 return emitOpError() << "init dimensions [" << initDims
1968 << "] doesn't match input dimensions after reduction ["
1969 << reducedInputDims << "]";
1970
1971 Block *block = getBody();
1972 if (block->getNumArguments() != this->getNumOperands())
1973 return emitOpError()
1974 << "mismatching number of operands and block arguments";
1975
1976 // Check that the first block arguments match the element type of the inputs.
1977 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1978 Type inputElementType =
1979 llvm::cast<ShapedType>(input.getType()).getElementType();
1980 if (inputElementType != bbArg.getType())
1981 return emitOpError()
1982 << "input element type " << inputElementType
1983 << " does not match corresponding block argument type "
1984 << bbArg.getType();
1985 }
1986
1987 // Check that the last block arguments match the element type of the outputs.
1988 for (auto [output, bbArg] : llvm::zip(
1989 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1990 auto outputElementType =
1991 llvm::cast<ShapedType>(output.getType()).getElementType();
1992 if (outputElementType != bbArg.getType())
1993 return emitOpError()
1994 << "output element type " << outputElementType
1995 << " does not match corresponding block argument type "
1996 << bbArg.getType();
1997 }
1998 return success();
1999}
2000
2001//===----------------------------------------------------------------------===//
2002// TransposeOp
2003//===----------------------------------------------------------------------===//
2004
2005static void buildIdentityRegion(OpBuilder &builder, Location loc,
2006 Region &region, ValueRange inputs,
2007 ValueRange outputs) {
2008 buildGenericRegion(builder, loc, region, inputs, outputs,
2009 [](OpBuilder &b, Location loc, ValueRange args) {
2010 if (!args.empty())
2011 linalg::YieldOp::create(b, loc, args[0]);
2012 });
2013}
2014
2015void TransposeOp::build(::mlir::OpBuilder &builder,
2016 ::mlir::OperationState &result, Value input, Value init,
2017 DenseI64ArrayAttr permutation,
2018 ArrayRef<NamedAttribute> attributes) {
2019 result.addOperands(input);
2020 result.addOperands(init);
2021 result.addAttribute(getPermutationAttrName(result.name), permutation);
2022 result.addAttributes(attributes);
2023
2024 // Add output types for `RankedTensorType` output arguments.
2025 Type initType = init.getType();
2026 if (llvm::isa<RankedTensorType>(initType))
2027 result.addTypes(initType);
2028
2029 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2030 init);
2031}
2032
2033void TransposeOp::build(::mlir::OpBuilder &builder,
2034 ::mlir::OperationState &result, Value input, Value init,
2035 ArrayRef<int64_t> permutation,
2036 ArrayRef<NamedAttribute> attributes) {
2037 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2038 attributes);
2039}
2040
2041ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2043 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2044 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2045 })))
2046 return failure();
2047
2048 OpBuilder builder(parser.getContext());
2049 buildIdentityRegion(builder, result.location, *result.addRegion(),
2050 /*inputs=*/result.operands,
2051 /*outputs=*/{});
2052 return success();
2053}
2054
2055void TransposeOp::getAsmResultNames(
2056 function_ref<void(Value, StringRef)> setNameFn) {
2057 if (!getResults().empty())
2058 setNameFn(getResults().front(), "transposed");
2059}
2060
2061void TransposeOp::print(OpAsmPrinter &p) {
2062 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2063 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2064 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2065}
2066
2067LogicalResult TransposeOp::verify() {
2068 ArrayRef<int64_t> permutationRef = getPermutation();
2069
2070 if (!isPermutationVector(permutationRef))
2071 return emitOpError("permutation is not valid");
2072
2073 auto inputType = getInput().getType();
2074 auto initType = getInit().getType();
2075
2076 int64_t rank = inputType.getRank();
2077
2078 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2079 "init")))
2080 return failure();
2081
2082 if (rank != static_cast<int64_t>(permutationRef.size()))
2083 return emitOpError() << "size of permutation " << permutationRef.size()
2084 << " does not match the argument rank " << rank;
2085
2086 auto inputDims = inputType.getShape();
2087 auto initDims = initType.getShape();
2088
2089 for (int64_t i = 0; i < rank; ++i) {
2090 int64_t inputDim = inputDims[permutationRef[i]];
2091 int64_t initDim = initDims[i];
2092
2093 if (inputDim != initDim) {
2094 return emitOpError() << "dim(result, " << i << ") = " << initDim
2095 << " doesn't match dim(input, permutation[" << i
2096 << "]) = " << inputDim;
2097 }
2098 }
2099
2100 return success();
2101}
2102
2103SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2104 int64_t rank = getInit().getType().getRank();
2105 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2106}
2107
2108ArrayAttr TransposeOp::getIndexingMaps() {
2109 Builder builder(getContext());
2110 int64_t rank = getInit().getType().getRank();
2111 return builder.getAffineMapArrayAttr(
2113 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2114 builder.getMultiDimIdentityMap(rank)});
2115}
2116
2117void TransposeOp::getEffects(
2118 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2119 &effects) {
2120 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2121}
2122
2123Speculation::Speculatability TransposeOp::getSpeculatability() {
2124 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2125}
2126
2127LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2128 SmallVectorImpl<OpFoldResult> &result) {
2129 // Only the tensor type is supported.
2130 if (!isa<TensorType>(getInput().getType()))
2131 return failure();
2132
2133 // Single dimension transpose.
2134 if (getPermutation().empty()) {
2135 result.push_back(getInput());
2136 return success();
2137 }
2138 // Identity permutation.
2139 if (isIdentityPermutation(getPermutation())) {
2140 result.push_back(getInput());
2141 return success();
2142 }
2143
2144 return failure();
2145}
2146
2147/// Fold transpose with transpose.
2148struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2149 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2150
2151 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2152 PatternRewriter &rewriter) const override {
2153 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2154 if (!defTransposeOp)
2155 return failure();
2156 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2157 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2158 SmallVector<int64_t> foldedPerms;
2159 foldedPerms.reserve(perms.size());
2160 for (int64_t perm : perms)
2161 foldedPerms.push_back(defPerms[perm]);
2162
2163 rewriter.replaceOpWithNewOp<TransposeOp>(
2164 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2165 foldedPerms);
2166 return success();
2167 }
2168};
2169
2170/// Rewrite a transpose of a dense splat constant into a dense splat constant of
2171/// the transposed output shape.
2172struct FoldTransposeSplatConstant : OpRewritePattern<linalg::TransposeOp> {
2173 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2174
2175 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2176 PatternRewriter &rewriter) const override {
2177 if (!transposeOp.hasPureTensorSemantics())
2178 return failure();
2179
2180 auto splatValue =
2181 getScalarConstantAttrFromDenseSplat(transposeOp.getInput());
2182 if (!splatValue.has_value())
2183 return failure();
2184
2185 auto resultType =
2186 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2187
2188 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2189 rewriter.replaceOpWithNewOp<arith::ConstantOp>(transposeOp, resultType,
2190 resultAttr);
2191 return success();
2192 }
2193};
2194
2195/// This pattern canonicalize transpose by swapping the order of
2196/// broadcast and transpose:
2197/// transpose(broadcast(input)) -> broadcast(transpose(input))
2198struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2199 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2200
2201 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2202 PatternRewriter &rewriter) const override {
2203 Value input = transposeOp.getInput();
2204 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2205 if (!input.hasOneUse() || !broadcastOp)
2206 return failure();
2207
2208 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2209 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2210
2211 // Get new perms and new dimensions.
2212 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2214 SmallVector<int64_t> resultDimensions;
2215 unsigned dimensionSize = dimensions.size();
2216 for (unsigned i = 0; i < dimensionSize; ++i)
2217 resultDimensions.push_back(invertPerm[dimensions[i]]);
2218
2219 // Create transpose result.
2220 Value broadcastInput = broadcastOp.getInput();
2221 Location loc = transposeOp.getLoc();
2222 MLIRContext *ctx = transposeOp.getContext();
2224 auto broadcastInputTy =
2225 mlir::cast<RankedTensorType>(broadcastInput.getType());
2226 unsigned inputRank = broadcastInputTy.getRank();
2227 for (unsigned i = 0; i < inputRank; ++i) {
2228 if (broadcastInputTy.isDynamicDim(i)) {
2229 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2230 ->getResult(0));
2231 } else {
2232 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2233 broadcastInputTy.getDimSize(i)));
2234 }
2235 }
2236 SmallVector<OpFoldResult> transposeResultShapes =
2237 applyPermutation(dims, resultPerms);
2238 Value transposeInit = tensor::EmptyOp::create(
2239 rewriter, transposeOp.getLoc(), transposeResultShapes,
2240 broadcastInputTy.getElementType());
2241
2242 // Create broadcast(transpose(input)).
2243 Value transposeResult =
2244 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2245 transposeInit, resultPerms)
2246 ->getResult(0);
2247 rewriter.replaceOpWithNewOp<BroadcastOp>(
2248 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2249 return success();
2250 }
2251};
2252
2253void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2254 MLIRContext *context) {
2255 results.add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2256 SwapTransposeWithBroadcast>(context);
2257}
2258
2259//===----------------------------------------------------------------------===//
2260// BroadcastOp
2261//===----------------------------------------------------------------------===//
2262
2263void BroadcastOp::build(::mlir::OpBuilder &builder,
2264 ::mlir::OperationState &result, Value input, Value init,
2265 DenseI64ArrayAttr dimensions,
2266 ArrayRef<NamedAttribute> attributes) {
2267 result.addOperands(input);
2268 result.addOperands(init);
2269 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2270 result.addAttributes(attributes);
2271
2272 // Add output types for `RankedTensorType` output arguments.
2273 Type initType = init.getType();
2274 if (llvm::isa<RankedTensorType>(initType))
2275 result.addTypes(initType);
2276
2277 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2278 init);
2279}
2280
2281void BroadcastOp::build(::mlir::OpBuilder &builder,
2282 ::mlir::OperationState &result, Value input, Value init,
2283 ArrayRef<int64_t> dimensions,
2284 ArrayRef<NamedAttribute> attributes) {
2285 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2286 attributes);
2287}
2288
2289ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2291 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2292 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2293 })))
2294 return failure();
2295
2296 OpBuilder builder(parser.getContext());
2297 buildIdentityRegion(builder, result.location, *result.addRegion(),
2298 /*inputs=*/result.operands,
2299 /*outputs=*/{});
2300 return success();
2301}
2302
2303void BroadcastOp::getAsmResultNames(
2304 function_ref<void(Value, StringRef)> setNameFn) {
2305 if (!getResults().empty())
2306 setNameFn(getResults().front(), "broadcasted");
2307}
2308
2309void BroadcastOp::print(OpAsmPrinter &p) {
2310 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2311 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2312 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2313}
2314
2315LogicalResult BroadcastOp::verify() {
2316 ArrayRef<int64_t> dimensionsRef = getDimensions();
2317
2318 auto inputType = getInput().getType();
2319 auto initType = getInit().getType();
2320
2321 int64_t inputRank = inputType.getRank();
2322 int64_t initRank = initType.getRank();
2323
2324 auto inputShape = inputType.getShape();
2325 auto initShape = initType.getShape();
2326
2327 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2328 return emitOpError() << "input rank plus added dimensions does not "
2329 "match init rank. input rank: "
2330 << inputRank
2331 << ", dimensions size: " << dimensionsRef.size()
2332 << ", init rank: " << initRank;
2333
2334 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2335 if (dim < 0 || dim >= initRank)
2336 return emitOpError() << "dimension " << idx
2337 << " is out of range. expected range: [0, "
2338 << initRank - 1 << "], got: " << dim;
2339 }
2340
2341 // Mapping from input dims to init dims.
2342 SmallVector<int64_t> dimMap;
2343 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2344 if (!llvm::is_contained(dimensionsRef, dim))
2345 dimMap.push_back(dim);
2346 }
2347
2348 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2349 // This dimensions is mapped from the input. Init and input dims should
2350 // match.
2351 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2352 return emitOpError() << "input dim " << inputDimIdx
2353 << " should match init dim " << initDimIdx
2354 << ". input: " << inputShape[inputDimIdx]
2355 << ", init: " << initShape[initDimIdx];
2356 }
2357
2358 return success();
2359}
2360
2361SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2362 int64_t rank = getInit().getType().getRank();
2363 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2364}
2365
2366ArrayAttr BroadcastOp::getIndexingMaps() {
2367 Builder builder(getContext());
2368 int64_t rank = getInit().getType().getRank();
2369 return builder.getAffineMapArrayAttr(
2370 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2371 builder.getMultiDimIdentityMap(rank)});
2372}
2373
2374void BroadcastOp::getEffects(
2375 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2376 &effects) {
2377 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2378}
2379
2380Speculation::Speculatability BroadcastOp::getSpeculatability() {
2381 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2382}
2383
2384/// Fold back-to-back broadcasts together.
2385struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2386 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2387
2388 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2389 PatternRewriter &rewriter) const override {
2390 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2391 if (!defBroadcastOp)
2392 return failure();
2393 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2394 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2395 SmallVector<int64_t> foldedDims(dimensions);
2396 Value init = broadcastOp.getInit();
2397 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2398 // Mapping from input dims to init dims.
2399 SmallVector<int64_t> dimMap;
2400 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2401 if (!llvm::is_contained(dimensions, dim))
2402 dimMap.push_back(dim);
2403 }
2404 for (auto dim : defDimensions)
2405 foldedDims.push_back(dimMap[dim]);
2406
2407 llvm::sort(foldedDims);
2408 rewriter.replaceOpWithNewOp<BroadcastOp>(
2409 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2410 return success();
2411 }
2412};
2413
2414/// Rewrite a broadcast of a dense splat constant into a dense splat constant of
2415/// the broadcast output shape.
2416struct FoldBroadcastSplatConstant : OpRewritePattern<linalg::BroadcastOp> {
2417 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2418
2419 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2420 PatternRewriter &rewriter) const override {
2421 if (!broadcastOp.hasPureTensorSemantics())
2422 return failure();
2423
2424 auto splatValue =
2425 getScalarConstantAttrFromDenseSplat(broadcastOp.getInput());
2426
2427 if (!splatValue.has_value())
2428 return failure();
2429
2430 auto resultType =
2431 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2432 if (!resultType.hasStaticShape())
2433 return rewriter.notifyMatchFailure(broadcastOp,
2434 "result type has dynamic shape");
2435
2436 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2437 rewriter.replaceOpWithNewOp<arith::ConstantOp>(broadcastOp, resultType,
2438 resultAttr);
2439 return success();
2440 }
2441};
2442
2443void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2444 MLIRContext *context) {
2445 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2446 FoldBroadcastSplatConstant>(context);
2447}
2448
2449//===----------------------------------------------------------------------===//
2450// YieldOp
2451//===----------------------------------------------------------------------===//
2452
2453void linalg::YieldOp::print(OpAsmPrinter &p) {
2454 if (getNumOperands() > 0)
2455 p << ' ' << getOperands();
2456 p.printOptionalAttrDict((*this)->getAttrs());
2457 if (getNumOperands() > 0)
2458 p << " : " << getOperandTypes();
2459}
2460
2461ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2462 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2463 SmallVector<Type, 2> types;
2464 SMLoc loc = parser.getCurrentLocation();
2465 return failure(parser.parseOperandList(opInfo) ||
2466 parser.parseOptionalAttrDict(result.attributes) ||
2467 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2468 parser.resolveOperands(opInfo, types, loc, result.operands));
2469}
2470
2471// Check the operand number and types must match the element types of the
2472// LinalgOp interface's shaped operands.
2473static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2474 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2475 return op.emitOpError("expected number of yield values (")
2476 << op.getNumOperands()
2477 << ") to match the number of inits / outs operands of the enclosing "
2478 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2479
2480 for (OpOperand &opOperand : op->getOpOperands()) {
2481 OpOperand *outputOperand =
2482 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2483 Type elementType = outputOperand->get().getType();
2484 if (isa<MemRefType, RankedTensorType>(elementType))
2485 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2486 if (opOperand.get().getType() != elementType)
2487 return op.emitOpError("type of yield operand ")
2488 << (opOperand.getOperandNumber() + 1) << " ("
2489 << opOperand.get().getType() << ") doesn't match "
2490 << "the element type of the enclosing linalg.generic op ("
2491 << elementType << ")";
2492 }
2493 return success();
2494}
2495
2496LogicalResult linalg::YieldOp::verify() {
2497 auto *parentOp = (*this)->getParentOp();
2498 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2499 return emitOpError("expected single non-empty parent region");
2500
2501 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2502 return verifyYield(*this, linalgOp);
2503
2504 return emitOpError("expected parent op with LinalgOp interface");
2505}
2506
2507//===----------------------------------------------------------------------===//
2508// IndexOp
2509//===----------------------------------------------------------------------===//
2510
2511LogicalResult IndexOp::verify() {
2512 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2513 if (!linalgOp)
2514 return emitOpError("expected parent op with LinalgOp interface");
2515 if (linalgOp.getNumLoops() <= getDim())
2516 return emitOpError("expected dim (")
2517 << getDim() << ") to be lower than the number of loops ("
2518 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2519 return success();
2520}
2521
2522OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2523 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2524 // Bail out if `linalg.index` does not have a proper parent yet at this
2525 // point, e.g., when calling `createOrFold` during IR construction in
2526 // `genericOp::build`.
2527 if (!linalgOp)
2528 return OpFoldResult{};
2529
2530 // Index of unit dims is always 0.
2531 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2532 uint64_t dim = getDim();
2533 assert(dim < loopBounds.size() && "Dim is out of bounds");
2534 if (loopBounds[dim] == 1)
2535 return IntegerAttr::get(IndexType::get(getContext()), 0);
2536
2537 return OpFoldResult{};
2538}
2539
2540/////// Operations corresponding to library calls defined with Tablegen ////////
2541
2542#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2543
2544#define GET_OP_CLASSES
2545#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2546
2547#define GET_OP_CLASSES
2548#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2549#define GET_OP_CLASSES
2550#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2551
2552AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2553 unsigned rank,
2554 MLIRContext *context) {
2555 if (maybeMap)
2556 return *maybeMap;
2557 if (rank == 0)
2558 return AffineMap::get(context);
2559 return AffineMap::getMultiDimIdentityMap(rank, context);
2560}
2561
2563mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2564 MLIRContext *context) {
2566 res.reserve(num);
2567 for (unsigned i = 0; i < num; ++i)
2568 res.push_back(getAffineDimExpr(startIdx++, context));
2569 return res;
2570}
2571
2574 auto rangeA = llvm::make_range(a.begin(), a.end());
2575 auto rangeB = llvm::make_range(b.begin(), b.end());
2576 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2577 return llvm::to_vector<4>(concatRanges);
2578}
2579
2580static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2581 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2582 ss << "view";
2583 for (auto size : memref.getShape())
2584 if (size < 0)
2585 ss << "sx";
2586 else
2587 ss << size << "x";
2588 if (failed(appendMangledType(ss, memref.getElementType())))
2589 return failure();
2590 if (auto as = memref.getMemorySpace()) {
2591 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2592 ss << "as" << attr.getInt();
2593 else
2594 return failure();
2595 }
2596 return success();
2597 }
2598 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2599 ss << "vector";
2600 llvm::interleave(
2601 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2602 if (failed(appendMangledType(ss, vec.getElementType())))
2603 return failure();
2604 return success();
2605 }
2607 ss << t;
2608 return success();
2609 }
2610 return failure();
2611}
2612
2614 assert(isa<LinalgOp>(op));
2615 std::string name(op->getName().getStringRef().str());
2616 std::string fun = "";
2617 for (NamedAttribute kv : op->getAttrs()) {
2618 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2619 fun = stringifyEnum(ufa.getValue()).str() + "_";
2620 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2621 fun = stringifyEnum(bfa.getValue()).str() + "_";
2622 }
2623 }
2624 name.reserve(128);
2625 llvm::replace(name, '.', '_');
2626 llvm::raw_string_ostream ss(name);
2627 ss << "_" << fun;
2628 for (Type t : op->getOperandTypes()) {
2629 if (failed(appendMangledType(ss, t)))
2630 return std::string();
2631 ss << "_";
2632 }
2633 name.pop_back();
2634 return name;
2635}
2636
2637//===----------------------------------------------------------------------===//
2638// Canonicalizers and Folders.
2639//===----------------------------------------------------------------------===//
2640
2641namespace {
2642struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2644
2645 LogicalResult matchAndRewrite(LinalgOp op,
2646 PatternRewriter &rewriter) const override {
2647 for (OpOperand &opOperand : op->getOpOperands()) {
2648 // Linalg "inputs" may be either tensor or memref type.
2649 // tensor<0xelt_type> is a convention that may not always mean
2650 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2651 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2652 if (!mt)
2653 continue;
2654 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2655 rewriter.eraseOp(op);
2656 return success();
2657 }
2658 }
2659 return failure();
2660 }
2661};
2662
2663/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2664/// result that is more static than the linalg op.
2665struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2666 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2667
2668 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2669 PatternRewriter &rewriter) const override {
2670 if (!tensor::canFoldIntoProducerOp(castOp))
2671 return failure();
2672
2673 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2674 if (!linalgOp)
2675 return failure();
2676
2677 // Cast can be in conditionally reachable region, if which case folding will
2678 // generate invalid code. Only conservatively fold ops in same block for
2679 // now.
2680 if (castOp->getBlock() != linalgOp->getBlock())
2681 return failure();
2682
2683 OpBuilder::InsertionGuard guard(rewriter);
2684 rewriter.setInsertionPoint(linalgOp);
2685
2686 Location loc = linalgOp.getLoc();
2687 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2688 unsigned resultNumber = resultValue.getResultNumber();
2689 auto resultType =
2690 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2691 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2692 // going from a more dynamic shape to a less dynamic shape. If the producer
2693 // for this cast, i.e. producer of the out operand, is also an operation
2694 // that folds with tensor.cast consumer (like this pattern), the cast will
2695 // continue to propagate as far up the stack as it can go.
2696 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2697 Value newOperand =
2698 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2699 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2700 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2701 linalgOp.getDpsInits().end());
2702 outputOperands[resultNumber] = newOperand;
2703 newOperands.append(outputOperands.begin(), outputOperands.end());
2704
2705 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2706 linalgOp->result_type_end());
2707 resultTypes[resultNumber] = resultType;
2708 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2709
2710 // Create a tensor.cast operation back to the original type.
2711 Value castBack = tensor::CastOp::create(
2712 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2713
2714 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2715 results[resultNumber] = castBack;
2716 rewriter.replaceOp(linalgOp, results);
2717 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2718 return success();
2719 }
2720};
2721
2722/// For each of the operand in `operands` this function maps the static sizes of
2723/// dimensions to their affine dim expressions.
2724static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2725 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2726 for (OpOperand &opOperand : operands) {
2727 if (linalgOp.isScalar(&opOperand))
2728 continue;
2729 Value src = opOperand.get();
2730 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2731 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2732
2733 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2734 // `tensor.cast` operation and source of the cast operation has a static
2735 // shape, then assign it to the `sourceShape`.
2736 auto *parentOp = src.getDefiningOp();
2737 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2738 if (parentOp) {
2739 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2740 Value castSource = castOp.getSource();
2741 auto castSourceType =
2742 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2743 if (castSourceType && castSourceType.hasStaticShape())
2744 sourceShape = castSourceType.getShape();
2745 }
2746 }
2747
2748 // If the source shape's dimension has a static shape, map the affine dim
2749 // expression to the known static size.
2750 for (unsigned i = 0; i < sourceShape.size(); i++) {
2751 if (sourceType.isDynamicDim(i))
2752 continue;
2753 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2754 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2755 }
2756 }
2757}
2758
2759/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2760/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2761/// their result types is stored in `resultTypes`. If `opOperand` requires no
2762/// change then `changeNeeded` is false and same operand is added in the
2763/// `newOperands` list.
2764static void createNewOperandWithStaticSizes(
2765 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2766 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2767 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2768 bool &changeNeeded) {
2769 Value src = opOperand->get();
2770 newOperands.push_back(src);
2771 if (linalgOp.isScalar(opOperand))
2772 return;
2773 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2774 Type resultType = sourceType;
2775 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2776 resultTypes.push_back(resultType);
2777 return;
2778 }
2779 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2780 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2781 SmallVector<int64_t> newShape;
2782 // If operand is updated with new shape, `newOperandNeeded` will be
2783 // true.
2784 bool newOperandNeeded = false;
2785 for (unsigned i = 0; i < sourceShape.size(); i++) {
2786 int64_t dimShape = sourceShape[i];
2787 AffineExpr dimExpr = sourceMap.getResult(i);
2788 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2789 newShape.push_back(dimShape);
2790 continue;
2791 }
2792 // Dimension has a dynamic shape and corresponding affine dim
2793 // expression is present in the map. So assign the size for the
2794 // given affine dim expression to the dimension.
2795 newShape.push_back(affineExprToSize[dimExpr]);
2796 newOperandNeeded = true;
2797 }
2798 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2799 sourceType.getEncoding());
2800 if (newOperandNeeded) {
2801 changeNeeded = true;
2802 // Get the new operand value given its size and element type by
2803 // casting it.
2804 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2805 unsigned index = opOperand->getOperandNumber();
2806 newOperands[index] = newOperand;
2807 }
2808 if (linalgOp.isDpsInit(opOperand))
2809 resultTypes.push_back(resultType);
2810}
2811
2812/// Static shapes for the operands can be inferred if any one of the operands
2813/// have a static shape. This can be done by referring to the affine dim
2814/// expressions for the operand.
2815struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2816 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2817
2818 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2819 PatternRewriter &rewriter) const override {
2820 if (!linalgOp.hasPureTensorSemantics())
2821 return failure();
2822
2823 // Maps must be projected permutations.
2824 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2825 return !map.isProjectedPermutation();
2826 }))
2827 return failure();
2828
2829 // Maps affine dim expressions to the static size of that dimension.
2830 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2831 Location loc = linalgOp.getLoc();
2832
2833 // For each of the affine dim expression, check if the size is known. If
2834 // known add that in the map.
2835 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2836
2837 SmallVector<Value> newOperands;
2838 SmallVector<Type> resultTypes;
2839
2840 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2841 // change in their types.
2842 bool changeNeeded = false;
2843 newOperands.reserve(linalgOp->getNumOperands());
2844 resultTypes.reserve(linalgOp.getNumDpsInits());
2845
2846 // Iterate over all the operands and update the static sizes.
2847 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2848 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2849 affineExprToSize, linalgOp, newOperands,
2850 resultTypes, changeNeeded);
2851 }
2852
2853 // If the generic op has all the required static information, no
2854 // canonicalization needed.
2855 if (!changeNeeded)
2856 return failure();
2857
2858 // Clone op.
2859 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2860 SmallVector<Value> replacements;
2861 replacements.reserve(newOp->getNumResults());
2862 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2863 Value newResult = std::get<1>(it);
2864 Value oldResult = std::get<0>(it);
2865 Type newType = newResult.getType();
2866 Type oldType = oldResult.getType();
2867 replacements.push_back(
2868 (newType != oldType)
2869 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2870 : newResult);
2871 }
2872 rewriter.replaceOp(linalgOp, replacements);
2873 return success();
2874 }
2875};
2876
2877} // namespace
2878
2879// All named ops canonicalizers and folders are auto-generated in the
2880// .cpp.inc.
2881
2882//===----------------------------------------------------------------------===//
2883// SoftmaxOp
2884//===----------------------------------------------------------------------===//
2885
2886LogicalResult SoftmaxOp::verify() {
2887 ShapedType inputType = getInputOperandType();
2888 ShapedType outputType = getOutputOperandType();
2889
2890 ArrayRef<int64_t> inputShape = inputType.getShape();
2891 ArrayRef<int64_t> outputShape = outputType.getShape();
2892 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2893 return emitOpError("incompatible output shape");
2894
2895 int64_t inputRank = getInputOperandRank();
2896 int64_t dimension = getDimension();
2897 if ((dimension < 0) || (dimension >= inputRank))
2898 return emitOpError("incorrect dimension specified");
2899
2900 return success();
2901}
2902
2903SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2904 int64_t operandRank = getInputOperandRank();
2905 SmallVector<Range> loopBounds(operandRank);
2906 Location loc = getLoc();
2907 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2908 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2909 Value source = getInput();
2910 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2911 loopBounds[dim].offset = zero;
2912 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2913 loopBounds[dim].stride = one;
2914 }
2915 return loopBounds;
2916}
2917
2918SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2919 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2920 utils::IteratorType::parallel);
2921 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2922 return iteratorTypes;
2923}
2924
2925FailureOr<TilingResult>
2926SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2927 ArrayRef<OpFoldResult> offsets,
2928 ArrayRef<OpFoldResult> sizes) {
2929 int64_t rank = getInputOperandRank();
2930 auto oneAttr = builder.getI64IntegerAttr(1);
2931 SmallVector<OpFoldResult> strides(rank, oneAttr);
2932 SmallVector<Value> tiledOperands;
2933 Operation *inputSlice =
2934 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2935 if (!inputSlice) {
2936 return emitOpError("failed to compute input slice");
2937 }
2938 tiledOperands.emplace_back(inputSlice->getResult(0));
2939 Operation *outputSlice =
2940 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2941 if (!outputSlice) {
2942 return emitOpError("failed to compute output slice");
2943 }
2944 tiledOperands.emplace_back(outputSlice->getResult(0));
2945
2946 SmallVector<Type, 4> resultTypes;
2947 if (hasPureTensorSemantics())
2948 resultTypes.push_back(tiledOperands[1].getType());
2949 Operation *tiledOp =
2950 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2951
2952 return TilingResult{
2953 {tiledOp},
2954 SmallVector<Value>(tiledOp->getResults()),
2955 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2956}
2957
2958LogicalResult SoftmaxOp::getResultTilePosition(
2959 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2960 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2961 SmallVector<OpFoldResult> &resultSizes) {
2962 if (resultNumber == 0) {
2963 resultOffsets.assign(offsets.begin(), offsets.end());
2964 resultSizes.assign(sizes.begin(), sizes.end());
2965 return success();
2966 }
2967 return failure();
2968}
2969
2970// cast(dynamic) -> static.
2971LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2972 return memref::foldMemRefCast(*this);
2973}
2974
2975LogicalResult
2976SoftmaxOp::reifyResultShapes(OpBuilder &b,
2977 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2978 SmallVector<OpFoldResult> shapes;
2979 Location loc = getOperation()->getLoc();
2980 IRRewriter rewriter(b);
2981 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2982 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2983 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2984 if (!outputShapedType.isDynamicDim(dim)) {
2985 // Static dim: Return IntegerAttr.
2986 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2987 } else {
2988 // Dynamic dim: Return Value.
2989 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2990 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2991 }
2992 }
2993 reifiedReturnShapes.emplace_back(std::move(shapes));
2994 return success();
2995}
2996
2997void SoftmaxOp::getEffects(
2998 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2999 &effects) {
3000 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
3001 if (!llvm::isa<MemRefType>(operand.getType()))
3002 continue;
3003 effects.emplace_back(MemoryEffects::Read::get(),
3004 &getOperation()->getOpOperand(index), /*stage=*/0,
3005 /*effectOnFullRegion=*/true,
3007 }
3008
3009 for (OpOperand &operand : getDpsInitsMutable()) {
3010 if (!llvm::isa<MemRefType>(operand.get().getType()))
3011 continue;
3012 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
3013 /*effectOnFullRegion=*/true,
3015 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
3016 /*effectOnFullRegion=*/true,
3018 }
3019}
3020
3021// Helper functions for softmax decomposition.
3022// @{
3023
3024// Helper function to produce the iterator types (reduction or parallel) and
3025// affine maps for the iterators used in the decomposition of softmax.
3026// This method creates:
3027// If allParallel == true:
3028// - iterator type: {parallel, ..., parallel}
3029// - affine maps:
3030// -- identity with inputRank dimensions.
3031// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3032// where N == inputRank.
3033//
3034// If allParallel == false:
3035// - iterator type at dim(i) == parallel for i != \p dim and
3036// dim(dim) == reduction.
3037// - affine map:
3038// -- identity with inputRank dimensions.
3039// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3040// where N == inputRank.
3041static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3043 int64_t dim, bool allParallel = false) {
3044 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
3045 utils::IteratorType::parallel);
3046 if (!allParallel)
3047 iteratorTypes[dim] = utils::IteratorType::reduction;
3048 MLIRContext *ctxt = builder.getContext();
3049 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
3050 SmallVector<AffineExpr, 2> affineExprs;
3051 for (int i = 0; i < inputRank; i++) {
3052 if (i != dim)
3053 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
3054 }
3055 auto reductionMap =
3056 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
3057 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
3058 return std::make_tuple(iteratorTypes, indexingMaps);
3059}
3060
3061// Helper function to produce a linalg.generic that computes a reduction on
3062// dimension \p dim with the operation type \p T.
3063template <typename T>
3064static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
3065 int64_t dim) {
3066 auto inputType = cast<ShapedType>(input.getType());
3067 ArrayRef<int64_t> inputShape = inputType.getShape();
3068 int64_t inputRank = inputShape.size();
3069 auto [iteratorTypes, indexingMaps] =
3070 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
3071 assert(indexingMaps.size() == 2 &&
3072 "We should have two maps: 1 for the input, 1 for the output");
3073 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3074
3075 auto genericOp = linalg::GenericOp::create(
3076 builder, loc, output.getType(), input, output, indexingMaps,
3077 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
3078 Value result = T::create(b, loc, args[0], args[1]);
3079 linalg::YieldOp::create(b, loc, result);
3080 });
3081 return genericOp.getResult(0);
3082}
3083
3084/// Produce a linalg generic that computes the second step of the softmax
3085/// decomposition: res = exp(input - max), where \p max is the max of \p input
3086/// on dimension \p dim.
3087static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3088 Value max, Value output, int64_t dim) {
3089 auto inputType = cast<ShapedType>(input.getType());
3090 ArrayRef<int64_t> inputShape = inputType.getShape();
3091 int64_t inputRank = inputShape.size();
3092 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3093 builder, inputRank, dim, /*allParallel=*/true);
3094 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3095 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3096 // Add the affine map for the output argument.
3097 indexingMaps.push_back(indexingMaps[0]);
3098 auto genericOp = linalg::GenericOp::create(
3099 builder, loc, input.getType(), ValueRange{input, max}, output,
3100 indexingMaps, iteratorTypes,
3101 [&](OpBuilder &b, Location loc, ValueRange args) {
3102 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3103 Value result = math::ExpOp::create(b, loc, diff);
3104 linalg::YieldOp::create(b, loc, result);
3105 });
3106 return genericOp.getResult(0);
3107}
3108
3109/// Produce a linalg generic that computes the final step of the softmax
3110/// decomposition.
3111/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3112/// yield n / d
3113/// }
3114static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3115 Value denominator, Value output, int64_t dim) {
3116 auto inputType = cast<ShapedType>(numerator.getType());
3117 ArrayRef<int64_t> inputShape = inputType.getShape();
3118 int64_t inputRank = inputShape.size();
3119 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3120 builder, inputRank, dim, /*allParallel=*/true);
3121 assert(indexingMaps.size() == 2 &&
3122 "We should have one map for each input (2)");
3123 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3124 // Add the affine map for the output tensor.
3125 indexingMaps.push_back(indexingMaps[0]);
3126 auto genericOp = linalg::GenericOp::create(
3127 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3128 output, indexingMaps, iteratorTypes,
3129 [&](OpBuilder &b, Location loc, ValueRange args) {
3130 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3131 linalg::YieldOp::create(b, loc, result);
3132 });
3133 return genericOp.getResult(0);
3134}
3135// @} End helper functions for softmax decomposition.
3136
3137/// Given an N-dimensional tensor x, this method converts
3138/// softmax(x) to the following sequence of operations:
3139///
3140/// 1. Compute the max of x along dimension d. This results
3141/// in a N-1 dimensional tensor m.
3142/// m = max(x, dim = d)
3143///
3144/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3145/// a N dimensional tensor z.
3146/// z = exp(x - m)
3147///
3148/// 3. Compute the sum of z along dimension d. This results in
3149/// a N-1 dimensional tensor l.
3150/// l = sum(z, dim = d)
3151///
3152/// 4. Divide z and l. This gives the N-dimensional softmax.
3153/// softmax = z / l
3154///
3155FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3156 OpBuilder::InsertionGuard guard(b);
3157 b.setInsertionPoint(*this);
3158 Location loc = getLoc();
3159 Value input = getInput();
3160 ShapedType inputType = getInputOperandType();
3161 Type elementType = inputType.getElementType();
3162 int64_t reductionDim = getDimension();
3163 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3164 Value output = getOutput();
3165 dims.erase(dims.begin() + reductionDim);
3166 // Step 1: Compute max along dim.
3167 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3168 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3169 elementType, b, loc,
3170 /*useOnlyFiniteValue=*/true);
3171 Value neutralForMaxFInit =
3172 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3173 .result();
3174 Value max =
3175 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3176
3177 // Step 2: Subtract max from input and exponentiate.
3178 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3179
3180 // Step 3: Compute sum along dim.
3181 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3182 b, loc, /*useOnlyFiniteValue=*/true);
3183 Value zeroInit =
3184 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3185 Value denominator =
3186 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3187
3188 // Step 4: Compute softmax.
3189 Value result =
3190 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3191 return SmallVector<Value>{result};
3192}
3193
3194//===----------------------------------------------------------------------===//
3195// WinogradFilterTransformOp
3196//===----------------------------------------------------------------------===//
3197
3198LogicalResult WinogradFilterTransformOp::verify() {
3199 auto filterType = cast<ShapedType>(getFilter().getType());
3200 ArrayRef<int64_t> filterShape = filterType.getShape();
3201 int64_t filterH = filterShape[getFilterHDim()];
3202 int64_t filterW = filterShape[getFilterWDim()];
3203 WinogradConv2DFmr fmr = getFmr();
3204 int64_t m, r;
3205 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3206
3207 if (filterH != r && filterH != 1)
3208 return emitOpError("expect filter height either equals to r or 1");
3209 if (filterW != r && filterW != 1)
3210 return emitOpError("expect filter width either equals to r or 1");
3211 if (filterH == 1 && filterW == 1)
3212 return emitOpError("expect either filter height or width equals to r");
3213
3214 SmallVector<int64_t> expectedOutputShape;
3215 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3216 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3217 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3218 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3219
3220 auto outputType = cast<ShapedType>(getOutput().getType());
3221 ArrayRef<int64_t> outputShape = outputType.getShape();
3222 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3223 return emitOpError("the output shape is not expected");
3224 }
3225 return success();
3226}
3227
3228SmallVector<Range>
3229WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3230 Location loc = getLoc();
3231 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3232 IntegerAttr oneAttr = builder.getIndexAttr(1);
3233 Value filter = getFilter();
3234 int64_t filterRank = getFilterOperandRank();
3235 SmallVector<Range> loopBounds(filterRank);
3236 for (unsigned dim = 0; dim < filterRank; ++dim) {
3237 loopBounds[dim].offset = zeroAttr;
3238 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3239 loopBounds[dim].stride = oneAttr;
3240 }
3241 return loopBounds;
3242}
3243
3244SmallVector<utils::IteratorType>
3245WinogradFilterTransformOp::getLoopIteratorTypes() {
3246 int64_t filterRank = getFilterOperandRank();
3247 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3248 utils::IteratorType::parallel);
3249 return iteratorTypes;
3250}
3251
3252LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3253 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3254 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3255 SmallVector<OpFoldResult> &resultSizes) {
3256 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3257 ShapedType filterType = getFilterOperandType();
3258 ArrayRef<int64_t> filterShape = filterType.getShape();
3259 int64_t filterH = filterShape[getFilterHDim()];
3260 int64_t filterW = filterShape[getFilterWDim()];
3261 WinogradConv2DFmr fmr = getFmr();
3262 int64_t m, r;
3263 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3264 int64_t alpha = m + r - 1;
3265 int64_t alphaH = filterH != 1 ? alpha : 1;
3266 int64_t alphaW = filterW != 1 ? alpha : 1;
3267 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3268 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3269
3270 resultOffsets.append(
3271 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3272 resultSizes.append(
3273 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3274
3275 return success();
3276}
3277
3278/// Implement tiling for winograd_filter_transform
3279/// The input of winograd_filter_transform is (F, KH, KW, C).
3280/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3281/// Users can specify the tile sizes of F and C.
3282/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3283/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3284FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3285 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3286 ArrayRef<OpFoldResult> sizes) {
3287 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3288 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3289 ShapedType filterType = getFilterOperandType();
3290 ArrayRef<int64_t> filterShape = filterType.getShape();
3291 int64_t filterH = filterShape[getFilterHDim()];
3292 int64_t filterW = filterShape[getFilterWDim()];
3293 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3294 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3295 SmallVector<Value> tiledOperands;
3296 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3297
3298 sliceOffsets.append(
3299 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3300 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3301 sizes[getFilterCDim()]});
3302 int64_t filterRank = getFilterOperandRank();
3303 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3304 Location loc = getLoc();
3305 auto filterSlice = tensor::ExtractSliceOp::create(
3306 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3307 tiledOperands.emplace_back(filterSlice);
3308
3309 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3310 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3311 resultSizes)))
3312 return failure();
3313
3314 int64_t outputRank = getOutputOperandRank();
3315 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3316 auto outputSlice = tensor::ExtractSliceOp::create(
3317 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3318 tiledOperands.emplace_back(outputSlice);
3319
3320 SmallVector<Type> resultTypes;
3321 resultTypes.push_back(tiledOperands[1].getType());
3322 Operation *tiledOp =
3323 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3324
3325 return TilingResult{
3326 {tiledOp},
3327 SmallVector<Value>(tiledOp->getResults()),
3328 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3329}
3330
3331//===----------------------------------------------------------------------===//
3332// WinogradInputTransformOp
3333//===----------------------------------------------------------------------===//
3334
3335LogicalResult WinogradInputTransformOp::verify() {
3336 auto inputType = cast<ShapedType>(getInput().getType());
3337 ArrayRef<int64_t> inputShape = inputType.getShape();
3338 int64_t inputH = inputShape[getInputHDim()];
3339 int64_t inputW = inputShape[getInputWDim()];
3340 WinogradConv2DFmr fmr = getFmr();
3341 int64_t m, r;
3342 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3343 int64_t tileSize = m + r - 1;
3344
3345 auto outputType = cast<ShapedType>(getOutput().getType());
3346 ArrayRef<int64_t> outputShape = outputType.getShape();
3347 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3348 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3349
3350 SmallVector<int64_t> expectedOutputShape(6, inputH);
3351 if (ShapedType::isDynamic(inputH)) {
3352 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3353 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3354 } else {
3355 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3356 expectedOutputShape[getOutputTileHDim()] =
3357 leftTransform ? (inputH - (r - 1)) / m : inputH;
3358 }
3359 if (ShapedType::isDynamic(inputW)) {
3360 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3361 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3362 } else {
3363 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3364 expectedOutputShape[getOutputTileWDim()] =
3365 rightTransform ? (inputW - (r - 1)) / m : inputW;
3366 }
3367 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3368 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3369
3370 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3371 return emitOpError("the output shape is not expected");
3372 }
3373 return success();
3374}
3375
3376SmallVector<Range>
3377WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3378 Location loc = getLoc();
3379 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3380 IntegerAttr oneAttr = builder.getIndexAttr(1);
3381 Value output = getOutput();
3382 int64_t outputRank = getOutputOperandRank();
3383 SmallVector<Range> loopBounds(outputRank);
3384 for (unsigned dim = 0; dim < outputRank; ++dim) {
3385 loopBounds[dim].offset = zeroAttr;
3386 // alphaH, alphaW, tileH, tileW, N, C
3387 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3388 loopBounds[dim].stride = oneAttr;
3389 }
3390 return loopBounds;
3391}
3392
3393SmallVector<utils::IteratorType>
3394WinogradInputTransformOp::getLoopIteratorTypes() {
3395 int64_t outputRank = getOutputOperandRank();
3396 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3397 utils::IteratorType::parallel);
3398 return iteratorTypes;
3399}
3400
3401LogicalResult WinogradInputTransformOp::getResultTilePosition(
3402 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3403 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3404 SmallVector<OpFoldResult> &resultSizes) {
3405 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3406 ShapedType outputType = getOutputOperandType();
3407 ArrayRef<int64_t> outputShape = outputType.getShape();
3408 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3409 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3410
3411 WinogradConv2DFmr fmr = getFmr();
3412 int64_t m, r;
3413 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3414 int64_t alpha = m + r - 1;
3415 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3416 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3417
3418 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3419 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3420
3421 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3422 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3423 offsets[getOutputCDim()]});
3424 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3425 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3426 sizes[getOutputCDim()]});
3427
3428 return success();
3429}
3430
3431/// Implement tiling for winograd_input_transform
3432/// The input of winograd_input_transform is (N, H, W, C).
3433/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3434/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3435/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3436/// the values for the sizes of tileH, tileW, N, C for one tile.
3437FailureOr<TilingResult>
3438WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3439 ArrayRef<OpFoldResult> offsets,
3440 ArrayRef<OpFoldResult> sizes) {
3441 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3442 WinogradConv2DFmr fmr = getFmr();
3443 int64_t m, r;
3444 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3445
3446 ShapedType outputType = getOutputOperandType();
3447 ArrayRef<int64_t> outputShape = outputType.getShape();
3448 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3449 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3450
3451 Location loc = getLoc();
3452 MLIRContext *context = builder.getContext();
3453 auto identityAffineMap =
3454 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3455 auto offsetAffineMap =
3456 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3457 Value mappedOffsetH = affine::makeComposedAffineApply(
3458 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3459 offsets[getOutputTileHDim()]);
3460 Value mappedOffsetW = affine::makeComposedAffineApply(
3461 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3462 offsets[getOutputTileWDim()]);
3463 auto sizeAffineMap = AffineMap::get(
3464 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3465 Value mappedSizeH = affine::makeComposedAffineApply(
3466 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3467 Value mappedSizeW = affine::makeComposedAffineApply(
3468 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3469
3470 SmallVector<Value> tiledOperands;
3471 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3472
3473 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3474 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3475 sliceOffsets.append(
3476 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3477 OpFoldResult sizeH =
3478 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3479 OpFoldResult sizeW =
3480 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3481 sliceSizes.append(
3482 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3483 int64_t inputRank = getInputOperandRank();
3484 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3485 auto inputSlice = tensor::ExtractSliceOp::create(
3486 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3487 tiledOperands.emplace_back(inputSlice);
3488
3489 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3490 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3491 resultSizes)))
3492 return failure();
3493
3494 int64_t outputRank = getOutputOperandRank();
3495 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3496 auto outputSlice = tensor::ExtractSliceOp::create(
3497 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3498 tiledOperands.emplace_back(outputSlice);
3499
3500 SmallVector<Type> resultTypes;
3501 resultTypes.push_back(tiledOperands[1].getType());
3502 Operation *tiledOp =
3503 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3504
3505 return TilingResult{
3506 {tiledOp},
3507 SmallVector<Value>(tiledOp->getResults()),
3508 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3509}
3510
3511//===----------------------------------------------------------------------===//
3512// WinogradOutputTransformOp
3513//===----------------------------------------------------------------------===//
3514
3515LogicalResult WinogradOutputTransformOp::verify() {
3516 auto valueType = cast<ShapedType>(getValue().getType());
3517 ArrayRef<int64_t> valueShape = valueType.getShape();
3518 int64_t valueH = valueShape[getValueAlphaHDim()];
3519 int64_t valueW = valueShape[getValueAlphaWDim()];
3520 int64_t valueTileH = valueShape[getValueTileHDim()];
3521 int64_t valueTileW = valueShape[getValueTileWDim()];
3522 WinogradConv2DFmr fmr = getFmr();
3523 int64_t m, r;
3524 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3525 bool leftTransform = valueH != 1;
3526 bool rightTransform = valueW != 1;
3527
3528 int64_t outputRank = getOutputOperandRank();
3529 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3530 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3531 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3532 } else {
3533 if (valueH != (leftTransform ? m + r - 1 : 1))
3534 return emitOpError("expect input height equals to input tile size");
3535 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3536 }
3537 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3538 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3539 } else {
3540 if (valueW != (rightTransform ? m + r - 1 : 1))
3541 return emitOpError("expect input width equals to input tile size");
3542 expectedOutputShape[getOutputWDim()] =
3543 (rightTransform ? m : 1) * valueTileW;
3544 }
3545 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3546 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3547
3548 auto outputType = cast<ShapedType>(getOutput().getType());
3549 ArrayRef<int64_t> outputShape = outputType.getShape();
3550 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3551 return emitOpError("the output shape is not expected");
3552 }
3553 return success();
3554}
3555
3556SmallVector<Range>
3557WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3558 Location loc = getLoc();
3559 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3560 IntegerAttr oneAttr = builder.getIndexAttr(1);
3561 Value value = getValue();
3562 int64_t valueRank = getValueOperandRank();
3563 SmallVector<Range> loopBounds(valueRank);
3564 for (unsigned dim = 0; dim < valueRank; ++dim) {
3565 loopBounds[dim].offset = zeroAttr;
3566 // alphaH, alphaW, tileH, tileW, N, F
3567 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3568 loopBounds[dim].stride = oneAttr;
3569 }
3570 return loopBounds;
3571}
3572
3573SmallVector<utils::IteratorType>
3574WinogradOutputTransformOp::getLoopIteratorTypes() {
3575 int64_t valueRank = getValueOperandRank();
3576 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3577 utils::IteratorType::parallel);
3578 return iteratorTypes;
3579}
3580
3581LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3582 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3583 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3584 SmallVector<OpFoldResult> &resultSizes) {
3585 WinogradConv2DFmr fmr = getFmr();
3586 int64_t m, r;
3587 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3588
3589 Location loc = getLoc();
3590 MLIRContext *context = builder.getContext();
3591 auto identityAffineMap =
3592 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3593 auto affineMap =
3594 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3595
3596 ShapedType valueType = getValueOperandType();
3597 ArrayRef<int64_t> valueShape = valueType.getShape();
3598 int64_t valueH = valueShape[0];
3599 int64_t valueW = valueShape[1];
3600 Value mappedOffsetH = affine::makeComposedAffineApply(
3601 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3602 offsets[getValueTileHDim()]);
3603 Value mappedOffsetW = affine::makeComposedAffineApply(
3604 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3605 offsets[getValueTileWDim()]);
3606 Value mappedSizeH = affine::makeComposedAffineApply(
3607 builder, loc, affineMap, sizes[getValueTileHDim()]);
3608 Value mappedSizeW = affine::makeComposedAffineApply(
3609 builder, loc, affineMap, sizes[getValueTileWDim()]);
3610
3611 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3612 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3613 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3614 OpFoldResult sizeH =
3615 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3616 OpFoldResult sizeW =
3617 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3618
3619 resultOffsets.append(
3620 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3621 resultSizes.append(
3622 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3623 return success();
3624}
3625
3626/// Implement tiling for winograd_output_transform
3627/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3628/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3629/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3630/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3631/// for the sizes of tileH, tileW, N, F for one tile.
3632FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3633 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3634 ArrayRef<OpFoldResult> sizes) {
3635 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3636 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3637 Location loc = getLoc();
3638 SmallVector<Value> tiledOperands;
3639 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3640
3641 ShapedType valueType = getValueOperandType();
3642 ArrayRef<int64_t> valueShape = valueType.getShape();
3643 int64_t alphaH = valueShape[getValueAlphaHDim()];
3644 int64_t alphaW = valueShape[getValueAlphaWDim()];
3645 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3646 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3647
3648 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3649 offsets[getValueTileWDim()], offsets[getValueNDim()],
3650 offsets[getValueFDim()]});
3651 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3652 sizes[getValueTileWDim()], sizes[getValueNDim()],
3653 sizes[getValueFDim()]});
3654 int64_t valueRank = getValueOperandRank();
3655 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3656 auto valueSlice = tensor::ExtractSliceOp::create(
3657 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3658 tiledOperands.emplace_back(valueSlice);
3659
3660 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3661 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3662 resultSizes)))
3663 return failure();
3664
3665 int64_t outputRank = getOutputOperandRank();
3666 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3667 auto outputSlice = tensor::ExtractSliceOp::create(
3668 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3669 tiledOperands.emplace_back(outputSlice);
3670
3671 SmallVector<Type> resultTypes;
3672 resultTypes.push_back(tiledOperands[1].getType());
3673 Operation *tiledOp =
3674 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3675
3676 return TilingResult{
3677 {tiledOp},
3678 SmallVector<Value>(tiledOp->getResults()),
3679 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3680}
3681
3682//===----------------------------------------------------------------------===//
3683// LinalgDialect
3684// TODO: Merge with the LinalgDialect block at the bottom
3685//===----------------------------------------------------------------------===//
3686
3687// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3688static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3689 auto explicitRange = subMap.getResults();
3690 auto defaultRange = fullMap.getResults();
3691 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3692 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3693 llvm::set_union(explicitSet, defaultSet);
3694 return explicitSet == defaultSet;
3695}
3696
3697/// Check if the user defined map is valid broadcast map. Here broadcast
3698/// indexing maps are defined in context of corresponding default indexing maps
3699/// for the given Op. This way the check becomes very simple i.e just check the
3700/// number of result dims.
3701/// Returns true if the explictMap is broadcasted with respect to the
3702/// defaultMap.
3703static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3704 return explictMap.getNumResults() < defaultMap.getNumResults();
3705}
3706
3707/// Verifies the broadcast and transpose semantic sepecified by the explicit
3708/// indexing map for the MatmulOp \p op for each operand specified by \p
3709/// opIndex.
3710static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3711 unsigned opIndex) {
3712 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3713 SmallVector<AffineMap, 3> defaultIndexingMaps =
3714 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3715
3716 auto opIndexingMap = opIndexingMaps[opIndex];
3717 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3718 // Check general validity of indexing map results.
3719 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3720 return matmulOp->emitOpError()
3721 << "Unexpected dim expression in map result.";
3722
3723 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3724 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3725 return matmulOp->emitOpError()
3726 << "Invalid broadcast requested, should be (d2).";
3727 }
3728 return success();
3729 }
3730 return success();
3731}
3732
3733// Check general validity of input indexing map of
3734// BatchMatmulOp/BatchReduceMatmulOp.
3735template <typename OpTy>
3736static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3737 AffineMap opIndexingMap,
3738 AffineMap defaultIndexingMap, bool isLHS) {
3739 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3740 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3741 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3742 // Check the result dims are valid.
3743 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3744 return batchVariantMatmulOp->emitOpError()
3745 << "Unexpected result dim expression (outside the set of default "
3746 "result dims).";
3747
3748 // Check for valid number of result dims of input maps.
3749 if (opIndexingMap.getNumResults() > 3)
3750 return batchVariantMatmulOp->emitOpError()
3751 << "no. of result dim expressions exceeds 3.";
3752
3753 auto hasValidBatchDim = [](AffineMap map) {
3754 AffineExpr batchDim = map.getResult(0);
3755 return batchDim.isFunctionOfDim(0);
3756 };
3757
3758 // Check if the requested broadcast is valid.
3759 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3760 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3761 return batchVariantMatmulOp->emitOpError()
3762 << "Invalid broadcast requested.";
3763 } else if (!hasValidBatchDim(opIndexingMap)) {
3764 return batchVariantMatmulOp->emitOpError()
3765 << "Invalid batch dimension expression.";
3766 }
3767 return success();
3768}
3769
3770/// This function checks if the given AffineMap for the output of a
3771/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3772/// dimensions and if the output map result dimensions are valid.
3773template <typename OpTy>
3774static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3775 AffineMap opIndexingMap) {
3776 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3777 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3778 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3779 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3780 opIndexingMap.getNumResults() != 3) {
3781
3782 return batchVariantMatmulOp->emitOpError()
3783 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3784 << ").";
3785 }
3786 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3787 opIndexingMap.getNumResults() != 2) {
3788 return batchVariantMatmulOp->emitOpError()
3789 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3790 << ").";
3791 }
3792
3793 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3794 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3795 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3796 outputMap.getResult(1).isFunctionOfDim(1) &&
3797 outputMap.getResult(2).isFunctionOfDim(2)
3798 : outputMap.getResult(0).isFunctionOfDim(1) &&
3799 outputMap.getResult(1).isFunctionOfDim(2);
3800 };
3801
3802 if (!areValidOutputResultDim(opIndexingMap)) {
3803 return batchVariantMatmulOp->emitOpError()
3804 << "Invalid output map result dimension.";
3805 }
3806
3807 return success();
3808}
3809
3810/// Verifies the broadcast and transpose semantic specified by the explicit
3811/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3812/// specified by opIndex.
3813template <typename OpTy>
3814static LogicalResult
3816 unsigned opIndex) {
3817 SmallVector<AffineMap, 3> opIndexingMaps =
3818 batchVariantMatmulOp.getIndexingMapsArray();
3819 SmallVector<AffineMap, 3> defaultIndexingMaps =
3820 batchVariantMatmulOp.getDefaultIndexingMaps(
3821 batchVariantMatmulOp->getContext());
3822
3823 if (opIndexingMaps.size() != 3)
3824 return batchVariantMatmulOp->emitOpError()
3825 << "Indexing_map attribute must have 3 affine maps.";
3826
3827 auto opIndexingMap = opIndexingMaps[opIndex];
3828 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3829
3830 if (opIndex == 2 &&
3831 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3832 return failure();
3833
3834 if (opIndex != 2 &&
3835 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3836 defaultIndexingMap, opIndex == 0)))
3837 return failure();
3838
3839 return success();
3840}
3841
3842namespace mlir {
3843namespace linalg {
3844
3845std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3846 if (m == 2 && r == 3)
3847 return WinogradConv2DFmr::F_2_3;
3848 if (m == 4 && r == 3)
3849 return WinogradConv2DFmr::F_4_3;
3850 if (m == 2 && r == 5)
3851 return WinogradConv2DFmr::F_2_5;
3852 return std::nullopt;
3853}
3854
3855std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3856 switch (fmr) {
3857 case WinogradConv2DFmr::F_2_3:
3858 return {2, 3};
3859 case WinogradConv2DFmr::F_4_3:
3860 return {4, 3};
3861 case WinogradConv2DFmr::F_2_5:
3862 return {2, 5};
3863 }
3864 llvm_unreachable("Unkown WinogradConv2DFmr");
3865}
3866
3867//===----------------------------------------------------------------------===//
3868// MatMulOp
3869//===----------------------------------------------------------------------===//
3870
3871static FailureOr<SmallVector<SmallVector<int64_t>>>
3874 for (auto map : maps) {
3875 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3876 if (!attr)
3877 return failure();
3879 for (auto result : attr.getAffineMap().getResults()) {
3880 auto dim = dyn_cast<AffineDimExpr>(result);
3881 if (!dim)
3882 return failure();
3883 pos.push_back(dim.getPosition());
3884 }
3885 positions.push_back(pos);
3886 }
3887 return positions;
3888}
3889
3890/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3891SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3892 AffineExpr d0, d1, d2;
3893 SmallVector<AffineMap> indexingMaps;
3894 bindDims(context, d0, d1, d2);
3895 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3896 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3897 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3898 return indexingMaps;
3899}
3900
3901bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3902 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3903 if (!maps)
3904 return false;
3905 if (maps.size() != 3)
3906 return false;
3907 auto positions = getAffineResultPositions(maps);
3908 if (failed(positions))
3909 return false;
3910 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3911 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3912 (*positions)[2] == SmallVector<int64_t>{0, 1};
3913}
3914
3915SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3916 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3917 utils::IteratorType::parallel,
3918 utils::IteratorType::reduction};
3919}
3920
3921unsigned MatmulOp::getNumRegionArgs() { return 3; }
3922
3923std::string MatmulOp::getLibraryCallName() {
3924 return generateLibraryCallName(getOperation());
3925}
3926
3927bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3928
3929/// Check if the op has broadcast and/or transpose semantic. Returns true if
3930/// the user defined indexing maps are not equal to default map.
3931bool MatmulOp::hasUserDefinedMaps() {
3932 SmallVector<AffineMap, 3> defaultMaps =
3933 getDefaultIndexingMaps(this->getContext());
3934 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3935 return defaultMaps != explicitMaps;
3936}
3937
3938/// Implements the block region builder for the MatmulOp. This is called by
3939/// 'fillStructuredOpRegion'.
3940void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3941 ArrayRef<NamedAttribute> attrs,
3942 function_ref<InFlightDiagnostic()> emitError) {
3943 if (emitError && block.getNumArguments() != 3) {
3944 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3945 << block.getNumArguments();
3946 return;
3947 }
3948 assert(block.getNumArguments() == 3 &&
3949 "MatmulOp regionBuilder expects 3 args");
3950 RegionBuilderHelper helper(b, block);
3951 SmallVector<Value> yields;
3952
3953 TypeFn castVal = TypeFn::cast_signed;
3954 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3955 return attr.getName() == "cast";
3956 });
3957 if (castIter != attrs.end()) {
3958 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3959 castVal = attr.getValue();
3960 }
3961
3962 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3963 block.getArgument(0));
3964 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3965 block.getArgument(1));
3966 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3967 if (!value1 || !value2 || !value3)
3968 return;
3969 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3970 value3, emitError);
3971 if (!value4)
3972 return;
3973 yields.push_back(value4);
3974 helper.yieldOutputs(yields);
3975}
3976
3977/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3978/// broadcast map must include K dimension.
3979/// TODO: Strict inclusion of K dimension in the broadcast map is not
3980/// necessary for both input matrices simultaneously. We can relax this
3981/// condition to have K dimension for one input matrix map and infer the K
3982/// dimension for other input matrix map from the one already having K
3983/// dimension.
3984bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3985 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3986 AffineExpr expr = bcastMap.getResult(0);
3987 // Invalid map if the common dimension of matmul not found.
3988 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3989}
3990
3991static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3992 if (parser.parseOptionalKeyword("indexing_maps"))
3993 return ArrayAttr{
3994 nullptr}; // Success in case indexing_maps was not provided.
3995
3996 ArrayAttr arrayAttr;
3997 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3998 return failure();
3999
4000 if (llvm::any_of(arrayAttr,
4001 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
4002 return parser.emitError(parser.getCurrentLocation())
4003 << "element of indexing_maps array is not an affine_map";
4004
4005 return arrayAttr;
4006}
4007
4008ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4009 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4010 if (failed(indexingMapsAttr))
4011 return failure();
4012
4013 if (*indexingMapsAttr == nullptr) {
4014 auto indexingMapAttrs = llvm::map_to_vector(
4015 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
4016 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4017 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
4018 }
4019
4020 result.addAttribute("indexing_maps", *indexingMapsAttr);
4021 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
4022 MatmulOp::getRegionBuilder());
4023}
4024
4025void MatmulOp::print(OpAsmPrinter &p) {
4026 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4027 MatmulOp::getDefaultIndexingMaps(getContext()),
4028 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4029 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4030 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4031
4032 std::array<StringRef, 3> elidedAttrs = {
4033 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4034 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4035 elidedAttrs);
4036}
4037
4038/// Verify the user defined indexing maps.
4039LogicalResult MatmulOp::verify() {
4040 // Verification of pure matmul is handled by verifyStructuredOpInterface().
4041 if (!hasUserDefinedMaps())
4042 return success();
4043
4044 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
4045 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
4046 return failure();
4047 }
4048 return success();
4049}
4050
4051LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4052 return memref::foldMemRefCast(*this);
4053}
4054
4055void MatmulOp::getEffects(
4056 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4057 &effects) {
4058 if (hasPureTensorSemantics())
4059 return;
4060 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4061}
4062
4063Speculation::Speculatability MatmulOp::getSpeculatability() {
4064 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4065}
4066
4067SmallVector<AffineMap>
4068MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4069 AffineExpr d0, d1, d2;
4070 MLIRContext *context = builder.getContext();
4071 bindDims(context, d0, d1, d2);
4072 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
4073 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
4074 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4075 return {mapLHS, mapRHS, mapOut};
4076}
4077
4079 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4080 if (!maps)
4081 return false;
4082 if (maps.size() != 3)
4083 return false;
4084 auto positions = getAffineResultPositions(maps);
4085 if (failed(positions))
4086 return false;
4087 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4088 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4089 (*positions)[2] == SmallVector<int64_t>{0, 1};
4090}
4091
4094 ValueRange inputs, ValueRange outputs,
4095 ArrayRef<NamedAttribute> attributes) {
4096 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4097 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4098}
4099
4102 ValueRange inputs, ValueRange outputs,
4103 ArrayRef<NamedAttribute> attributes) {
4104 OperationState state(location, getOperationName());
4105 build(builder, state, inputs, outputs, attributes);
4106 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4107 assert(res && "builder didn't return the right type");
4108 return res;
4109}
4110
4113 TypeRange resultTensorTypes,
4114 ValueRange inputs, ValueRange outputs,
4115 ArrayRef<NamedAttribute> attributes) {
4116 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4117 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4118}
4119
4122 TypeRange resultTensorTypes, ValueRange inputs,
4123 ValueRange outputs,
4124 ArrayRef<NamedAttribute> attributes) {
4125 OperationState state(location, getOperationName());
4126 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4127 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4128 assert(res && "builder didn't return the right type");
4129 return res;
4130}
4131
4134 TypeRange resultTensorTypes,
4135 ValueRange inputs, ValueRange outputs,
4136 Attribute cast,
4137 ArrayRef<NamedAttribute> attributes) {
4138 result.addAttribute("cast", cast);
4139 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4140 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4141}
4142
4145 TypeRange resultTensorTypes, ValueRange inputs,
4146 ValueRange outputs, Attribute cast,
4147 ArrayRef<NamedAttribute> attributes) {
4148 OperationState state(location, getOperationName());
4149 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4150 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4151 assert(res && "builder didn't return the right type");
4152 return res;
4153}
4154
4156 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4158 op->getAttr("indexing_maps"));
4159}
4160
4162MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4163 AffineExpr d0, d1, d2;
4164 MLIRContext *context = builder.getContext();
4165 bindDims(context, d0, d1, d2);
4166 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4167 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4168 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4169 return {mapLHS, mapRHS, mapOut};
4170}
4171
4173 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4174 if (!maps)
4175 return false;
4176 if (maps.size() != 3)
4177 return false;
4178 auto positions = getAffineResultPositions(maps);
4179 if (failed(positions))
4180 return false;
4181 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4182 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4183 (*positions)[2] == SmallVector<int64_t>{0, 1};
4184}
4185
4188 ValueRange inputs, ValueRange outputs,
4189 ArrayRef<NamedAttribute> attributes) {
4190 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4191 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4192}
4193
4196 ValueRange inputs, ValueRange outputs,
4197 ArrayRef<NamedAttribute> attributes) {
4198 OperationState state(location, getOperationName());
4199 build(builder, state, inputs, outputs, attributes);
4200 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4201 assert(res && "builder didn't return the right type");
4202 return res;
4203}
4204
4207 TypeRange resultTensorTypes,
4208 ValueRange inputs, ValueRange outputs,
4209 ArrayRef<NamedAttribute> attributes) {
4210 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4211 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4212}
4213
4216 TypeRange resultTensorTypes, ValueRange inputs,
4217 ValueRange outputs,
4218 ArrayRef<NamedAttribute> attributes) {
4219 OperationState state(location, getOperationName());
4220 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4221 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4222 assert(res && "builder didn't return the right type");
4223 return res;
4224}
4225
4228 TypeRange resultTensorTypes,
4229 ValueRange inputs, ValueRange outputs,
4230 Attribute cast,
4231 ArrayRef<NamedAttribute> attributes) {
4232 result.addAttribute("cast", cast);
4233 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4234 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4235}
4236
4239 TypeRange resultTensorTypes, ValueRange inputs,
4240 ValueRange outputs, Attribute cast,
4241 ArrayRef<NamedAttribute> attributes) {
4242 OperationState state(location, getOperationName());
4243 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4244 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4245 assert(res && "builder didn't return the right type");
4246 return res;
4247}
4248
4250 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4252 op->getAttr("indexing_maps"));
4253}
4254
4256BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4257 AffineExpr d0, d1, d2, d3;
4258 MLIRContext *context = builder.getContext();
4259 bindDims(context, d0, d1, d2, d3);
4260 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4261 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4262 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4263 return {mapLHS, mapRHS, mapOut};
4264}
4265
4267 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4268 if (!maps)
4269 return false;
4270 if (maps.size() != 3)
4271 return false;
4272 auto positions = getAffineResultPositions(maps);
4273 if (failed(positions))
4274 return false;
4275 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4276 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4277 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4278}
4279
4281 OpBuilder &builder, OperationState &result, ValueRange inputs,
4282 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4283 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4284 BatchMatmulOp::getRegionBuilder(),
4285 getDefaultIndexingMaps(builder));
4286}
4287
4290 ValueRange inputs, ValueRange outputs,
4291 ArrayRef<NamedAttribute> attributes) {
4292 OperationState state(location, getOperationName());
4293 build(builder, state, inputs, outputs, attributes);
4294 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4295 assert(res && "builder didn't return the right type");
4296 return res;
4297}
4298
4300 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4301 ValueRange inputs, ValueRange outputs,
4302 ArrayRef<NamedAttribute> attributes) {
4303 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4304 BatchMatmulOp::getRegionBuilder(),
4305 getDefaultIndexingMaps(builder));
4306}
4307
4310 TypeRange resultTensorTypes, ValueRange inputs,
4311 ValueRange outputs,
4312 ArrayRef<NamedAttribute> attributes) {
4313 OperationState state(location, getOperationName());
4314 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4315 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4316 assert(res && "builder didn't return the right type");
4317 return res;
4318}
4319
4321 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4322 ValueRange inputs, ValueRange outputs, Attribute cast,
4323 ArrayRef<NamedAttribute> attributes) {
4324 result.addAttribute("cast", cast);
4325 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4326 BatchMatmulOp::getRegionBuilder(),
4327 getDefaultIndexingMaps(builder));
4328}
4329
4332 TypeRange resultTensorTypes, ValueRange inputs,
4333 ValueRange outputs, Attribute cast,
4334 ArrayRef<NamedAttribute> attributes) {
4335 OperationState state(location, getOperationName());
4336 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4337 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4338 assert(res && "builder didn't return the right type");
4339 return res;
4340}
4341
4343 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4345 op->getAttr("indexing_maps"));
4346}
4347
4349BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4350 AffineExpr d0, d1, d2, d3;
4351 MLIRContext *context = builder.getContext();
4352 bindDims(context, d0, d1, d2, d3);
4353 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4354 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4355 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4356 return {mapLHS, mapRHS, mapOut};
4357}
4358
4360 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4361 if (!maps)
4362 return false;
4363 if (maps.size() != 3)
4364 return false;
4365 auto positions = getAffineResultPositions(maps);
4366 if (failed(positions))
4367 return false;
4368 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4369 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4370 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4371}
4372
4374 OpBuilder &builder, OperationState &result, ValueRange inputs,
4375 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4376 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4377 BatchMatmulOp::getRegionBuilder(),
4378 getDefaultIndexingMaps(builder));
4379}
4380
4383 ValueRange inputs, ValueRange outputs,
4384 ArrayRef<NamedAttribute> attributes) {
4385 OperationState state(location, getOperationName());
4386 build(builder, state, inputs, outputs, attributes);
4387 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4388 assert(res && "builder didn't return the right type");
4389 return res;
4390}
4391
4393 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4394 ValueRange inputs, ValueRange outputs,
4395 ArrayRef<NamedAttribute> attributes) {
4396 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4397 BatchMatmulOp::getRegionBuilder(),
4398 getDefaultIndexingMaps(builder));
4399}
4400
4403 TypeRange resultTensorTypes, ValueRange inputs,
4404 ValueRange outputs,
4405 ArrayRef<NamedAttribute> attributes) {
4406 OperationState state(location, getOperationName());
4407 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4408 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4409 assert(res && "builder didn't return the right type");
4410 return res;
4411}
4412
4414 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4415 ValueRange inputs, ValueRange outputs, Attribute cast,
4416 ArrayRef<NamedAttribute> attributes) {
4417 result.addAttribute("cast", cast);
4418 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4419 BatchMatmulOp::getRegionBuilder(),
4420 getDefaultIndexingMaps(builder));
4421}
4422
4425 TypeRange resultTensorTypes, ValueRange inputs,
4426 ValueRange outputs, Attribute cast,
4427 ArrayRef<NamedAttribute> attributes) {
4428 OperationState state(location, getOperationName());
4429 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4430 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4431 assert(res && "builder didn't return the right type");
4432 return res;
4433}
4434
4436 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4438 op->getAttr("indexing_maps"));
4439}
4440
4441//===----------------------------------------------------------------------===//
4442// ContractOp
4443//===----------------------------------------------------------------------===//
4444
4445SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4446 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4447 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4448 // domains are all the same, and each implements a projected permutation.
4449 // Each iteration space dim must occur for at least one operand and either
4450 // takes part in a contraction/reduction or else has parallel iteration type.
4451 // We have that a dim is a contraction/reduction dim if and only if the dim
4452 // occurs for the output operand. We use this fact for fast inference:
4453 // NB: In case we allow dims to occur solely for one input, the above still
4454 // holds: per the einsum semantics, these are reduction dims as well.
4455 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4456 for (auto result : outAffineMap.getResults()) {
4457 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4458 assert(dimExpr && "affine_map is a projected permutation");
4459 dimsInOutput[dimExpr.getPosition()] = true;
4460 }
4461
4463 for (auto dimOccursInOutput : dimsInOutput)
4464 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4465 : utils::IteratorType::reduction);
4466
4467 return iteratorTypes;
4468}
4469
4470unsigned ContractOp::getNumRegionArgs() { return 3; }
4471
4472/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4473void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4474 ArrayRef<NamedAttribute> attrs,
4475 function_ref<InFlightDiagnostic()> emitError) {
4476 if (emitError && block.getNumArguments() != 3) {
4477 emitError() << "ContractOp regionBuilder expects 3 args, got "
4478 << block.getNumArguments();
4479 return;
4480 }
4481 assert(block.getNumArguments() == 3 &&
4482 "ContractOp regionBuilder expects 3 args");
4483 RegionBuilderHelper helper(b, block);
4484
4485 TypeFn castSignedness = TypeFn::cast_signed;
4486 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4487 return attr.getName() == "cast";
4488 });
4489 if (castIter != attrs.end()) {
4490 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4491 castSignedness = attr.getValue();
4492 }
4493
4494 // TODO: Support fields with operators besides mult & add.
4495 Type outType = block.getArgument(2).getType();
4496 Value lhsAtOutType =
4497 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4498 Value rhsAtOutType =
4499 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4500 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4501 rhsAtOutType, emitError);
4502 if (!productAtOutType)
4503 return;
4504 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4505 productAtOutType, emitError);
4506 if (!result)
4507 return;
4508 helper.yieldOutputs({result});
4509}
4510
4511ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4512 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4513 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4514 return parser.emitError(parser.getCurrentLocation(),
4515 "expected 'indexing_maps' attribute");
4516 result.addAttribute("indexing_maps", *indexingMapsAttr);
4517
4518 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4519 regionBuilder);
4520}
4521
4522void ContractOp::print(OpAsmPrinter &p) {
4523 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4525 p, getOperation(), getInputs(), getOutputs(),
4526 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4527}
4528
4529LogicalResult ContractOp::verify() {
4530 int iterationSpaceDims = -1;
4531 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4532 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4533 // access an input operand (so occurrence count can be at most 2) and
4534 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4535 SmallVector<size_t> inOccurrences;
4536 SmallVector<size_t> outOccurrences;
4537
4538 // A helper so that for each operand's affine_map and type we check that ...
4539 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4540 bool isInput) -> LogicalResult {
4541 // ... the affine_map is a projected permutation;
4542 if (!affineMap.isProjectedPermutation())
4543 return emitError("provided affine_map is not a projected permutation");
4544
4545 // ... the rank of the affine_map's results and corresponding type match;
4546 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4547 if (affineMap.getNumResults() != shapedType.getRank())
4548 return emitError("ranks of shaped operand and results of corresponding "
4549 "affine_map differ");
4550 } else if (affineMap.getNumResults() != 0) {
4551 return emitError("affine_map specifies shaped access while operand has "
4552 "non-shaped type");
4553 }
4554
4555 // ... the rank of the affine_map's domain is the same as those seen prior;
4556 if (iterationSpaceDims == -1) {
4557 iterationSpaceDims = affineMap.getNumDims();
4558 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4559 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4560 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4561 return emitError("iteration spaces of provided affine_maps differ");
4562 }
4563
4564 // ... update counts of dims used to access either an input or the output.
4565 for (AffineExpr affineExpr : affineMap.getResults()) {
4566 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4567 if (!affineDimExpr)
4568 llvm_unreachable("affine_map is a projected permutation");
4569
4570 if (isInput)
4571 inOccurrences[affineDimExpr.getPosition()] += 1;
4572 else
4573 outOccurrences[affineDimExpr.getPosition()] += 1;
4574 }
4575
4576 return success();
4577 };
4578
4579 for (auto &&[affineMap, operandType, isInput] :
4580 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4581 SmallVector<bool>{true, true, false})) {
4582 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4583 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4584 }
4585
4586 bool hasContractingDim = false;
4587 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4588 size_t inOccCount = inOccurrences[dimIndex];
4589 size_t outOccCount = outOccurrences[dimIndex];
4590
4591 // We have a contracting dim if and only if ...
4592 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4593
4594 if (inOccCount == 0 && outOccCount == 0)
4595 return emitError() << "iteration space dim at index " << dimIndex
4596 << " not used to access any operand";
4597
4598 // NB: We disallow a dim which occurs for only one input operand and not
4599 // for the output. In terms of einsum semantics such dims have a
4600 // sensible meaning - namely an additional reduction per each such dim.
4601 // By contrast, the ContractionOpInterface does not know about this
4602 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4603 // while vector.contract's verifier accepts dims of this kind many of
4604 // its lowerings give up on encountering these dims.
4605 // TODO: Remove following once we have comprehensive support for input-only
4606 // reduction dims, at both the linalg- and vector-dialect levels.
4607 if (inOccCount == 1 && outOccCount != 1)
4608 return emitError()
4609 << "iteration space dim at index " << dimIndex
4610 << " is neither a contracting dim nor of parallel iteration type";
4611 }
4612
4613 if (!hasContractingDim)
4614 return emitError("'indexing_maps' do not specify a contracting dimension");
4615
4616 return success();
4617}
4618
4619LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4620 return memref::foldMemRefCast(*this);
4621}
4622
4623void ContractOp::getEffects(
4624 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4625 &effects) {
4626 if (hasPureTensorSemantics())
4627 return;
4628 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4629}
4630
4631Speculation::Speculatability ContractOp::getSpeculatability() {
4632 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4633}
4634
4635//===----------------------------------------------------------------------===//
4636// Implementation of BatchMatmulOp
4637//===----------------------------------------------------------------------===//
4638SmallVector<AffineMap>
4639BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4640 AffineExpr d0, d1, d2, d3;
4641 SmallVector<AffineMap> indexingMaps;
4642 bindDims(context, d0, d1, d2, d3);
4643 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4644 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4645 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4646 return indexingMaps;
4647}
4648
4649bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4650 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4651 if (!maps)
4652 return false;
4653 if (maps.size() != 3)
4654 return false;
4655 auto positions = getAffineResultPositions(maps);
4656 if (failed(positions))
4657 return false;
4658 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4659 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4660 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4661}
4662
4663SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4664 return SmallVector<utils::IteratorType>{
4665 utils::IteratorType::parallel, utils::IteratorType::parallel,
4666 utils::IteratorType::parallel, utils::IteratorType::reduction};
4667}
4668
4669unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4670
4671std::string BatchMatmulOp::getLibraryCallName() {
4672 return generateLibraryCallName(getOperation());
4673}
4674
4675/// Check if the op has broadcast and/or transpose semantic. Returns true if
4676/// the user defined indexing maps are not equal to default map.
4677bool BatchMatmulOp::hasUserDefinedMaps() {
4678 SmallVector<AffineMap, 3> defaultMaps =
4679 getDefaultIndexingMaps(this->getContext());
4680 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4681 return defaultMaps != explicitMaps;
4682}
4683
4684/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4685/// broadcast map must include K dimension.
4686/// TODO: Strict inclusion of K dimension in the broadcast map is not
4687/// necessary for both input matrices simultaneously. We can relax this
4688/// condition to have K dimension for one input matrix map and infer the K
4689/// dimension for other input matrix map from the one already having K
4690/// dimension.
4691bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4692 assert(bcastMap.getNumResults() < 3 &&
4693 "Expected less than 3 result dim expr.");
4694 bool isValid = false;
4695 enum Indices { batchPos, mPos, nPos, kPos };
4696 if (bcastMap.getNumResults() == 1) {
4697 AffineExpr expr = bcastMap.getResult(0);
4698 isValid = expr.isFunctionOfDim(kPos);
4699 } else if (bcastMap.getNumResults() == 2) {
4700 AffineExpr expr0 = bcastMap.getResult(0);
4701 AffineExpr expr1 = bcastMap.getResult(1);
4702 isValid =
4703 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4704 expr0.isFunctionOfDim(mPos)) &&
4705 expr1.isFunctionOfDim(kPos))
4706 : ((expr0.isFunctionOfDim(batchPos) &&
4707 expr1.isFunctionOfDim(kPos)) ||
4708 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4709 }
4710 return isValid;
4711}
4712
4713void BatchMatmulOp::regionBuilder(
4714 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4715 function_ref<InFlightDiagnostic()> emitError) {
4716 if (emitError && block.getNumArguments() != 3) {
4717 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4718 << block.getNumArguments();
4719 return;
4720 }
4721 assert(block.getNumArguments() == 3 &&
4722 "BatchMatmulOp regionBuilder expects 3 args");
4723 RegionBuilderHelper helper(b, block);
4724 SmallVector<Value> yields;
4725
4726 TypeFn castVal = TypeFn::cast_signed;
4727 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4728 return attr.getName() == "cast";
4729 });
4730 if (castIter != attrs.end()) {
4731 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4732 castVal = attr.getValue();
4733 }
4734
4735 auto toType = block.getArgument(2).getType();
4736 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4737 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4738 Value mulVal =
4739 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4740 if (!castValA || !castValB || !mulVal)
4741 return;
4742 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4743 mulVal, emitError);
4744 if (!addVal)
4745 return;
4746 yields.push_back(addVal);
4747 helper.yieldOutputs(yields);
4748}
4749
4750ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4751 SmallVector<Attribute, 3> indexingMapsAttr;
4752 Attribute mapAttr;
4753 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4754 if (parser.parseEqual())
4755 return failure();
4756
4757 if (parser.parseLSquare())
4758 return failure();
4759
4760 do {
4761 if (parser.parseAttribute(mapAttr))
4762 return failure();
4763 if (!isa<AffineMapAttr>(mapAttr)) {
4764 return parser.emitError(parser.getCurrentLocation(),
4765 "expected affine map attribute");
4766 }
4767 indexingMapsAttr.push_back(mapAttr);
4768
4769 if (parser.parseOptionalComma())
4770 break;
4771 } while (true);
4772
4773 if (parser.parseRSquare())
4774 return failure();
4775 }
4776 // Initialize indexingMaps, if not supplied explicitly.
4777 if (indexingMapsAttr.empty()) {
4778 indexingMapsAttr = llvm::map_to_vector(
4779 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4780 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4781 }
4782 result.addAttribute("indexing_maps",
4783 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4784
4785 return ::parseNamedStructuredOp(parser, result,
4786 BatchMatmulOp::getNumRegionArgs(),
4787 BatchMatmulOp::getRegionBuilder());
4788}
4789
4790void BatchMatmulOp::print(OpAsmPrinter &p) {
4791 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4792 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4793 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4794 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4795 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4796
4797 std::array<StringRef, 3> elidedAttrs = {
4798 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4799 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4800 elidedAttrs);
4801}
4802
4803/// Verify the user defined indexing maps.
4804LogicalResult BatchMatmulOp::verify() {
4805 // Verification of pure batch_matmul is handled by
4806 // verifyStructuredOpInterface().
4807 if (!hasUserDefinedMaps())
4808 return success();
4809
4810 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4812 return failure();
4813 }
4814 return success();
4815}
4816
4817LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4818 SmallVectorImpl<OpFoldResult> &) {
4819 return memref::foldMemRefCast(*this);
4820}
4821
4822void BatchMatmulOp::getEffects(
4823 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4824 &effects) {
4825 if (hasPureTensorSemantics())
4826 return;
4827 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4828}
4829
4830Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4831 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4832}
4833
4834//===----------------------------------------------------------------------===//
4835// ElementwiseOp
4836//===----------------------------------------------------------------------===//
4837//
4838namespace {
4839struct ArityGroupAndKind {
4840 // The enum class {Unary, Binary, Ternary, ..}
4841 ElementwiseArityGroup arityGroup;
4842
4843 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4844 union Kind {
4845 UnaryFn unaryFn;
4846 BinaryFn binaryFn;
4847 TernaryFn ternaryFn;
4848 } kind;
4849};
4850
4851unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4852 return static_cast<unsigned>(arityGroup);
4853}
4854} // namespace
4855
4856static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4857 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4858 constexpr int lastBinary =
4859 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4860 constexpr int lastTernary =
4861 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4862
4863 int val = static_cast<int>(kind);
4864 ArityGroupAndKind result;
4865
4866 if (val < lastUnary) {
4867 result.arityGroup = ElementwiseArityGroup::Unary;
4868 result.kind.unaryFn = static_cast<UnaryFn>(val);
4869 return result;
4870 }
4871 if (val < lastBinary) {
4872 result.arityGroup = ElementwiseArityGroup::Binary;
4873 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4874 return result;
4875 }
4876 if (val >= lastTernary) {
4877 llvm_unreachable("unhandled ElementwiseFn");
4878 }
4879 result.arityGroup = ElementwiseArityGroup::Ternary;
4880 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4881 return result;
4882}
4883
4884SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4885 auto rank = getResultRank();
4886 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4887}
4888
4890ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4891 MLIRContext *context) {
4892 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4893 return SmallVector<AffineMap>(numMaps, map);
4894}
4895
4896ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4897 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4898 Attribute attr;
4899 mlir::linalg::ElementwiseKind elemwiseKindVal;
4900 if (parser.parseKeyword("kind") || parser.parseEqual())
4901 return failure();
4902
4903 if (succeeded(parser.parseAttribute(attr))) {
4904 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4905 if (!elemwiseKindAttr)
4906 return parser.emitError(parser.getCurrentLocation(),
4907 "expected ElementwiseKind attribute");
4908 elemwiseKindVal = elemwiseKindAttr.getValue();
4909 } else {
4910 return parser.emitError(parser.getCurrentLocation(),
4911 "expected operation 'kind' attribute");
4912 }
4913 result.addAttribute(
4914 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4915
4916 // Parse optional `indexing_maps`
4917 SmallVector<Attribute, 3> indexingMapsAttr;
4918 Attribute mapAttr;
4919 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4920 if (parser.parseEqual())
4921 return failure();
4922 if (parser.parseLSquare())
4923 return failure();
4924 do {
4925 if (parser.parseAttribute(mapAttr))
4926 return failure();
4927 if (!isa<AffineMapAttr>(mapAttr))
4928 return parser.emitError(parser.getCurrentLocation(),
4929 "expected affine map attribute");
4930 indexingMapsAttr.push_back(mapAttr);
4931 if (parser.parseOptionalComma())
4932 break;
4933 } while (true);
4934 if (parser.parseRSquare())
4935 return failure();
4936 }
4937 // At this stage of parsing the only way to infer number of region
4938 // args is through op kind, as input output tensors are not parsed yet.
4939 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4940 int numRegionArgs =
4941 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4942 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4943 ElementwiseOp::getRegionBuilder())) {
4944 return parser.emitError(parser.getCurrentLocation(),
4945 "unable to parse elemwise op");
4946 }
4947
4948 // Initialize indexingMaps, if not supplied explicitly.
4949 if (indexingMapsAttr.empty()) {
4950 // We need to infer the numDims of the indexing maps from the output
4951 // type which is already parsed by now.
4952 auto resultType = result.operands[result.operands.size() - 1].getType();
4953 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4954 if (!shapedType)
4955 return parser.emitError(parser.getCurrentLocation(),
4956 "return type needs to be shaped type");
4957 auto numDims = shapedType.getRank();
4958 indexingMapsAttr = llvm::map_to_vector(
4959 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4960 parser.getContext()),
4961 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4962 }
4963
4964 result.addAttribute("indexing_maps",
4965 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4966 return success();
4967}
4968
4969void ElementwiseOp::print(OpAsmPrinter &p) {
4970 p << " kind=";
4971 p.printAttribute(getKindAttr());
4972 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4973 "indexing_maps"};
4974 unsigned arity =
4975 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4976 unsigned numDims = getResultRank();
4977
4978 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4979 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4980 getContext()),
4981 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4982
4983 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4984 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4985
4986 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4987 elidedAttrs);
4988}
4989
4990/// Implements the block region builder for the ElementwiseOp. This is called by
4991/// 'fillStructuredOpRegion'.
4992void ElementwiseOp::regionBuilder(
4993 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4994 function_ref<InFlightDiagnostic()> emitError) {
4995 ElementwiseKind elemwiseKind;
4996 for (auto attr : attrs) {
4997 if (attr.getName() == b.getStringAttr("kind")) {
4998 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4999 assert(kindAttr && "op kind attribute incorrectly set");
5000 elemwiseKind = kindAttr.getValue();
5001 break;
5002 }
5003 }
5004
5005 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
5006 auto arityGroup = groupAndKind.arityGroup;
5007 auto kind = groupAndKind.kind;
5008 if (emitError && block.getNumArguments() !=
5009 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
5010 emitError() << "Elementwise regionBuilder expects "
5011 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
5012 << block.getNumArguments();
5013 return;
5014 }
5015 assert(block.getNumArguments() ==
5016 getArityGroupAsUInt(arityGroup) + 1 /*output*/
5017 && "Elementwise regionBuilder number of block args mismatch");
5018
5019 RegionBuilderHelper helper(b, block);
5020 SmallVector<Value> yields;
5021 Value result;
5022
5023 if (arityGroup == ElementwiseArityGroup::Unary) {
5024 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
5025
5026 } else if (arityGroup == ElementwiseArityGroup::Binary) {
5027 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
5028 block.getArgument(1));
5029
5030 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
5031 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
5032 block.getArgument(1), block.getArgument(2));
5033
5034 } else {
5035 assert(false && "found unhandled category in elemwise");
5036 }
5037
5038 yields.push_back(result);
5039 helper.yieldOutputs(yields);
5040}
5041
5042LogicalResult ElementwiseOp::fold(FoldAdaptor,
5043 SmallVectorImpl<OpFoldResult> &) {
5044 return memref::foldMemRefCast(*this);
5045}
5046
5047void ElementwiseOp::getEffects(
5048 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5049 &effects) {
5050 if (hasPureTensorSemantics())
5051 return;
5052 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5053}
5054
5055Speculation::Speculatability ElementwiseOp::getSpeculatability() {
5056 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5057}
5058
5059//===----------------------------------------------------------------------===//
5060// PackOp/UnPackOp Common
5061//===----------------------------------------------------------------------===//
5062
5063template <typename OpTy, typename>
5064SmallVector<int64_t>
5066 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5067 ? packOrUnPack.getDestType()
5068 : packOrUnPack.getSourceType();
5069 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5070 ? packOrUnPack.getSourceType()
5071 : packOrUnPack.getDestType();
5073 packedType.getShape().take_front(unpackedType.getRank()));
5074 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5076 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
5077 }
5078 return result;
5079}
5084
5085// Given the (potentially) updated packed type, `newPackedTy`, generates an
5086// updated mixed-tile-sizes list. For each inner packed dimension that is static
5087// in `newPackedTy`, the tile is set to that static size (replacing SSA values
5088// or mismatched constants). Dynamic packed dimensions preserve the original
5089// tile. The folded tensor type is treated as authoritative for static extents.
5090// Note - packed-type-dim and mixed-tile-size should always match!
5093 ArrayRef<OpFoldResult> mixedTiles) {
5094 SmallVector<OpFoldResult> newMixedTileSizes;
5095 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5096 .getShape()
5097 .take_back(mixedTiles.size()),
5098 mixedTiles)) {
5099 int64_t dimSize = std::get<0>(it);
5100 if (dimSize == ShapedType::kDynamic) {
5101 newMixedTileSizes.push_back(std::get<1>(it));
5102 continue;
5103 }
5104 newMixedTileSizes.push_back(rewriter.getIndexAttr(dimSize));
5105 }
5106
5107 return newMixedTileSizes;
5108}
5109
5110template <typename OpTy>
5111static LogicalResult
5113 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5114 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5115 "applies to only pack or unpack operations");
5116 int64_t destRank = op.getDestRank();
5117 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5118 for (auto dim : llvm::seq<int64_t>(0, destRank))
5119 reifiedReturnShapes[0][dim] =
5120 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5121 return success();
5122}
5123
5124template <typename OpTy>
5126 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5127 "applies to only pack or unpack operations");
5128 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5129 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5130 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5131 assert(tiles.size() == dimsToTile.size() &&
5132 "tiles must match indices of dimension to block");
5133 // bind the dimension `i` with the tile factor.
5134 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5135 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5136 return dimAndTileMapping;
5137}
5138
5139template <typename OpTy>
5141 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5142 "applies to only pack or unpack operations");
5143 Builder builder(op);
5144 SmallVector<OpFoldResult> mixedInnerTiles;
5145 unsigned dynamicValIndex = 0;
5146 for (int64_t staticTile : op.getStaticInnerTiles()) {
5147 if (ShapedType::isStatic(staticTile))
5148 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5149 else
5150 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5151 }
5152 return mixedInnerTiles;
5153}
5154
5155template <typename OpTy>
5157 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5158 "applies to only pack or unpack operations");
5159 SmallVector<Value> dynamicTiles;
5160 SmallVector<int64_t> staticTiles;
5161 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5162 return staticTiles;
5163}
5164
5165/// Returns true if `dimsPos` is invalid. It is invalid when:
5166/// a) It contains duplicate.
5167/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5168/// c) The number of elements in `dimsPos` is > than `rank`.
5170 size_t rank) {
5171 size_t dimsPosSize = dimsPos.size();
5172 if (dimsPosSize > rank)
5173 return true;
5174 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5175 if (dimsPosSize != uniqued.size())
5176 return true;
5177 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5178 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5179 });
5180}
5181
5182template <typename OpTy>
5183static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5184 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5185 "applies to only pack or unpack operations");
5186 Operation *op = packOrUnPack.getOperation();
5187
5188 // Return true if we have a zero-value tile.
5189 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5190 return llvm::any_of(tiles, [](OpFoldResult tile) {
5191 return isa<Attribute>(tile) && isZeroInteger(tile);
5192 });
5193 };
5194
5195 // Verify that the source and destination are ranked types.
5196 if (!packOrUnPack.getSourceType().hasRank() ||
5197 !packOrUnPack.getDestType().hasRank())
5198 return op->emitError("expected both source and destination to have rank");
5199
5200 // Verify that the Operation does not have mixed tensor/buffer semantics.
5201 if (!packOrUnPack.hasPureBufferSemantics() &&
5202 !packOrUnPack.hasPureTensorSemantics())
5203 return op->emitError("mixing tensor and buffer semantics is not allowed");
5204 const unsigned numResults = packOrUnPack.getNumResults();
5205 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5206 return op->emitError("expected 1 result, got ") << numResults;
5207 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5208 return op->emitError("expected 0 results, got ") << numResults;
5209
5210 // Verify tiles. Do not allow zero tiles.
5211 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5212 if (hasZeros(mixedTiles))
5213 return op->emitError("invalid zero tile factor");
5214
5215 // Verify inner_dims_pos and outer_dims_perm.
5216 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5217 ? packOrUnPack.getSourceType()
5218 : packOrUnPack.getDestType();
5219 size_t unpackedRank = unpackedType.getRank();
5220 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5221 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5222 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5223 return op->emitError("invalid inner_dims_pos vector");
5224 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5225 return op->emitError("invalid outer_dims_perm vector");
5226 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5227 return op->emitError("outer_dims_perm must be a permutation or empty");
5228
5229 // Tiling factors must be less than or equal to the input rank for pack (or
5230 // output rank for unpack), and must match the number of `inner_dims_pos`.
5231 if (mixedTiles.size() > unpackedRank) {
5232 return op->emitError("tiling factors must be less than or equal to the "
5233 "input rank for pack or output rank for unpack");
5234 }
5235 if (mixedTiles.size() != innerDimsPos.size()) {
5236 return op->emitError(
5237 "tiling factors must equal the number of dimensions to tile");
5238 }
5239
5240 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5241 ? packOrUnPack.getDestType()
5242 : packOrUnPack.getSourceType();
5243 size_t packedRank = packedType.getRank();
5244 // Require output rank to match input rank + number of blocking factors.
5245 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5246 if (expectedPackedRank != packedRank) {
5247 return op->emitError(
5248 "packed rank != (unpacked rank + num tiling factors), got ")
5249 << packedRank << " != " << expectedPackedRank;
5250 }
5251
5252 // Verify result shape is greater than the minimum expected
5253 // by the pack operation, and that the output shape
5254 // represents full tiles.
5255 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5256 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5257 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5258 for (auto it : llvm::enumerate(llvm::zip(
5259 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5260 int64_t dimSize = std::get<0>(it.value());
5261 if (Attribute attr =
5262 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5263 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5264 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5265 if (dimSize != staticTileSize)
5266 return op->emitError(
5267 "mismatch in inner tile sizes specified and shaped of "
5268 "tiled dimension in the packed type at index ")
5269 << it.index() << ": got " << dimSize << " != " << staticTileSize;
5270 } else if (!ShapedType::isDynamic(dimSize)) {
5271 return op->emitError("mismatch in inner tile sizes specified at index ")
5272 << it.index() << ": got static shape " << dimSize
5273 << " but dynamic tile size";
5274 }
5275 }
5276 if (failed(
5277 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5278 auto elementType = unpackedType.getElementType();
5279 Type expectedType, actualType;
5280 if (packOrUnPack.hasPureTensorSemantics()) {
5281 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5282 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5283 } else {
5284 expectedType = MemRefType::get(expectedPackedShape, elementType);
5285 actualType = MemRefType::get(packedType.getShape(), elementType);
5286 }
5287 return op->emitError("expected ")
5288 << expectedType << " for the packed domain value, got "
5289 << actualType;
5290 }
5291 return success();
5292}
5293
5294namespace {
5295/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5296/// various permutations to the op.
5297// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5298// these. These may or may not become true foldings / canonicalizations
5299// depending on how aggressive we want to be in automatically folding
5300// transposes.
5301struct PackOrUnPackTransposeResult {
5302 SmallVector<int64_t> innerDimsPos;
5303 SmallVector<OpFoldResult> innerTiles;
5304 SmallVector<int64_t> outerDimsPerm;
5305};
5306} // namespace
5307
5308template <typename OpTy>
5309static PackOrUnPackTransposeResult
5311 ArrayRef<int64_t> innerPermutation,
5312 ArrayRef<int64_t> outerPermutation) {
5313 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5314 "applies to only pack or unpack operations");
5315 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5316 "some permutation must be non-empty");
5317 PackOrUnPackTransposeResult metadata;
5318 metadata.innerDimsPos =
5319 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5320 metadata.innerTiles =
5321 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5322 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5323 ? packOrUnPackOp.getSourceRank()
5324 : packOrUnPackOp.getDestRank();
5325 metadata.outerDimsPerm =
5326 packOrUnPackOp.getOuterDimsPerm().empty()
5327 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5328 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5329 if (!innerPermutation.empty()) {
5330 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5331 isPermutationVector(innerPermutation) &&
5332 "invalid inner permutation");
5333 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5334 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5335 }
5336 if (!outerPermutation.empty()) {
5337 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5338 isPermutationVector(outerPermutation) &&
5339 "invalid outer permutation");
5340 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5341 }
5342 return metadata;
5343}
5344
5345//===----------------------------------------------------------------------===//
5346// PackOp
5347//===----------------------------------------------------------------------===//
5348
5349void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5350 if (!getResults().empty())
5351 setNameFn(getResult(), "pack");
5352}
5353
5354ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5355 OpAsmParser::UnresolvedOperand source, dest;
5358 SmallVector<Type> paddingValueType;
5359 SmallVector<int64_t> staticTiles;
5360 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5361 Type sourceType, destType, resultType;
5362
5363 if (parser.parseOperand(source))
5364 return failure();
5365
5366 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5367 if (parser.parseLParen() ||
5368 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5369 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5370 parser.parseRParen())
5371 return failure();
5372 }
5373
5374 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5375 if (parser.parseEqual())
5376 return failure();
5377
5378 SmallVector<int64_t> outerDimsPermVec;
5380 int64_t value;
5381 if (parser.parseInteger(value))
5382 return failure();
5383 outerDimsPermVec.push_back(value);
5384 return success();
5385 }))
5386 return failure();
5387 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5388 }
5389
5390 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5391 return failure();
5392
5393 SmallVector<int64_t> innerDimsPosVec;
5395 int64_t value;
5396 if (parser.parseInteger(value))
5397 return failure();
5398 innerDimsPosVec.push_back(value);
5399 return success();
5400 }))
5401 return failure();
5402 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5403
5404 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5405 return failure();
5406
5407 DenseI64ArrayAttr staticTilesAttr;
5408 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5409 return failure();
5410 for (auto val : staticTilesAttr.asArrayRef())
5411 staticTiles.push_back(val);
5412
5413 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5414 return failure();
5415
5416 if (parser.parseOptionalAttrDict(result.attributes))
5417 return failure();
5418
5419 if (parser.parseColon() || parser.parseType(sourceType))
5420 return failure();
5421
5422 bool hasArrow = succeeded(parser.parseOptionalArrow());
5423 if (hasArrow) {
5424 if (parser.parseType(destType))
5425 return failure();
5426 }
5427
5428 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5429 if (!hasArrow) {
5430 return parser.emitError(parser.getCurrentLocation(),
5431 "pack/unpack requires '->' and destination type");
5432 }
5433
5434 if (!isMemRef)
5435 resultType = destType;
5436
5437 if (parser.resolveOperand(source, sourceType, result.operands) ||
5438 parser.resolveOperand(dest, destType, result.operands))
5439 return failure();
5440
5441 if (!paddingValue.empty() &&
5442 parser.resolveOperands(paddingValue, paddingValueType[0],
5443 result.operands))
5444 return failure();
5445
5446 if (!dynamicTiles.empty() &&
5447 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5448 result.operands))
5449 return failure();
5450
5451 result.addAttribute("static_inner_tiles",
5452 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5453 result.addAttribute("inner_dims_pos", innerDimsPos);
5454 if (outerDimsPerm)
5455 result.addAttribute("outer_dims_perm", outerDimsPerm);
5456
5457 SmallVector<int32_t> segmentSizes = {
5458 1, 1, static_cast<int32_t>(paddingValue.size()),
5459 static_cast<int32_t>(dynamicTiles.size())};
5460 result.addAttribute("operandSegmentSizes",
5461 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5462
5463 if (!isMemRef)
5464 result.addTypes(resultType);
5465
5466 return success();
5467}
5468
5469void PackOp::print(OpAsmPrinter &p) {
5470 p << " " << getSource();
5471
5472 if (getPaddingValue()) {
5473 p << " padding_value(" << getPaddingValue() << " : "
5474 << getPaddingValue().getType() << ")";
5475 }
5476
5477 if (!getOuterDimsPerm().empty()) {
5478 p << " outer_dims_perm = [";
5479 llvm::interleaveComma(getOuterDimsPerm(), p);
5480 p << "]";
5481 }
5482
5483 p << " inner_dims_pos = [";
5484 llvm::interleaveComma(getInnerDimsPos(), p);
5485 p << "]";
5486
5487 p << " inner_tiles = ";
5488 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5489
5490 p << " into " << getDest();
5491
5492 p.printOptionalAttrDict((*this)->getAttrs(),
5493 {"static_inner_tiles", "inner_dims_pos",
5494 "outer_dims_perm", "operandSegmentSizes"});
5495
5496 p << " : " << getSource().getType();
5497 p << " -> " << getDest().getType();
5498}
5499
5500void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5501 Value dest, ArrayRef<int64_t> innerDimsPos,
5502 ArrayRef<OpFoldResult> innerTiles,
5503 std::optional<Value> paddingValue,
5504 ArrayRef<int64_t> outerDimsPerm) {
5505 assert(innerDimsPos.size() == innerTiles.size() &&
5506 "number of tile sizes specified must match the specified number of "
5507 "original dimensions to be tiled");
5508 SmallVector<int64_t> staticTileSizes;
5509 SmallVector<Value> dynamicTileSizes;
5510 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5511 build(builder, state, dest.getType(), source, dest,
5512 paddingValue ? *paddingValue : nullptr,
5513 outerDimsPerm.empty() ? nullptr
5514 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5515 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5516 builder.getDenseI64ArrayAttr(staticTileSizes));
5517}
5518
5519LogicalResult
5520PackOp::reifyResultShapes(OpBuilder &builder,
5521 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5522 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5523}
5524
5525DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5526 return getDimAndTileMappingImpl(*this);
5527}
5528
5529SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5530 return getMixedTilesImpl(*this);
5531}
5532
5533SmallVector<int64_t> PackOp::getStaticTiles() {
5534 return getStaticTilesImpl(*this);
5535}
5536
5537ArrayRef<int64_t> PackOp::getAllOuterDims() {
5538 ShapedType inputType = getSourceType();
5539 int64_t inputRank = inputType.getRank();
5540 return getDestType().getShape().take_front(inputRank);
5541}
5542
5543SmallVector<int64_t> PackOp::getTiledOuterDims() {
5544 auto innerDimsPos = getInnerDimsPos();
5545 SmallVector<int64_t> outerDims(getAllOuterDims());
5546 SmallVector<int64_t> res;
5547
5548 // Recover the original order of the outer dims.
5549 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5550 invertPermutationVector(outerDimPermInv);
5551 if (!outerDimPermInv.empty())
5552 applyPermutationToVector(outerDims, outerDimPermInv);
5553
5554 // Collect the outer dims corresponding to the tilled inner dims.
5555 for (auto index : innerDimsPos)
5556 res.push_back(outerDims[index]);
5557
5558 return res;
5559}
5560
5561bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5562 ArrayRef<int64_t> innerDimsPos,
5563 ArrayRef<int64_t> outputShape,
5564 ArrayRef<int64_t> outerDimsPerm,
5565 ArrayRef<OpFoldResult> innerTiles) {
5566 SmallVector<int64_t> outputTileSizes(
5567 outputShape.take_front(inputShape.size()));
5568 if (!outerDimsPerm.empty()) {
5569 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5570 "expected output and outer_dims_perm to have same size");
5571 applyPermutationToVector(outputTileSizes,
5572 invertPermutationVector(outerDimsPerm));
5573 }
5574 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5575 if (ShapedType::isDynamic(inputShape[pos]))
5576 continue;
5577 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5578 if (!constantTile) {
5579 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5580 (inputShape[pos] % outputTileSizes[pos] != 0))
5581 return true;
5582 } else {
5583 assert(*constantTile != 0 && "static tile size can't be zero");
5584 if (inputShape[pos] % (*constantTile) != 0) {
5585 return true;
5586 }
5587 }
5588 }
5589 return false;
5590}
5591
5592bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5593 ArrayRef<int64_t> innerDimsPos,
5594 ArrayRef<int64_t> outputShape,
5595 ArrayRef<int64_t> outerDimsPerm,
5596 ArrayRef<OpFoldResult> innerTiles) {
5597 SmallVector<int64_t> outputTileSizes(
5598 outputShape.take_front(inputShape.size()));
5599 if (!outerDimsPerm.empty()) {
5600 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5601 "expected output and outer_dims_perm to have same size");
5602 applyPermutationToVector(outputTileSizes,
5603 invertPermutationVector(outerDimsPerm));
5604 }
5605 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5606 if (ShapedType::isDynamic(inputShape[pos]) ||
5607 ShapedType::isDynamic(outputTileSizes[pos]))
5608 return true;
5609 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5610 if (!constantTile)
5611 return true;
5612 assert(*constantTile != 0 && "static tile size can't be zero");
5613 if (inputShape[pos] % (*constantTile) != 0)
5614 return true;
5615 }
5616 return false;
5617}
5618
5619LogicalResult PackOp::verify() {
5621 return failure();
5622
5623 // Verify padding value, and bail out if the tile does not divide the
5624 // dimension fully. In the case of dynamic tile factors or dimensions, having
5625 // a partial tile is undefined behavior.
5626 auto paddingValue = getPaddingValue();
5627 if (paddingValue &&
5628 paddingValue.getType() != getSourceType().getElementType()) {
5629 return emitOpError("expected padding_value has ")
5630 << getSourceType().getElementType()
5631 << " but got: " << paddingValue.getType();
5632 }
5633
5634 if (!paddingValue &&
5635 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5636 getDestType().getShape(), getOuterDimsPerm(),
5637 getMixedTiles())) {
5638 return emitOpError(
5639 "invalid tile factor or output size provided. Only full tiles are "
5640 "supported when padding_value is not set");
5641 }
5642 return success();
5643}
5644
5645/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5646/// Value's to kDynamic, even if they are arith.constant values.
5647static SmallVector<int64_t>
5650 for (auto o : ofrs) {
5651 // Have to do this first, as getConstantIntValue special-cases constants.
5652 if (llvm::dyn_cast_if_present<Value>(o))
5653 result.push_back(ShapedType::kDynamic);
5654 else
5655 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5656 }
5657 return result;
5658}
5659
5660SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5661 ArrayRef<int64_t> innerTileSizes,
5662 ArrayRef<int64_t> innerDimsPos,
5663 ArrayRef<int64_t> outerDimsPerm) {
5664 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5665 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5666 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5667 continue;
5668 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5669 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5670 continue;
5671 }
5672 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5673 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5674 }
5675
5676 // Swap tile loops if outer_dims_perm is available.
5677 if (!outerDimsPerm.empty())
5678 applyPermutationToVector(resultShape, outerDimsPerm);
5679
5680 // Append the inner tile dimensions.
5681 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5682 return resultShape;
5683}
5684
5685SmallVector<OpFoldResult> PackOp::getResultShape(
5686 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5687 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5688 ArrayRef<int64_t> outerDimsPerm) {
5689 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5690
5691 AffineExpr s0, s1;
5692 bindSymbols(builder.getContext(), s0, s1);
5693 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5694 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5695 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5696 builder, loc, ceilDivExpr,
5697 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5698 }
5699 if (!outerDimsPerm.empty())
5700 applyPermutationToVector(resultDims, outerDimsPerm);
5701 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5702
5703 SmallVector<int64_t> resultTypeShape =
5704 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5705 asShapeWithAnyValueAsDynamic(innerTileSizes),
5706 innerDimsPos, outerDimsPerm);
5707
5708 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5709 // result type shape says it's a dynamic dim. This is needed as callers may
5710 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5711 // dynamic dims returned by that.
5712 for (unsigned i = 0; i < resultDims.size(); ++i) {
5713 if (ShapedType::isStatic(resultTypeShape[i]))
5714 continue;
5715 resultDims[i] =
5716 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5717 }
5718
5719 return resultDims;
5720}
5721
5722RankedTensorType PackOp::inferPackedTensorType(
5723 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5724 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5725 SmallVector<int64_t> resultShape = inferPackedShape(
5726 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5727 return RankedTensorType::get(resultShape, sourceType.getElementType());
5728}
5729
5730MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5731 ArrayRef<int64_t> innerTileSizes,
5732 ArrayRef<int64_t> innerDimsPos,
5733 ArrayRef<int64_t> outerDimsPerm) {
5734 SmallVector<int64_t> resultShape = inferPackedShape(
5735 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5736 return MemRefType::get(resultShape, sourceType.getElementType());
5737}
5738
5739Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5740 ArrayRef<OpFoldResult> innerTileSizes,
5741 ArrayRef<int64_t> innerDimsPos,
5742 ArrayRef<int64_t> outerDimsPerm) {
5743 AffineExpr dim0, dim1;
5744 bindDims(b.getContext(), dim0, dim1);
5745 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5746 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5747 {v1, v2});
5748 };
5749
5750 SmallVector<OpFoldResult> mixedSizes;
5751 for (auto [index, value] : llvm::enumerate(
5752 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5753 if (ShapedType::isDynamic(value))
5754 mixedSizes.push_back(
5755 tensor::DimOp::create(b, loc, source, index).getResult());
5756 else
5757 mixedSizes.push_back(b.getIndexAttr(value));
5758 }
5759 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5760 int64_t dimPos = std::get<0>(it);
5761 OpFoldResult tileSize = std::get<1>(it);
5762 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5763 }
5764 if (!outerDimsPerm.empty())
5765 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5766
5767 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5768 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5769 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5770}
5771
5772PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5773 ArrayRef<int64_t> innerPermutation,
5774 ArrayRef<int64_t> outerPermutation) {
5775 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5776 *this, innerPermutation, outerPermutation);
5777 Value transposedDest =
5778 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5779 metadata.innerDimsPos, metadata.outerDimsPerm);
5780 return PackOp::create(b, loc, getSource(), transposedDest,
5781 metadata.innerDimsPos, metadata.innerTiles,
5782 getPaddingValue(), metadata.outerDimsPerm);
5783}
5784
5785template <typename OpTy>
5788 &effects) {
5789 // No memory effects for pure tensor semantics
5790 if (op.hasPureTensorSemantics())
5791 return;
5792
5793 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5794 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5795 continue;
5796
5797 if (&opOperand == &op.getSourceMutable()) {
5798 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5799 /*effectOnFullRegion=*/true,
5801 } else if (&opOperand == &op.getDestMutable()) {
5802 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5803 /*effectOnFullRegion=*/true,
5805 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5806 /*effectOnFullRegion=*/true,
5808 }
5809 }
5810}
5811
5812void PackOp::getEffects(
5814 &effects) {
5815 getPackUnPackEffectsImpl(*this, effects);
5816}
5817
5818void UnPackOp::getEffects(
5820 &effects) {
5821 getPackUnPackEffectsImpl(*this, effects);
5822}
5823
5824/// Returns true if the tiles and the tiled dims are constant.
5825template <typename OpTy>
5827 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5828 "applies to only pack or unpack operations");
5829 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5830 ? op.getDestType()
5831 : op.getSourceType();
5832 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5833 for (auto [dimDest, tile] : llvm::zip(
5834 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5835 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5836 if (!constTileSize || ShapedType::isDynamic(dimDest))
5837 return false;
5838 }
5839 return true;
5840}
5841
5842Speculation::Speculatability PackOp::getSpeculatability() {
5843 if (!hasPureTensorSemantics())
5845 if (getPaddingValue())
5847
5848 // The verifier rejects already operations if we can statically prove that the
5849 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5850 // to have constant tiles and tiled inner dimensions.
5853
5855}
5856
5857// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5858// dimensions for pack and unpack.
5859static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5860 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5861 return false;
5862 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5863 return true;
5864 // Outer dims permutation is optional.
5865 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5866 // identity permutation.
5867 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5868 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5869}
5870
5871// Return true if pack and unpack have the same tiles.
5872// Same SSA values or same integer constants.
5873static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5874 auto packTiles = packOp.getMixedTiles();
5875 auto unPackTiles = unPackOp.getMixedTiles();
5876 if (packTiles.size() != unPackTiles.size())
5877 return false;
5878 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5879 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5880 return false;
5881 }
5882 return true;
5883}
5884
5885/// Returns true if the pack op does not need a padding value.
5886static bool paddingIsNotNeeded(PackOp op) {
5887 auto srcType = op.getSourceType();
5888 if (llvm::any_of(op.getInnerDimsPos(),
5889 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5890 return false;
5891 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5892 return false;
5893 return !PackOp::requirePaddingValue(
5894 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5895 op.getOuterDimsPerm(), op.getMixedTiles());
5896}
5897
5898/// Returns true if the `srcShape` or `destShape` is different from the one in
5899/// `packOp` and populates each with the inferred static shape.
5900static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5901 SmallVectorImpl<int64_t> &destShape) {
5902 bool changeNeeded = false;
5903 srcShape.assign(packOp.getSourceType().getShape().begin(),
5904 packOp.getSourceType().getShape().end());
5905 destShape.assign(packOp.getDestType().getShape().begin(),
5906 packOp.getDestType().getShape().end());
5907 llvm::SmallSetVector<int64_t, 4> innerDims;
5908 innerDims.insert_range(packOp.getInnerDimsPos());
5909 SmallVector<int64_t> inverseOuterDimsPerm;
5910 if (!packOp.getOuterDimsPerm().empty())
5911 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5912 int srcRank = packOp.getSourceRank();
5913 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5914 if (innerDims.contains(i))
5915 continue;
5916 int64_t srcPos = i;
5917 int64_t destPos = i;
5918 if (!inverseOuterDimsPerm.empty())
5919 destPos = inverseOuterDimsPerm[srcPos];
5920 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5921 ShapedType::isDynamic(destShape[destPos])) {
5922 continue;
5923 }
5924 int64_t size = srcShape[srcPos];
5925 if (ShapedType::isDynamic(size))
5926 size = destShape[destPos];
5927 srcShape[srcPos] = size;
5928 destShape[destPos] = size;
5929 changeNeeded = true;
5930 }
5931 return changeNeeded;
5932}
5933
5934LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5935 // TODO: Support Memref PackOp. Temporarily return failure.
5936 if (!packOp.hasPureTensorSemantics())
5937 return failure();
5938
5939 // Fold an pack(unpack(x)) to x.
5940 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5941 if (unPackOp.getSourceType() == packOp.getDestType() &&
5942 !packOp.getPaddingValue() &&
5943 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5944 haveSameTiles(packOp, unPackOp)) {
5945 rewriter.replaceOp(packOp, unPackOp.getSource());
5946 return success();
5947 }
5948 }
5949
5950 // Fold optional PaddingValue operand away if padding is not needed.
5951 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5952 rewriter.startOpModification(packOp);
5953 packOp.getPaddingValueMutable().clear();
5954 rewriter.finalizeOpModification(packOp);
5955 return success();
5956 }
5957
5958 // Insert tensor.cast ops if static shape inference is available..
5959 SmallVector<int64_t> srcShape, destShape;
5960 if (inferStaticShape(packOp, srcShape, destShape)) {
5961 Location loc = packOp.getLoc();
5962 Value source = packOp.getSource();
5963 if (srcShape != packOp.getSourceType().getShape()) {
5964 auto newSrcType = packOp.getSourceType().clone(srcShape);
5965 source =
5966 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5967 }
5968 Value dest = packOp.getDest();
5969 ShapedType originalResultType = packOp.getDestType();
5970 bool needUpdateDestType = (destShape != originalResultType.getShape());
5971 if (needUpdateDestType) {
5972 auto newDestType = packOp.getDestType().clone(destShape);
5973 dest =
5974 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5975 }
5976 rewriter.modifyOpInPlace(packOp, [&] {
5977 packOp.getSourceMutable().assign(source);
5978 packOp.getDestMutable().assign(dest);
5979 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5980 });
5981 // Insert a cast if needed
5982 if (needUpdateDestType) {
5983 rewriter.setInsertionPointAfter(packOp);
5984 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5985 packOp.getResult());
5986 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5987 }
5988 return success();
5989 }
5990
5991 return failure();
5992}
5993
5994template <typename PackOrUnpackOp>
5995static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5996 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5997 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5998 "Function meant for pack/unpack");
5999 // This is a pad if packing only adds ones and we don't transpose dimensions.
6000
6001 // Check that we are not transposing any dimensions.
6002 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
6003 int64_t numPackedDims = innerDimsPos.size();
6004 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
6005 if (orderedDims != innerDimsPos) {
6006 // Dimensions don't happen in order.
6007 return false;
6008 }
6009
6010 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
6011 int64_t packedRank = packedTensorType.getRank();
6012 // At this point we know that we are taking numPackedDims outer
6013 // dimensions and pushing them all the way as the inner most dimensions.
6014 // What's left on the outer most dimensions is, in this order:
6015 // - the factor of the packed dimensions, then
6016 // - the untouched dimensions
6017 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
6018 // if all the dimensions that bubble outerward are ones.
6019 // Therefore check that all the dimensions but the numPackedDims inner most
6020 // ones are ones.
6021 return llvm::all_of(
6022 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6023 [&packedShape](int64_t i) { return packedShape[i] == 1; });
6024}
6025
6026bool PackOp::isLikePad() {
6027 auto packedTensorType =
6028 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6029 return isLikePadUnPad(*this, packedTensorType);
6030}
6031
6032::mlir::LogicalResult
6033PackOp::fold(FoldAdaptor adaptor,
6035 if (!hasPureTensorSemantics())
6036 return failure();
6037 std::optional<Attribute> paddingValue;
6038 if (auto pad = adaptor.getPaddingValue())
6039 paddingValue = pad;
6040 if (OpFoldResult reshapedSource = reshapeConstantSource(
6041 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6042 cast<TensorType>(getDestType()), paddingValue)) {
6043 results.push_back(reshapedSource);
6044 return success();
6045 }
6046 return failure();
6047}
6048
6049/// Folds a tensor.cast op into a consuming PackOp op if the
6050/// `tensor.cast` has source that is more static than the consuming op.
6051///
6052/// Example:
6053/// ```mlir
6054/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
6055/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
6056/// ```
6057///
6058/// folds into:
6059///
6060/// ```mlir
6061/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
6062/// ```
6065
6066 LogicalResult matchAndRewrite(PackOp op,
6067 PatternRewriter &rewriter) const override {
6068 // TODO: Support Memref PackOp. Temporarily return failure.
6069 if (!op.hasPureTensorSemantics())
6070 return failure();
6071
6073 return failure();
6074
6075 SmallVector<Type> newResultTypes(op->getResultTypes());
6076 SmallVector<Value> newOperands =
6078
6079 // Get the updated mixed-tile-sizes attribute.
6080 SmallVector<OpFoldResult> newMixedTileSizes =
6081 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6082 if (llvm::any_of(newMixedTileSizes, isZeroInteger))
6083 return failure();
6084
6085 // Clone op.
6086 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6087 // this point. However, in practice, we use them for things that we'd like
6088 // to preserve. Implement a better abstraction.
6089 PackOp newOp =
6090 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6091 op.getInnerDimsPos(), newMixedTileSizes,
6092 op.getPaddingValue(), op.getOuterDimsPerm());
6093 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6094
6095 // Replace op.
6096 Value oldResult = op.getResult();
6097 Value newResult = newOp.getResult();
6099 (newResult.getType() != oldResult.getType())
6100 ? tensor::CastOp::create(rewriter, op->getLoc(),
6101 oldResult.getType(), newResult)
6102 : newResult;
6103
6104 rewriter.replaceOp(op, {replacement});
6105
6106 return success();
6107 }
6108};
6109
6110//===----------------------------------------------------------------------===//
6111// UnPackOp
6112//===----------------------------------------------------------------------===//
6113
6114void UnPackOp::getAsmResultNames(
6115 function_ref<void(Value, StringRef)> setNameFn) {
6116 if (!getResults().empty())
6117 setNameFn(getResult(), "unpack");
6118}
6119
6120// Custom parser for UnPackOp that handles the memref/tensor case distinction
6121ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6122 OpAsmParser::UnresolvedOperand source, dest;
6124 SmallVector<int64_t> staticTiles;
6125 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6126 Type sourceType, destType, resultType;
6127
6128 if (parser.parseOperand(source))
6129 return failure();
6130
6131 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6132 if (parser.parseEqual())
6133 return failure();
6134
6135 SmallVector<int64_t> outerDimsPermVec;
6137 int64_t value;
6138 if (parser.parseInteger(value))
6139 return failure();
6140 outerDimsPermVec.push_back(value);
6141 return success();
6142 }))
6143 return failure();
6144 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6145 }
6146
6147 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6148 return failure();
6149
6150 SmallVector<int64_t> innerDimsPosVec;
6152 int64_t value;
6153 if (parser.parseInteger(value))
6154 return failure();
6155 innerDimsPosVec.push_back(value);
6156 return success();
6157 }))
6158 return failure();
6159 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6160
6161 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6162 return failure();
6163
6164 DenseI64ArrayAttr staticTilesAttr;
6165 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6166 return failure();
6167 for (auto val : staticTilesAttr.asArrayRef())
6168 staticTiles.push_back(val);
6169
6170 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6171 return failure();
6172
6173 if (parser.parseOptionalAttrDict(result.attributes))
6174 return failure();
6175
6176 if (parser.parseColon() || parser.parseType(sourceType))
6177 return failure();
6178
6179 bool hasArrow = succeeded(parser.parseOptionalArrow());
6180 if (hasArrow) {
6181 if (parser.parseType(destType))
6182 return failure();
6183 }
6184
6185 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6186 if (!hasArrow) {
6187 return parser.emitError(parser.getCurrentLocation(),
6188 "pack/unpack requires '->' and destination type");
6189 }
6190
6191 if (!isMemRef)
6192 resultType = destType;
6193
6194 if (parser.resolveOperand(source, sourceType, result.operands) ||
6195 parser.resolveOperand(dest, destType, result.operands))
6196 return failure();
6197
6198 if (!dynamicTiles.empty() &&
6199 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6200 result.operands))
6201 return failure();
6202
6203 result.addAttribute("static_inner_tiles",
6204 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6205 result.addAttribute("inner_dims_pos", innerDimsPos);
6206 if (outerDimsPerm)
6207 result.addAttribute("outer_dims_perm", outerDimsPerm);
6208
6209 SmallVector<int32_t> segmentSizes = {
6210 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6211 result.addAttribute("operandSegmentSizes",
6212 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6213
6214 if (!isMemRef)
6215 result.addTypes(resultType);
6216
6217 return success();
6218}
6219
6220void UnPackOp::print(OpAsmPrinter &p) {
6221 p << " " << getSource();
6222
6223 if (!getOuterDimsPerm().empty()) {
6224 p << " outer_dims_perm = [";
6225 llvm::interleaveComma(getOuterDimsPerm(), p);
6226 p << "]";
6227 }
6228
6229 p << " inner_dims_pos = [";
6230 llvm::interleaveComma(getInnerDimsPos(), p);
6231 p << "]";
6232
6233 p << " inner_tiles = ";
6234 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6235
6236 p << " into " << getDest();
6237
6238 p.printOptionalAttrDict((*this)->getAttrs(),
6239 {"static_inner_tiles", "inner_dims_pos",
6240 "outer_dims_perm", "operandSegmentSizes"});
6241
6242 p << " : " << getSource().getType();
6243 p << " -> " << getDest().getType();
6244}
6245
6246LogicalResult
6247UnPackOp::reifyResultShapes(OpBuilder &builder,
6248 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6249 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6250}
6251
6252DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6253 return getDimAndTileMappingImpl(*this);
6254}
6255
6256SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6257 return getMixedTilesImpl(*this);
6258}
6259
6260SmallVector<int64_t> UnPackOp::getStaticTiles() {
6261 return getStaticTilesImpl(*this);
6262}
6263
6264ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6265 ShapedType destType = getDestType();
6266 int64_t destRank = destType.getRank();
6267 return getSourceType().getShape().take_front(destRank);
6268}
6269
6270SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6271 auto innerDimsPos = getInnerDimsPos();
6272 SmallVector<int64_t> outerDims(getAllOuterDims());
6273 SmallVector<int64_t> res;
6274
6275 // Recover the original order of the outer dims.
6276 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6277 invertPermutationVector(outerDimPermInv);
6278 if (!outerDimPermInv.empty())
6279 applyPermutationToVector(outerDims, outerDimPermInv);
6280
6281 // Collect the outer dims corresponding to the tilled inner dims.
6282 for (auto index : innerDimsPos)
6283 res.push_back(outerDims[index]);
6284
6285 return res;
6286}
6287
6288LogicalResult UnPackOp::verify() {
6289 return commonVerifierPackAndUnPackOp(*this);
6290}
6291
6292Speculation::Speculatability UnPackOp::getSpeculatability() {
6293 if (!hasPureTensorSemantics())
6295 // See PackOp::getSpeculatability.
6298
6300}
6301
6302void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6303 Value dest, ArrayRef<int64_t> innerDimsPos,
6304 ArrayRef<OpFoldResult> innerTiles,
6305 ArrayRef<int64_t> outerDimsPerm) {
6306 assert(innerDimsPos.size() == innerTiles.size() &&
6307 "number of tile sizes specified must match the specified number of "
6308 "original dimensions to be tiled");
6309 SmallVector<int64_t> staticTileSizes;
6310 SmallVector<Value> dynamicTileSizes;
6311 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6312 build(builder, state, dest.getType(), source, dest,
6313 outerDimsPerm.empty() ? nullptr
6314 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6315 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6316 builder.getDenseI64ArrayAttr(staticTileSizes));
6317}
6318
6319Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6320 Value source,
6321 ArrayRef<OpFoldResult> innerTileSizes,
6322 ArrayRef<int64_t> innerDimsPos,
6323 ArrayRef<int64_t> outerDimsPerm) {
6324 AffineExpr sym0, sym1;
6325 bindSymbols(b.getContext(), sym0, sym1);
6326 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6327 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6328 };
6329
6330 SmallVector<OpFoldResult> mixedSizes;
6331 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6332 for (auto i :
6333 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6334 if (srcType.isDynamicDim(i))
6335 mixedSizes.push_back(
6336 tensor::DimOp::create(b, loc, source, i).getResult());
6337 else
6338 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6339 }
6340 if (!outerDimsPerm.empty()) {
6342 mixedSizes, invertPermutationVector(outerDimsPerm));
6343 }
6344
6345 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6346 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6347
6348 auto elemType = srcType.getElementType();
6349 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6350}
6351
6352UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6353 Value transposedSource,
6354 ArrayRef<int64_t> innerPermutation,
6355 ArrayRef<int64_t> outerPermutation) {
6356 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6357 *this, innerPermutation, outerPermutation);
6358 return UnPackOp::create(b, loc, transposedSource, getDest(),
6359 metadata.innerDimsPos, metadata.innerTiles,
6360 metadata.outerDimsPerm);
6361}
6362
6363/// Returns true if the `srcShape` or `destShape` is different from the one in
6364/// `op` and populates each with the inferred static shape.
6365static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6366 SmallVectorImpl<int64_t> &destShape) {
6367 bool changeNeeded = false;
6368 srcShape.assign(op.getSourceType().getShape().begin(),
6369 op.getSourceType().getShape().end());
6370 destShape.assign(op.getDestType().getShape().begin(),
6371 op.getDestType().getShape().end());
6372 llvm::SmallSetVector<int64_t, 4> innerDims;
6373 innerDims.insert_range(op.getInnerDimsPos());
6374 SmallVector<int64_t> inverseOuterDimsPerm;
6375 if (!op.getOuterDimsPerm().empty())
6376 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6377 int destRank = op.getDestRank();
6378 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6379 if (innerDims.contains(i))
6380 continue;
6381 int64_t srcPos = i;
6382 int64_t destPos = i;
6383 if (!inverseOuterDimsPerm.empty())
6384 srcPos = inverseOuterDimsPerm[destPos];
6385 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6386 ShapedType::isDynamic(destShape[destPos])) {
6387 continue;
6388 }
6389 int64_t size = srcShape[srcPos];
6390 if (ShapedType::isDynamic(size))
6391 size = destShape[destPos];
6392 srcShape[srcPos] = size;
6393 destShape[destPos] = size;
6394 changeNeeded = true;
6395 }
6396 return changeNeeded;
6397}
6398
6399LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6400 PatternRewriter &rewriter) {
6401 // TODO: Support Memref UnPackOp. Temporarily return failure.
6402 if (!unPackOp.hasPureTensorSemantics())
6403 return failure();
6404
6405 /// unpack(pack(x)) -> x
6406 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6407 if (packOp.getSourceType() != unPackOp.getDestType())
6408 return failure();
6409 if (packOp.getPaddingValue() ||
6410 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6411 !haveSameTiles(packOp, unPackOp))
6412 return failure();
6413 rewriter.replaceOp(unPackOp, packOp.getSource());
6414 return success();
6415 }
6416 /// unpack(destinationStyleOp(x)) -> unpack(x)
6417 if (auto dstStyleOp =
6418 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6419 auto destValue = cast<OpResult>(unPackOp.getDest());
6420 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6421 rewriter.modifyOpInPlace(unPackOp,
6422 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6423 return success();
6424 }
6425 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6426 if (unPackOp->hasOneUse()) {
6427 auto extractSliceUser =
6428 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6429 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6430 OpBuilder::InsertionGuard g(rewriter);
6431 rewriter.setInsertionPoint(unPackOp);
6432 auto newDest = tensor::ExtractSliceOp::create(
6433 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6434 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6435 extractSliceUser.getMixedStrides());
6436 rewriter.modifyOpInPlace(unPackOp, [&]() {
6437 unPackOp.setDpsInitOperand(0, newDest);
6438 unPackOp.getResult().setType(newDest.getType());
6439 });
6440 rewriter.replaceOp(extractSliceUser, unPackOp);
6441 return success();
6442 }
6443 }
6444
6445 // Insert tensor.cast ops if static shape inference is available..
6446 SmallVector<int64_t> srcShape, destShape;
6447 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6448 Location loc = unPackOp.getLoc();
6449 Value source = unPackOp.getSource();
6450 if (srcShape != unPackOp.getSourceType().getShape()) {
6451 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6452 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6453 unPackOp.getSource());
6454 }
6455 Value dest = unPackOp.getDest();
6456 if (destShape != unPackOp.getDestType().getShape()) {
6457 auto newDestType = unPackOp.getDestType().clone(destShape);
6458 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6459 unPackOp.getDest());
6460 }
6461 UnPackOp newOp = UnPackOp::create(
6462 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6463 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6464 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6465 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6466 return success();
6467 }
6468
6469 return failure();
6470}
6471
6472bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6473 // Rank-reduced folding is not supported.
6474 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6475 return false;
6476 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6477 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6478 return false;
6479 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6480 SmallVector<int64_t> outerShapeWithoutTranspose =
6482 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6483 for (auto [pos, tileSize] :
6484 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6485 areOuterDimsTiled[pos] = true;
6486 if (unpackedTypeAfterFold.isDynamicDim(pos))
6487 return false;
6488 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6489 return false;
6490 if (ShapedType::isDynamic(tileSize))
6491 return false;
6492 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6493 unpackedTypeAfterFold.getDimSize(pos);
6494 if (paddingSize >= tileSize)
6495 return false;
6496 }
6497 // extract_slice must not affect dimensions that are not being unpacked
6498 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6499 if (areOuterDimsTiled[pos])
6500 continue;
6501 int64_t dim = outerShapeWithoutTranspose[pos];
6502 if (ShapedType::isDynamic(dim))
6503 return false;
6504 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6505 return false;
6506 }
6507 return true;
6508}
6509
6510bool UnPackOp::isLikeUnPad() {
6511 ShapedType packedTensorType = getSourceType();
6512 return isLikePadUnPad(*this, packedTensorType);
6513}
6514
6515::mlir::LogicalResult
6516UnPackOp::fold(FoldAdaptor adaptor,
6517 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6518 // TODO: Support Memref UnPackOp. Temporarily return failure.
6519 if (!hasPureTensorSemantics())
6520 return failure();
6521
6522 if (OpFoldResult reshapedSource = reshapeConstantSource(
6523 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6524 cast<TensorType>(getResult().getType()))) {
6525 results.push_back(reshapedSource);
6526 return success();
6527 }
6528 return failure();
6529}
6530
6531/// Folds a tensor.cast op into a consuming UnPackOp op if the
6532/// `tensor.cast` has source that is more static than the consuming op.
6533///
6534/// Example:
6535/// ```mlir
6536/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6537/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6538/// ```
6539///
6540/// folds into:
6541///
6542/// ```mlir
6543/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6544/// ```
6545struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6546 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6547
6548 LogicalResult matchAndRewrite(UnPackOp op,
6549 PatternRewriter &rewriter) const override {
6550 // TODO: Support Memref UnPackOp. Temporarily return failure.
6551 if (!op.hasPureTensorSemantics())
6552 return failure();
6553
6555 return failure();
6556
6557 SmallVector<Type> newResultTypes(op->getResultTypes());
6558 SmallVector<Value> newOperands =
6560 Value sourceTensor = newOperands[0];
6561
6562 // Get the updated mixed-tile-sizes attribute.
6564 rewriter, sourceTensor.getType(), op.getMixedTiles());
6565
6566 // Clone op.
6567 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6568 // this point. However, in practice, we use them for things that we'd like
6569 // to preserve. Implement a better abstraction.
6570 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6571 newOperands[1], op.getInnerDimsPos(),
6572 newMixedTileSizes, op.getOuterDimsPerm());
6573 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6574
6575 // Replace op.
6576 Value oldResult = op.getResult();
6577 Value newResult = newOp.getResult();
6579 (newResult.getType() != oldResult.getType())
6580 ? tensor::CastOp::create(rewriter, op->getLoc(),
6581 oldResult.getType(), newResult)
6582 : newResult;
6583
6584 rewriter.replaceOp(op, {replacement});
6585
6586 return success();
6587 }
6588};
6589
6590//===----------------------------------------------------------------------===//
6591// BatchReduceMatmulOp
6592//===----------------------------------------------------------------------===//
6593SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6595 utils::IteratorType::reduction, utils::IteratorType::parallel,
6596 utils::IteratorType::parallel, utils::IteratorType::reduction};
6597}
6598
6599SmallVector<AffineMap>
6600BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6601 AffineExpr d0, d1, d2, d3;
6602 SmallVector<AffineMap> indexingMaps;
6603 bindDims(context, d0, d1, d2, d3);
6604 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6605 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6606 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6607 return indexingMaps;
6608}
6609
6610bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6611 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6612 if (!maps)
6613 return false;
6614 if (maps.size() != 3)
6615 return false;
6616 auto positions = getAffineResultPositions(maps);
6617 if (failed(positions))
6618 return false;
6619 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6620 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6621 (*positions)[2] == SmallVector<int64_t>{1, 2};
6622}
6623unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6624
6625std::string BatchReduceMatmulOp::getLibraryCallName() {
6626 return generateLibraryCallName(getOperation());
6627}
6628
6629/// Check if the op has broadcast and/or transpose semantic. Returns true if
6630/// the user defined indexing maps are not equal to default map.
6631bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6632 SmallVector<AffineMap, 3> defaultMaps =
6633 getDefaultIndexingMaps(this->getContext());
6634 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6635 return defaultMaps != explicitMaps;
6636}
6637
6638/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6639/// broadcast map must include K dimension.
6640/// TODO: Strict inclusion of K dimension in the broadcast map is not
6641/// necessary for both input matrices simultaneously. We can relax this
6642/// condition to have K dimension for one input matrix map and infer the K
6643/// dimension for other input matrix map from the one already having K
6644/// dimension.
6645bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6646 bool isLHS) {
6647 assert(bcastMap.getNumResults() < 3 &&
6648 "Expected less than 3 result dim expr.");
6649 bool isValid = false;
6650 enum Indices { batchPos, mPos, nPos, kPos };
6651 if (bcastMap.getNumResults() == 1) {
6652 AffineExpr expr = bcastMap.getResult(0);
6653 isValid = expr.isFunctionOfDim(kPos);
6654 } else if (bcastMap.getNumResults() == 2) {
6655 AffineExpr expr0 = bcastMap.getResult(0);
6656 AffineExpr expr1 = bcastMap.getResult(1);
6657 isValid =
6658 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6659 expr0.isFunctionOfDim(mPos)) &&
6660 expr1.isFunctionOfDim(kPos))
6661 : ((expr0.isFunctionOfDim(batchPos) &&
6662 expr1.isFunctionOfDim(kPos)) ||
6663 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6664 }
6665 return isValid;
6666}
6667
6668void BatchReduceMatmulOp::regionBuilder(
6669 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
6670 function_ref<InFlightDiagnostic()> emitError) {
6671 if (emitError && block.getNumArguments() != 3) {
6672 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6673 << block.getNumArguments();
6674 return;
6675 }
6676 assert(block.getNumArguments() == 3 &&
6677 "BatchReduceMatmulOp regionBuilder expects 3 args");
6678 RegionBuilderHelper helper(b, block);
6679 SmallVector<Value> yields;
6680
6681 auto toType = block.getArgument(2).getType();
6682 Value castValA =
6683 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6684 Value castValB =
6685 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6686 Value mulVal =
6687 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6688 if (!castValA || !castValB || !mulVal)
6689 return;
6690 Value addVal =
6691 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6692 if (!addVal)
6693 return;
6694 yields.push_back(addVal);
6695 helper.yieldOutputs(yields);
6696}
6697
6698ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6699 OperationState &result) {
6700 SmallVector<Attribute, 3> indexingMapsAttr;
6701 Attribute mapAttr;
6702 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6703 if (parser.parseEqual())
6704 return failure();
6705 if (parser.parseLSquare())
6706 return failure();
6707
6708 do {
6709 if (parser.parseAttribute(mapAttr))
6710 return failure();
6711 if (!isa<AffineMapAttr>(mapAttr)) {
6712 return parser.emitError(parser.getCurrentLocation(),
6713 "expected affine map attribute");
6714 }
6715 indexingMapsAttr.push_back(mapAttr);
6716
6717 if (parser.parseOptionalComma())
6718 break;
6719 } while (true);
6720
6721 if (parser.parseRSquare())
6722 return failure();
6723 }
6724 // Initialize indexingMaps, if not supplied explicitly.
6725 if (indexingMapsAttr.empty()) {
6726 indexingMapsAttr = llvm::map_to_vector(
6727 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6728 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6729 }
6730 result.addAttribute("indexing_maps",
6731 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6732 return ::parseNamedStructuredOp(parser, result,
6733 BatchReduceMatmulOp::getNumRegionArgs(),
6734 BatchReduceMatmulOp::getRegionBuilder());
6735}
6736
6737void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6738 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6739 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6740 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6741
6742 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6743 p << " indexing_maps = [";
6744 llvm::interleaveComma(getIndexingMaps(), p,
6745 [&](Attribute attr) { p.printAttribute(attr); });
6746 p << "]";
6747 }
6748
6749 SmallVector<StringRef, 3> elidedAttrs = {
6750 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6751 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6752 elidedAttrs);
6753}
6754
6755/// Verify the user defined indexing maps.
6756LogicalResult BatchReduceMatmulOp::verify() {
6757 // Verification of pure batch_reduce_matmul is handled by
6758 // verifyStructuredOpInterface().
6759 if (!hasUserDefinedMaps())
6760 return success();
6761
6762 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6764 return failure();
6765 }
6766 return success();
6767}
6768LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6769 SmallVectorImpl<OpFoldResult> &) {
6770 return memref::foldMemRefCast(*this);
6771}
6772void BatchReduceMatmulOp::getEffects(
6773 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6774 &effects) {
6775 if (hasPureTensorSemantics())
6776 return;
6777 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6778}
6779
6780Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6781 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6782}
6783
6784} // namespace linalg
6785} // namespace mlir
6786
6787//===----------------------------------------------------------------------===//
6788// LinalgDialect
6789//===----------------------------------------------------------------------===//
6790
6791void LinalgDialect::getCanonicalizationPatterns(
6792 RewritePatternSet &results) const {
6793 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6794 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6795}
6796
6797Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6798 Attribute value, Type type,
6799 Location loc) {
6800 return arith::ConstantOp::materialize(builder, value, type, loc);
6801}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
Definition LinalgOps.cpp:95
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:60
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:392
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:462
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
result_iterator result_begin()
Definition Operation.h:438
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
unsigned getNumOperands()
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_type_range getOperandTypes()
Definition Operation.h:422
result_iterator result_end()
Definition Operation.h:439
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override