MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
33#include "mlir/IR/Matchers.h"
40
41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
53#include <cassert>
54#include <optional>
55
56using namespace mlir;
57using namespace mlir::linalg;
58
59/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
61 int64_t dim) {
62 auto type = cast<ShapedType>(v.getType());
63 if (!type.isDynamicDim(dim))
64 return builder.getIndexAttr(type.getDimSize(dim));
65
66 return getAsOpFoldResult(
68 .Case([&](RankedTensorType t) -> Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
70 })
71 .Case([&](MemRefType t) -> Value {
72 return memref::DimOp::create(builder, loc, v, dim);
73 }));
74}
75
76/// Returns a memref.subview or a tensor.extract_slice based on the type of the
77/// `source`.
81 ArrayRef<OpFoldResult> strides) {
83 .Case([&](RankedTensorType t) -> Operation * {
84 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
85 strides);
86 })
87 .Case([&](MemRefType type) -> Operation * {
88 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
89 strides);
90 })
91 .Default([&](Type t) -> Operation * { return nullptr; });
92}
93
94static std::optional<TypedAttr>
96 DenseElementsAttr splatAttr;
98 if (!splatAttr || !splatAttr.isSplat())
99 return std::nullopt;
100
101 return splatAttr.getSplatValue<TypedAttr>();
102}
103
104//===----------------------------------------------------------------------===//
105// Helper functions
106//===----------------------------------------------------------------------===//
107
109 int64_t dim) {
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable("Expected MemRefType or TensorType");
115}
116
118 int64_t dim) {
119 auto shapedType = llvm::cast<ShapedType>(source.getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
121 return createOrFoldDimOp(b, loc, source, dim);
122 return b.getIndexAttr(shapedType.getDimSize(dim));
123}
124
125//===----------------------------------------------------------------------===//
126// Support for named Linalg ops defined in ods-gen.
127//===----------------------------------------------------------------------===//
128
132
133/// Fills the region of a structured operation using the provided
134/// `regionBuilder`. The method is used by both named structured ops created by
135/// ods-gen and by manually defined C++ ops. It is called by both builders and
136/// parsers and creates a block with arguments corresponding to the elemental
137/// types of `inputTypes` and `outputTypes`.
138static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
139 TypeRange inputTypes, TypeRange outputTypes,
142 RegionBuilderFn regionBuilder) {
143 SmallVector<Type, 8> argTypes;
145 for (auto containers : {inputTypes, outputTypes}) {
146 for (auto t : containers) {
147 argTypes.push_back(
148 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
149
150 // TODO: Pass in a proper location here.
151 argLocs.push_back(opBuilder.getUnknownLoc());
152 }
153 }
154
155 // RAII.
156 OpBuilder::InsertionGuard guard(opBuilder);
157 Block *body =
158 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
159
160 opBuilder.setInsertionPointToStart(body);
161 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
162 regionBuilder(b, *body, attrs, emitError);
163
164 // indexing_maps is an auto-generated method.
165
166 // iterator_types is an auto-generated method.
167}
168
169/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
170/// The result types are derived automatically if `resultTensorTypes` is none.
171/// The body of the operation is filled using `regionBuilder`. All ods-gen
172/// created structured operations use the method to implement their builders.
174 std::optional<TypeRange> resultTensorTypes,
175 ValueRange inputs, ValueRange outputs,
176 ArrayRef<NamedAttribute> attributes,
177 RegionBuilderFn regionBuilder) {
178 // Derive the result types if needed.
179 SmallVector<Type> derivedResultTypes =
180 resultTensorTypes.value_or(TypeRange());
181 if (!resultTensorTypes)
182 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
184
185 state.addOperands(inputs);
186 state.addOperands(outputs);
187 state.addTypes(derivedResultTypes);
188
189 state.addAttributes(attributes);
190 state.addAttribute(
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
194
195 // Create and fill the region of the structured operation.
196 Region &region = *state.addRegion();
197 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
198 state.attributes.getAttrs(), /*emitError=*/{},
199 regionBuilder);
200}
201
203 std::optional<TypeRange> resultTensorTypes,
204 ValueRange inputs, ValueRange outputs,
205 ArrayRef<NamedAttribute> attributes,
206 RegionBuilderFn regionBuilder,
207 ArrayRef<AffineMap> defaultIndexingMaps) {
208 // If indexing maps are not provided, apply the default ones.
209 if (none_of(attributes, [](NamedAttribute attr) {
210 return attr.getName() == "indexing_maps";
211 })) {
212 SmallVector<Attribute, 3> indexingMapsAttrVal;
213 indexingMapsAttrVal = llvm::map_to_vector(
214 defaultIndexingMaps,
215 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
216 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217 }
218 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
219 attributes, regionBuilder);
220}
221
223 std::optional<TypeRange> resultTensorTypes,
224 ValueRange inputs, ValueRange outputs,
225 ArrayRef<NamedAttribute> attributes,
226 RegionBuilderFn regionBuilder,
227 ArrayRef<AffineMap> defaultIndexingMaps) {
228 // If indexing maps are not provided, apply the default ones.
229 if (none_of(attributes, [](NamedAttribute attr) {
230 return attr.getName() == "indexing_maps";
231 })) {
232 SmallVector<Attribute, 4> indexingMapsAttrVal;
233 indexingMapsAttrVal = llvm::map_to_vector(
234 defaultIndexingMaps,
235 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
236 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
237 }
238 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
239 attributes, regionBuilder);
240}
241
243 std::optional<TypeRange> resultTensorTypes,
244 ValueRange inputs, ValueRange outputs,
245 ArrayRef<NamedAttribute> attributes,
246 RegionBuilderFn regionBuilder,
247 ArrayRef<AffineMap> indexingMaps) {
248 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
249 SmallVector<Attribute, 4> indexingMapsAttrVal;
250 indexingMapsAttrVal =
251 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
252 return AffineMapAttr::get(map);
253 });
254 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
255 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
256 attributes, regionBuilder);
257}
258
259/// Common parsing used for both named structured ops created by ods-gen and by
260/// manually defined C++ ops. Does not handle regions.
261static ParseResult
263 SmallVectorImpl<Type> &inputTypes,
264 SmallVectorImpl<Type> &outputTypes,
265 bool addOperandSegmentSizes = true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
268 outputsOperands;
269
270 if (succeeded(parser.parseOptionalLess())) {
271 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
272 return failure();
273 }
274 attrsLoc = parser.getCurrentLocation();
275 if (parser.parseOptionalAttrDict(result.attributes))
276 return failure();
277
278 if (succeeded(parser.parseOptionalKeyword("ins"))) {
279 if (parser.parseLParen())
280 return failure();
281
282 inputsOperandsLoc = parser.getCurrentLocation();
283 if (parser.parseOperandList(inputsOperands) ||
284 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
285 return failure();
286 }
287
288 if (succeeded(parser.parseOptionalKeyword("outs"))) {
289 outputsOperandsLoc = parser.getCurrentLocation();
290 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
291 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
292 return failure();
293 }
294
295 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
296 result.operands) ||
297 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
298 result.operands))
299 return failure();
300
301 if (addOperandSegmentSizes) {
302 // This is a bit complex because we're trying to be backward compatible with
303 // operation syntax that mix the inherent attributes and the discardable
304 // ones in the same dictionary. If the properties are used, we append the
305 // operandSegmentSizes there directly. Otherwise we append it to the
306 // discardable attributes dictionary where it is handled by the generic
307 // Operation::create(...) method.
308 if (result.propertiesAttr) {
309 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
310 attrs.append("operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
314 result.propertiesAttr = attrs.getDictionary(parser.getContext());
315 } else {
316 result.addAttribute("operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
320 }
321 }
322 if (!result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
325 if (info) {
326 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 << "'" << result.name.getStringRef() << "' op ";
329 })))
330 return failure();
331 }
332 }
333 return success();
334}
335
337 ValueRange outputs) {
338 if (!inputs.empty())
339 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
340 if (!outputs.empty())
341 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
342}
343
344//===----------------------------------------------------------------------===//
345// Specific parsing and printing for named structured ops created by ods-gen.
346//===----------------------------------------------------------------------===//
347
349 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
350 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
351 RegionBuilderFn regionBuilder, SMLoc loc) {
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
353 return parser.emitError(
354 parser.getCurrentLocation(),
355 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
358 }
359
360 OpBuilder opBuilder(parser.getContext());
361 ParseResult result = success();
363 opBuilder, region, inputTypes, outputTypes, attrs,
364 [&]() {
365 result = failure();
366 return parser.emitError(loc);
367 },
368 regionBuilder);
369 return result;
370}
371
372static ParseResult
374 SmallVectorImpl<Type> &resultTypes) {
375 if (parser.parseOptionalArrowTypeList(resultTypes))
376 return failure();
377 return success();
378}
379
380static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
382 unsigned numRegionArgs,
383 RegionBuilderFn regionBuilder) {
384 // TODO: Enable when ods-gen supports captures.
385 SmallVector<Type, 1> inputTypes, outputTypes;
386 SMLoc loc = parser.getCurrentLocation();
387 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
388 return failure();
389
390 // Parse optional attributes.
391 if (parser.parseOptionalAttrDict(result.attributes))
392 return failure();
393
394 // TODO: consider merging results parsing into region parsing.
395 // Need to wait for declarative assembly resolution to decide.
396 SmallVector<Type, 1> outputTensorsTypes;
397 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
398 return failure();
399 result.addTypes(outputTensorsTypes);
400
401 std::unique_ptr<Region> region = std::make_unique<Region>();
402 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
403 outputTypes, result.attributes.getAttrs(),
404 regionBuilder, loc))
405 return failure();
406 result.addRegion(std::move(region));
407
408 return success();
409}
410
412 TypeRange resultTypes) {
413 if (resultTypes.empty())
414 return;
415 p.printOptionalArrowTypeList(resultTypes);
416}
417
419 ValueRange inputs, ValueRange outputs,
420 ArrayRef<StringRef> elidedAttrs = {}) {
421 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
422
423 // Printing is shared with generic ops, except for the region and
424 // attributes.
425 printCommonStructuredOpParts(p, inputs, outputs);
426
427 // Results printing.
429
430 // Region is elided.
431}
432
433//===----------------------------------------------------------------------===//
434// Region builder helper.
435// TODO: Move this to a utility library.
436// The public methods on this class are referenced directly from generated code.
437// Helper build the unary, binary, and type conversion functions defined by the
438// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
439// class.
440//
441// Implementations of the math functions must be polymorphic over numeric types,
442// internally performing necessary casts. If the function application makes no
443// sense, then the only recourse is to assert and return nullptr. This can be
444// extended later if it becomes possible to fail construction of the region. The
445// invariant should be enforced at a higher level.
446//
447// TODO: These helpers are currently type polymorphic over the class of integer
448// and floating point types, but they will not internally cast within bit
449// widths of a class (mixed precision such as i8->i32) or across classes
450// (i.e. mixed float and integer). Many such combinations are ambiguous or need
451// to be handled with care and work is being considered to extend the op
452// language to make such cases explicit. In the mean-time, violating this will
453// fail verification, which is deemed acceptable.
454//===----------------------------------------------------------------------===//
455
456namespace {
457
458class RegionBuilderHelper {
459public:
460 RegionBuilderHelper(OpBuilder &builder, Block &block)
461 : builder(builder), block(block) {}
462
463 // Build the unary functions defined by OpDSL.
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
465 function_ref<InFlightDiagnostic()> emitError = {}) {
466 if (!isFloatingPoint(arg)) {
467 if (emitError) {
468 emitError() << "unsupported non numeric type";
469 return nullptr;
470 }
471 llvm_unreachable("unsupported non numeric type");
472 }
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
475 switch (unaryFn) {
476 case UnaryFn::exp:
477 return math::ExpOp::create(builder, arg.getLoc(), arg);
478 case UnaryFn::log:
479 return math::LogOp::create(builder, arg.getLoc(), arg);
480 case UnaryFn::abs:
481 return math::AbsFOp::create(builder, arg.getLoc(), arg);
482 case UnaryFn::ceil:
483 return math::CeilOp::create(builder, arg.getLoc(), arg);
484 case UnaryFn::floor:
485 return math::FloorOp::create(builder, arg.getLoc(), arg);
486 case UnaryFn::negf:
487 return arith::NegFOp::create(builder, arg.getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.getType());
490 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
493 }
494 case UnaryFn::round:
495 return math::RoundOp::create(builder, arg.getLoc(), arg);
496 case UnaryFn::sqrt:
497 return math::SqrtOp::create(builder, arg.getLoc(), arg);
498 case UnaryFn::rsqrt:
499 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
502 case UnaryFn::tanh:
503 return math::TanhOp::create(builder, arg.getLoc(), arg);
504 case UnaryFn::erf:
505 return math::ErfOp::create(builder, arg.getLoc(), arg);
506 }
507 if (emitError) {
508 emitError() << "unsupported unary function";
509 return nullptr;
510 }
511 llvm_unreachable("unsupported unary function");
512 }
513
514 // Build the binary functions defined by OpDSL.
515 // If emitError is provided, an error will be emitted if the operation is not
516 // supported and a nullptr will be returned, otherwise an assertion will be
517 // raised.
518 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
519 function_ref<InFlightDiagnostic()> emitError = {}) {
520 bool allComplex = isComplex(arg0) && isComplex(arg1);
521 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
522 bool allInteger = isInteger(arg0) && isInteger(arg1);
523 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
524 arg1.getType().getIntOrFloatBitWidth() == 1;
525 if (!allComplex && !allFloatingPoint && !allInteger) {
526 if (emitError) {
527 emitError()
528 << "Cannot build binary Linalg operation: expects allComplex, "
529 "allFloatingPoint, or allInteger, got "
530 << arg0.getType() << " and " << arg1.getType();
531 return nullptr;
532 }
533 llvm_unreachable("unsupported non numeric type");
534 }
535 OpBuilder::InsertionGuard g(builder);
536 builder.setInsertionPointToEnd(&block);
537 switch (binaryFn) {
538 case BinaryFn::add:
539 if (allComplex)
540 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
541 if (allFloatingPoint)
542 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
543 if (allBool)
544 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
545 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
546 case BinaryFn::sub:
547 if (allComplex)
548 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
549 if (allFloatingPoint)
550 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allBool) {
552 if (emitError) {
553 emitError() << "unsupported operation: sub with bools";
554 return nullptr;
555 }
556 llvm_unreachable("unsupported operation: sub with bools");
557 }
558 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
559 case BinaryFn::mul:
560 if (allComplex)
561 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
562 if (allFloatingPoint)
563 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
564 if (allBool)
565 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
566 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
567 case BinaryFn::div:
568 if (allComplex)
569 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
570 if (allFloatingPoint)
571 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
572 if (allBool) {
573 if (emitError) {
574 emitError() << "unsupported operation: div with bools";
575 return nullptr;
576 }
577 llvm_unreachable("unsupported operation: div with bools");
578 }
579 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580 case BinaryFn::div_unsigned:
581 if (!allInteger || allBool) {
582 if (emitError) {
583 emitError() << "unsupported operation: unsigned div not on uint";
584 return nullptr;
585 }
586 llvm_unreachable("unsupported operation: unsigned div not on uint");
587 }
588 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
589 case BinaryFn::max_signed:
590 assert(!allComplex);
591 if (allFloatingPoint)
592 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
593 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
594 case BinaryFn::min_signed:
595 assert(!allComplex);
596 if (allFloatingPoint)
597 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
598 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
599 case BinaryFn::max_unsigned:
600 assert(!allComplex);
601 if (!allInteger || allBool) {
602 if (emitError) {
603 emitError() << "unsupported operation: unsigned max not on uint";
604 return nullptr;
605 }
606 llvm_unreachable("unsupported operation: unsigned max not on uint");
607 }
608 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
609 case BinaryFn::min_unsigned:
610 assert(!allComplex);
611 if (!allInteger || allBool) {
612 if (emitError) {
613 emitError() << "unsupported operation: unsigned min not on uint";
614 return nullptr;
615 }
616 llvm_unreachable("unsupported operation: unsigned min not on uint");
617 }
618 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
619 case BinaryFn::powf:
620 assert(allFloatingPoint);
621 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
622 }
623 if (emitError) {
624 emitError() << "unsupported binary function";
625 return nullptr;
626 }
627 llvm_unreachable("unsupported binary function");
628 }
629
630 // Build the ternary functions defined by OpDSL.
631 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
632 function_ref<InFlightDiagnostic()> emitError = {}) {
633 OpBuilder::InsertionGuard g(builder);
634 builder.setInsertionPointToEnd(&block);
635 switch (ternaryFn) {
636 case TernaryFn::select:
637 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
638 }
639 if (emitError) {
640 emitError() << "unsupported ternary function";
641 return nullptr;
642 }
643 llvm_unreachable("unsupported ternary function");
644 }
645
646 // Build the type functions defined by OpDSL.
647 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
648 function_ref<InFlightDiagnostic()> emitError = {}) {
649 switch (typeFn) {
650 case TypeFn::cast_signed:
651 return cast(toType, operand, false);
652 case TypeFn::cast_unsigned:
653 return cast(toType, operand, true);
654 }
655 if (emitError) {
656 emitError() << "unsupported type conversion function";
657 return nullptr;
658 }
659 llvm_unreachable("unsupported type conversion function");
660 }
661
662 void yieldOutputs(ValueRange values) {
663 OpBuilder::InsertionGuard g(builder);
664 builder.setInsertionPointToEnd(&block);
665 Location loc = builder.getUnknownLoc();
666 YieldOp::create(builder, loc, values);
667 }
668
669 Value constant(const std::string &value) {
670 OpBuilder::InsertionGuard g(builder);
671 builder.setInsertionPointToEnd(&block);
672 Location loc = builder.getUnknownLoc();
673 Attribute valueAttr = parseAttribute(value, builder.getContext());
674 return arith::ConstantOp::create(builder, loc,
675 ::cast<TypedAttr>(valueAttr));
676 }
677
678 Value index(int64_t dim) {
679 OpBuilder::InsertionGuard g(builder);
680 builder.setInsertionPointToEnd(&block);
681 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
682 }
683
684 Type getIntegerType(unsigned width) {
685 return IntegerType::get(builder.getContext(), width);
686 }
687
688 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
689 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
690
691private:
692 // Generates operations to cast the given operand to a specified type.
693 // If the cast cannot be performed, a warning will be issued and the
694 // operand returned as-is (which will presumably yield a verification
695 // issue downstream).
696 Value cast(Type toType, Value operand, bool isUnsignedCast) {
697 OpBuilder::InsertionGuard g(builder);
698 builder.setInsertionPointToEnd(&block);
699 auto loc = operand.getLoc();
700 if (isa<UnknownLoc>(loc)) {
701 if (operand.getDefiningOp())
702 loc = operand.getDefiningOp()->getLoc();
703 else if (operand.getParentBlock() &&
704 operand.getParentBlock()->getParentOp())
705 loc = operand.getParentBlock()->getParentOp()->getLoc();
706 }
707 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
708 }
709
710 bool isComplex(Value value) {
711 return llvm::isa<ComplexType>(value.getType());
712 }
713 bool isFloatingPoint(Value value) {
714 return llvm::isa<FloatType>(value.getType());
715 }
716 bool isInteger(Value value) {
717 return llvm::isa<IntegerType>(value.getType());
718 }
719
720 OpBuilder &builder;
721 Block &block;
722};
723
724} // namespace
725
726//===----------------------------------------------------------------------===//
727// CopyOp
728//===----------------------------------------------------------------------===//
729
730namespace {
731
732struct EraseSelfCopy : OpRewritePattern<CopyOp> {
733 using OpRewritePattern<CopyOp>::OpRewritePattern;
734 LogicalResult matchAndRewrite(CopyOp copyOp,
735 PatternRewriter &rewriter) const override {
736 if (copyOp.getInputs() != copyOp.getOutputs())
737 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
738 if (copyOp.hasPureBufferSemantics())
739 rewriter.eraseOp(copyOp);
740 else
741 rewriter.replaceOp(copyOp, copyOp.getInputs());
742
743 return success();
744 }
745};
746
747} // namespace
748
749void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
750 MLIRContext *context) {
751 results.add<EraseSelfCopy>(context);
752}
753
754//===----------------------------------------------------------------------===//
755// FillOp
756//===----------------------------------------------------------------------===//
757
758namespace {
759
760/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
761///
762/// For such op chains, we can create new linalg.fill ops with the result
763/// type of the tensor.expand/collapse_shape op.
764template <typename TensorReshapeOp>
765struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
766 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
767 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
768 PatternRewriter &rewriter) const override {
769 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
770 if (!oldFill)
771 return failure();
772
773 Location loc = oldFill.getLoc();
774 TensorReshapeOp newInit;
775 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
776
777 newInit = TensorReshapeOp::create(
778 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
779 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
780 reshapeOp.getStaticOutputShape());
781 } else {
782 newInit = TensorReshapeOp::create(
783 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
784 reshapeOp.getReassociation());
785 }
786 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
787 ValueRange{newInit});
788 return success();
789 }
790};
791
792/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
793/// filling value are the same.
794struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
796
797 LogicalResult matchAndRewrite(tensor::PadOp padOp,
798 PatternRewriter &rewriter) const override {
799 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
800 if (!fillOp)
801 return failure();
802
803 // We can only fold if the padding value is the same as the original
804 // filling value.
805 Value padValue = padOp.getConstantPaddingValue();
806 if (!padValue || fillOp.value() != padValue)
807 return failure();
808
809 ReifiedRankedShapedTypeDims reifiedShape;
810 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
811 return rewriter.notifyMatchFailure(
812 padOp, "failed to reify tensor.pad op result shape");
813
814 auto emptyTensor =
815 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
816 padOp.getResultType().getElementType());
817 Value replacement =
818 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
819 ValueRange{emptyTensor})
820 .getResult(0);
821 if (replacement.getType() != padOp.getResultType()) {
822 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
823 padOp.getResultType(), replacement);
824 }
825 rewriter.replaceOp(padOp, replacement);
826 return success();
827 }
828};
829
830/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
831/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
832/// filling value are the same.
833struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
835
836 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
837 PatternRewriter &rewriter) const override {
838 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
839 if (!srcPadOp)
840 return failure();
841
842 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
843 return failure();
844
845 // Walk back the tensor.insert_slice chain and find the first destination
846 // value at the start of the chain.
847 Value firstDest = insertOp.getDest();
848 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
849 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
850 return failure();
851
852 // Make sure the range of values accessed are disjoint. Without this, we
853 // cannot fold tensor.pad away.
854 bool disjoint = false;
855 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
856 // If the dimension has dynamic offset/size, we cannot guarantee
857 // disjoint. So just skip it.
858 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
859 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
860 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
861 continue;
862
863 // Get the range start and end, inclusively for both.
864 int64_t prevStart = prevOp.getStaticOffset(i);
865 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
866 prevOp.getStaticStride(i);
867 int64_t nextStart = insertOp.getStaticOffset(i);
868 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
869 insertOp.getStaticStride(i);
870 if (prevEnd < nextStart || nextEnd < prevStart) {
871 disjoint = true;
872 break;
873 }
874 }
875
876 if (!disjoint)
877 break;
878 firstDest = prevOp.getDest();
879 }
880
881 // Check whether the first destination is a fill op. For overlapped cases,
882 // this also cannot be true.
883 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
884 if (!dstFillOp)
885 return failure();
886
887 // We can only fold if the padding value is the same as the original
888 // filling value.
889 Value padValue = srcPadOp.getConstantPaddingValue();
890 if (!padValue || dstFillOp.value() != padValue)
891 return failure();
892
893 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
894 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
895
896 Location loc = insertOp.getLoc();
897 MLIRContext *context = getContext();
898
899 AffineExpr sym0, sym1;
900 bindSymbols(context, sym0, sym1);
901 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
902
903 // Calculate the new offsets for the insert. It should be the old offsets
904 // plus low padding sizes.
905 SmallVector<OpFoldResult, 4> newOffsets;
906 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
907 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
908 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
909 }
910
911 RankedTensorType srcPadType = srcPadOp.getSourceType();
912 SmallVector<OpFoldResult, 4> newSizes;
913 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
914 if (srcPadType.isDynamicDim(i)) {
915 newSizes.push_back(
916 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
917 .getResult());
918 } else {
919 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
920 }
921 }
922
923 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
924 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
925 newSizes, insertOp.getMixedStrides());
926 return success();
927 }
928};
929
930/// Fold tensor.extract(linalg.fill(<input>)) into <input>
931struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
932public:
933 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
934
935 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
936 PatternRewriter &rewriter) const override {
937 // See if tensor input of tensor.extract op is the result of a linalg.fill
938 // op.
939 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
940 if (!fillOp)
941 return failure();
942
943 // Get scalar input operand of linalg.fill op.
944 Value extractedScalar = fillOp.getInputs()[0];
945
946 // Replace tensor.extract op with scalar value used to fill the tensor.
947 rewriter.replaceOp(extractOp, extractedScalar);
948 return success();
949 }
950};
951
952/// Folds pack(fill) into a single fill op if
953/// 1. The pack op does not have padding value, or
954/// 2. The filled value and padding value are the same.
955static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
956 linalg::PackOp packOp) {
957 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
958 if (!fillOp)
959 return failure();
960
961 if (auto paddingValue = packOp.getPaddingValue())
962 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
963 return failure();
964
965 Value packOpDest = packOp.getDest();
966 if (!packOpDest.hasOneUse())
967 return failure();
968
969 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
970 packOp.getDest());
971}
972
973/// Wrapper pattern that applies foldFillPackIntoFillOp method.
974struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
975public:
976 FoldFillWithPack(MLIRContext *context)
977 : OpRewritePattern<linalg::PackOp>(context) {}
978
979 LogicalResult matchAndRewrite(linalg::PackOp packOp,
980 PatternRewriter &rewriter) const override {
981 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
982 if (failed(fillOp))
983 return failure();
984 rewriter.replaceOp(packOp, fillOp.value().result());
985 return success();
986 }
987};
988
989/// Fold fill with copy.
990struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
991 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
992
993 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
994 PatternRewriter &rewriter) const override {
995 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
996 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
997 fillOp.getInputs(),
998 copyOp.getOutputs());
999 return success();
1000 }
1001 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1002 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
1003 fillOp.getOutputs());
1004 return success();
1005 }
1006 return failure();
1007 }
1008};
1009
1010/// Fold fill with transpose.
1011struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1012 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1013
1014 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1015 PatternRewriter &rewriter) const override {
1016 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1017 rewriter.replaceOpWithNewOp<FillOp>(
1018 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1019 transposeOp.getDpsInitOperand(0)->get());
1020 return success();
1021 }
1022 return failure();
1023 }
1024};
1025
1026/// Fold a concat with all elements being fills of the same value
1027/// into a fill of the concat result shape.
1028struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1030
1031 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1032 PatternRewriter &rewriter) const override {
1033 auto concatOperands = concatOp.getInputs();
1034 if (concatOperands.empty()) {
1035 return failure();
1036 }
1037
1038 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1039 if (!firstFillOp) {
1040 return failure();
1041 }
1042 // Prefetch the fill value.
1043 OpFoldResult firstFillVal =
1044 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1045 // Collect all the outs values for the fill operations.
1046 SmallVector<Value> allOuts;
1047 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1048
1049 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1050 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1051 if (!fillOp) {
1052 return false;
1053 }
1054
1055 OpFoldResult fillVal =
1056 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1057 if (fillVal != firstFillVal)
1058 return false;
1059
1060 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1061 return true;
1062 };
1063 if (!llvm::all_of(concatOperands.drop_front(),
1064 isDefinedByCompatibleFillOp)) {
1065 return rewriter.notifyMatchFailure(
1066 concatOp, "not all operands are defined by a compatible fill op");
1067 }
1068
1069 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1070 concatOp.getDim(), allOuts);
1071 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1072 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1073 return success();
1074 }
1075};
1076
1077} // namespace
1078
1079void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080 MLIRContext *context) {
1081 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1082 FoldFillWithPack, FoldFillWithPad,
1083 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1084 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1085 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// GenericOp
1090//===----------------------------------------------------------------------===//
1091
1093 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1094 ValueRange outputs,
1095 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1096 SmallVector<Type, 4> blockArgTypes;
1097 SmallVector<Location, 4> blockArgLocs;
1098 for (ValueRange container : {inputs, outputs}) {
1099 for (Value v : container) {
1100 Type t = v.getType();
1101 blockArgTypes.push_back(
1102 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1103 blockArgLocs.push_back(v.getLoc());
1104 }
1105 }
1106
1107 OpBuilder::InsertionGuard guard(builder);
1108 Block *bodyBlock =
1109 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1110 bodyBuild(builder, loc, bodyBlock->getArguments());
1111}
1112
1113void GenericOp::getAsmBlockArgumentNames(Region &region,
1114 OpAsmSetValueNameFn setNameFn) {
1115 for (Value v : getRegionInputArgs())
1116 setNameFn(v, "in");
1117 for (Value v : getRegionOutputArgs())
1118 setNameFn(v, "out");
1119}
1120
1121void GenericOp::build(
1122 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1123 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1124 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1125 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1126 ArrayRef<NamedAttribute> attributes) {
1127 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1128 iteratorTypes, doc, libraryCall);
1129 result.addAttributes(attributes);
1130 if (bodyBuild)
1131 buildGenericRegion(builder, result.location, *result.regions.front(),
1132 inputs, outputs, bodyBuild);
1133}
1134
1135void GenericOp::build(
1136 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1137 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1138 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1139 StringRef libraryCall,
1140 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1141 ArrayRef<NamedAttribute> attributes) {
1142 build(builder, result, resultTensorTypes, inputs, outputs,
1143 builder.getAffineMapArrayAttr(indexingMaps),
1144 builder.getArrayAttr(llvm::map_to_vector(
1145 iteratorTypes,
1146 [&](utils::IteratorType iter) -> mlir::Attribute {
1147 return IteratorTypeAttr::get(builder.getContext(), iter);
1148 })),
1149 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1150 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1151 bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, ValueRange inputs,
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1158 StringRef libraryCall,
1159 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1160 ArrayRef<NamedAttribute> attributes) {
1161 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1162 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1163}
1164
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &result, ValueRange inputs,
1167 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1169 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1172 /*doc=*/"",
1173 /*libraryCall=*/"", bodyBuild, attributes);
1174}
1175
1176void GenericOp::build(
1177 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1178 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1179 ArrayRef<utils::IteratorType> iteratorTypes,
1180 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1181 ArrayRef<NamedAttribute> attributes) {
1182 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1183 iteratorTypes,
1184 /*doc=*/"",
1185 /*libraryCall=*/"", bodyBuild, attributes);
1186}
1187
1188void GenericOp::print(OpAsmPrinter &p) {
1189 p << " ";
1190
1191 // Print extra attributes.
1192 auto genericAttrNames = linalgTraitAttrNames();
1193
1194 llvm::StringSet<> genericAttrNamesSet;
1195 genericAttrNamesSet.insert_range(genericAttrNames);
1196 SmallVector<NamedAttribute, 8> genericAttrs;
1197 for (auto attr : (*this)->getAttrs()) {
1198 if (attr.getName() == getIteratorTypesAttrName()) {
1199 auto iteratorTypes =
1200 llvm::cast<ArrayAttr>(attr.getValue())
1201 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1202 // Convert IteratorType enums into the string representation. This is
1203 // needed, because tests still use the old format when 'iterator_types'
1204 // attribute is represented as an array of strings.
1205 // TODO: Remove this conversion once tests are fixed.
1206 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1207 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1208 return StringAttr::get(getContext(), stringifyIteratorType(t));
1209 });
1210
1211 genericAttrs.emplace_back(
1212 getIteratorTypesAttrName(),
1213 ArrayAttr::get(getContext(), iteratorTypeNames));
1214 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1215 genericAttrs.push_back(attr);
1216 }
1217 }
1218 if (!genericAttrs.empty()) {
1219 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1220 p << genericDictAttr;
1221 }
1222
1223 // Printing is shared with named ops, except for the region and attributes
1224 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1225
1226 genericAttrNames.push_back("operandSegmentSizes");
1227 genericAttrNamesSet.insert(genericAttrNames.back());
1228
1229 bool hasExtraAttrs = false;
1230 for (NamedAttribute n : (*this)->getAttrs()) {
1231 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1232 break;
1233 }
1234 if (hasExtraAttrs) {
1235 p << " attrs = ";
1236 p.printOptionalAttrDict((*this)->getAttrs(),
1237 /*elidedAttrs=*/genericAttrNames);
1238 }
1239
1240 // Print region.
1241 if (!getRegion().empty()) {
1242 p << ' ';
1243 p.printRegion(getRegion());
1244 }
1245
1246 // Print results.
1247 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1248}
1249
1250ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1251 DictionaryAttr dictAttr;
1252 // Parse the core linalg traits that must check into a dictAttr.
1253 // The name is unimportant as we will overwrite result.attributes.
1254 // The core linalg traits must contain the information necessary to pass the
1255 // verifier.
1256 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1257 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1258 return failure();
1259 result.attributes.assign(dictAttr.getValue().begin(),
1260 dictAttr.getValue().end());
1261
1262 // Convert array of string into an array of IteratorType enums. This is
1263 // needed, because tests still use the old format when 'iterator_types'
1264 // attribute is represented as an array of strings.
1265 // TODO: Remove this conversion once tests are fixed.
1266 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1267 result.attributes.get(getIteratorTypesAttrName(result.name)));
1268 if (!iteratorTypes) {
1269 return parser.emitError(attributeLocation)
1270 << "expected " << getIteratorTypesAttrName(result.name)
1271 << " array attribute";
1272 }
1273
1274 SmallVector<Attribute> iteratorTypeAttrs;
1275
1276 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1277 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1278 if (!maybeIteratorType.has_value())
1279 return parser.emitError(parser.getCurrentLocation())
1280 << "unexpected iterator_type (" << s << ")";
1281
1282 iteratorTypeAttrs.push_back(
1283 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1284 }
1285 result.attributes.set(getIteratorTypesAttrName(result.name),
1286 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1287
1288 // Parsing is shared with named ops, except for the region.
1289 SmallVector<Type, 1> inputTypes, outputTypes;
1290 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1291 return failure();
1292
1293 // Optional attributes may be added.
1294 if (succeeded(parser.parseOptionalKeyword("attrs")))
1295 if (failed(parser.parseEqual()) ||
1296 failed(parser.parseOptionalAttrDict(result.attributes)))
1297 return failure();
1298
1299 std::unique_ptr<Region> region = std::make_unique<Region>();
1300 if (parser.parseRegion(*region, {}))
1301 return failure();
1302 result.addRegion(std::move(region));
1303
1304 // Generic ops may specify that a subset of its outputs are tensors. Such
1305 // outputs are specified in the result type.
1306 // TODO: may need to move output parsing before region parsing.
1307 // Need to wait for declarative assembly resolution to decide.
1308 SmallVector<Type, 1> outputTensorsTypes;
1309 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1310 return failure();
1311 result.addTypes(outputTensorsTypes);
1312
1313 return success();
1314}
1315
1318 &effects,
1319 LinalgOp linalgOp) {
1320 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1321 if (!llvm::isa<MemRefType>(operand.getType()))
1322 continue;
1323 effects.emplace_back(
1324 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1325 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1326 }
1327
1328 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1329 if (!llvm::isa<MemRefType>(operand.get().getType()))
1330 continue;
1331 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1332 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1333 /*effectOnFullRegion=*/true,
1335 }
1336 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1337 /*effectOnFullRegion=*/true,
1339 }
1340}
1341
1342void GenericOp::getEffects(
1343 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1344 &effects) {
1345 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1346}
1347
1350 // Operands with value semantics are speculatable, while operands with memory
1351 // semantics are not.
1352 if (!linalgOp.hasPureTensorSemantics())
1354 // The body of the op can still have speculation in its region.
1356}
1357
1358Speculation::Speculatability GenericOp::getSpeculatability() {
1359 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1360}
1361
1362namespace {
1363
1364/// Remove linalg operations that are just copying the values from inputs to
1365/// results. In the memref case, the operation must be copying to and from the
1366/// same value. Requirements are:
1367/// 1) All iterator types are parallel
1368/// 2) The body contains just a yield operation with the yielded values being
1369/// the arguments corresponding to the operands.
1370template <typename OpTy>
1371struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1372 using OpRewritePattern<OpTy>::OpRewritePattern;
1373
1374 LogicalResult matchAndRewrite(OpTy linalgOp,
1375 PatternRewriter &rewriter) const override {
1376 // All indexing maps must be equal. It follows that they are permutations.
1377 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1378 return failure();
1379
1380 // Check that the body of the linalg operation is just a linalg.yield
1381 // operation.
1382 Block &body = linalgOp->getRegion(0).front();
1383 if (!llvm::hasSingleElement(body))
1384 return failure();
1385 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1386 if (!yieldOp)
1387 return failure();
1388
1389 // In the buffer case, we need to check exact buffer equality.
1390 if (linalgOp.hasPureBufferSemantics()) {
1391 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1392 linalgOp.getDpsInputOperand(0)->get() !=
1393 linalgOp.getDpsInitOperand(0)->get()) {
1394 return rewriter.notifyMatchFailure(
1395 linalgOp, "expected single input and output to be the same value");
1396 }
1397
1398 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1399 if (!yieldArg || yieldArg.getOwner() != &body) {
1400 return rewriter.notifyMatchFailure(linalgOp,
1401 "cannot fold fill-like op");
1402 }
1403
1404 rewriter.eraseOp(linalgOp);
1405 return success();
1406 }
1407
1408 if (!linalgOp.hasPureTensorSemantics()) {
1409 return rewriter.notifyMatchFailure(
1410 linalgOp, "mixed semantics is not supported yet");
1411 }
1412
1413 // Get the argument number of the returned values. That is the operand
1414 // number to use for replacing uses of this operation.
1415 SmallVector<Value> returnedArgs;
1416 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1417 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1418 if (!yieldArg || yieldArg.getOwner() != &body)
1419 return failure();
1420 unsigned argumentNumber = yieldArg.getArgNumber();
1421 Value returnedArg = linalgOp->getOperand(argumentNumber);
1422 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1423 // The input can have a different type than the result, e.g. a dynamic
1424 // input dimension can be turned into a static output dimension.
1425 Type returnType = returnedArg.getType();
1426 if (returnType != resultType) {
1427 // Distinguish between sparse conversion or dense tensor casting.
1428 // TODO: unify the two ops?
1431 returnedArg = sparse_tensor::ConvertOp::create(
1432 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1433 else {
1434 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1435 resultType))
1436 return failure();
1437 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1438 resultType, returnedArg);
1439 }
1440 }
1441 returnedArgs.push_back(returnedArg);
1442 }
1443
1444 if (returnedArgs.size() != linalgOp->getNumResults())
1445 return failure();
1446 rewriter.replaceOp(linalgOp, returnedArgs);
1447 return success();
1448 }
1449};
1450
1451} // namespace
1452
1453void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1454 MLIRContext *context) {
1455 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1456}
1457
1458LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1459 return memref::foldMemRefCast(*this);
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// MapOp
1464//===----------------------------------------------------------------------===//
1465
1466static ParseResult parseDstStyleOp(
1468 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1469 nullptr) {
1470 // Parse `ins` and `outs`.
1471 SmallVector<Type, 4> inputTypes, outputTypes;
1472 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1473 /*addOperandSegmentSizes=*/false))
1474 return failure();
1475
1476 // Add result types.
1477 for (Type outputType : outputTypes) {
1478 if (llvm::isa<RankedTensorType>(outputType))
1479 result.addTypes(outputType);
1480 }
1481
1482 // Parse required attributes.
1483 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1484 return failure();
1485
1486 // Parse optional attributes.
1487 if (parser.parseOptionalAttrDict(result.attributes))
1488 return failure();
1489 return success();
1490}
1491
1492void MapOp::getAsmBlockArgumentNames(Region &region,
1493 OpAsmSetValueNameFn setNameFn) {
1494 for (Value v : getRegionInputArgs())
1495 setNameFn(v, "in");
1496 for (Value v : getRegionOutputArgs())
1497 setNameFn(v, "init");
1498}
1499
1500void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1501 if (!getResults().empty())
1502 setNameFn(getResults().front(), "mapped");
1503}
1504
1505void MapOp::build(
1506 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1507 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1508 ArrayRef<NamedAttribute> attributes) {
1509 build(builder, result, TypeRange{}, inputs, init);
1510 result.addAttributes(attributes);
1511
1512 // Add output types for `RankedTensorType` output arguments.
1513 Type initType = init.getType();
1514 if (llvm::isa<RankedTensorType>(initType))
1515 result.addTypes(initType);
1516
1517 if (bodyBuild)
1518 buildGenericRegion(builder, result.location, *result.regions.front(),
1519 inputs, /*outputs=*/{init}, bodyBuild);
1520}
1521
1523 const OperationName &payloadOpName,
1524 const NamedAttrList &payloadOpAttrs,
1525 ArrayRef<Value> operands,
1526 bool initFirst = false, bool mapInit = true) {
1527 OpBuilder b(parser.getContext());
1528 Region *body = result.addRegion();
1529 Block &block = body->emplaceBlock();
1530 b.setInsertionPointToStart(&block);
1531 for (auto &operand : operands) {
1532 block.addArgument(
1533 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1534 b.getUnknownLoc());
1535 }
1536 SmallVector<Value> payloadOpOperands;
1537 // If initFirst flag is enabled, we consider init as the first position of
1538 // payload operands.
1539 if (initFirst) {
1540 if (mapInit)
1541 payloadOpOperands.push_back(block.getArguments().back());
1542 for (const auto &arg : block.getArguments().drop_back())
1543 payloadOpOperands.push_back(arg);
1544 } else {
1545 payloadOpOperands = {block.getArguments().begin(),
1546 block.getArguments().end() - int(!mapInit)};
1547 }
1548
1549 Operation *payloadOp = b.create(
1550 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1551 payloadOpOperands,
1552 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1553 .getElementType()},
1554 payloadOpAttrs);
1555 YieldOp::create(b, result.location, payloadOp->getResults());
1556}
1557
1558ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1559 std::optional<OperationName> payloadOpName;
1560 NamedAttrList payloadOpAttrs;
1561 if (succeeded(parser.parseOptionalLBrace())) {
1562 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1563 if (failed(operationName))
1564 return failure();
1565 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1566 return failure();
1567 payloadOpName = operationName.value();
1568 if (parser.parseRBrace())
1569 return failure();
1570 }
1571
1572 if (parseDstStyleOp(parser, result))
1573 return failure();
1574
1575 if (payloadOpName.has_value()) {
1576 if (!result.operands.empty())
1577 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1578 payloadOpAttrs, ArrayRef(result.operands), false,
1579 false);
1580 else
1581 result.addRegion();
1582 } else {
1583 SmallVector<OpAsmParser::Argument> regionArgs;
1584 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1585 /*allowType=*/true, /*allowAttrs=*/true)) {
1586 return failure();
1587 }
1588 Region *body = result.addRegion();
1589 if (parser.parseRegion(*body, regionArgs))
1590 return failure();
1591 }
1592 return success();
1593}
1594
1595static bool canUseShortForm(Block *body, bool initFirst = false,
1596 bool mapInit = true) {
1597 // `intFirst == true` implies that we want to map init arg
1598 if (initFirst && !mapInit)
1599 return false;
1600 // Check if the body can be printed in short form. The following 4 conditions
1601 // must be satisfied:
1602
1603 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1604 if (body->getOperations().size() != 2)
1605 return false;
1606 Operation &payload = body->getOperations().front();
1607
1608 // 2) The payload op must have the same number of operands as the number of
1609 // block arguments.
1610 if (payload.getNumOperands() == 0 ||
1611 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1612 return false;
1613
1614 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1615 // must be the first operand of the payload op, otherwise, the operands
1616 // must match the block arguments in order.
1617 if (initFirst) {
1618 // check init
1619 if (payload.getOperands().back() != body->getArgument(0))
1620 return false;
1621 // check rest
1622 for (const auto &[operand, bbArg] :
1623 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1624 if (bbArg != operand)
1625 return false;
1626 }
1627 } else {
1628 for (const auto &[operand, bbArg] :
1629 llvm::zip(payload.getOperands(),
1630 body->getArguments().drop_back(int(!mapInit)))) {
1631 if (bbArg != operand)
1632 return false;
1633 }
1634 }
1635
1636 // 4) The `yield` operand must be the result of the payload op.
1637 auto yieldOp = cast<YieldOp>(body->getTerminator());
1638 return yieldOp.getNumOperands() == 1 &&
1639 yieldOp.getOperand(0).getDefiningOp() &&
1640 yieldOp.getOperand(0).getDefiningOp() == &payload;
1641}
1642
1643static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1644 SmallVector<StringRef> elidedAttrs;
1645 std::string attrToElide;
1646 p << " { " << payloadOp->getName().getStringRef();
1647 for (const auto &attr : payloadOp->getAttrs()) {
1648 auto fastAttr =
1649 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1650 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1651 attrToElide = attr.getName().str();
1652 elidedAttrs.push_back(attrToElide);
1653 break;
1654 }
1655 }
1656 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1657 p << " }";
1658}
1659
1660void MapOp::print(OpAsmPrinter &p) {
1661 Block *mapper = getBody();
1662 bool useShortForm =
1663 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1664 if (useShortForm) {
1665 printShortForm(p, &mapper->getOperations().front());
1666 }
1667
1668 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1669 p.printOptionalAttrDict((*this)->getAttrs());
1670
1671 if (!useShortForm) {
1672 // Print region if the payload op was not detected.
1673 p.increaseIndent();
1674 p.printNewline();
1675 p << "(";
1676 llvm::interleaveComma(mapper->getArguments(), p,
1677 [&](auto arg) { p.printRegionArgument(arg); });
1678 p << ") ";
1679
1680 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1681 p.decreaseIndent();
1682 }
1683}
1684
1685LogicalResult MapOp::verify() {
1686 auto *bodyBlock = getBody();
1687 auto blockArgs = bodyBlock->getArguments();
1688
1689 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1690 // region.
1691 if (getInputs().size() + 1 != blockArgs.size())
1692 return emitOpError() << "expects number of operands to match the arity of "
1693 "mapper, but got: "
1694 << getInputs().size() + 1 << " and "
1695 << blockArgs.size();
1696
1697 // The parameters of mapper should all match the element type of inputs.
1698 for (const auto &[bbArgType, inputArg] :
1699 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1700 auto inputElemType =
1701 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1702 if (bbArgType != inputElemType) {
1703 return emitOpError() << "expected element type of input " << inputElemType
1704 << " to match bbArg type " << bbArgType;
1705 }
1706 }
1707
1708 // The shape of each input must match the shape of the output.
1709 auto outputShape = getInit().getType().getShape();
1710 for (Type inputArgType : TypeRange{getInputs()}) {
1711 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1712 if (inputElemShape != outputShape) {
1713 return emitOpError() << "expected shape of input (" << inputElemShape
1714 << ") to match shape of output (" << outputShape
1715 << ")";
1716 }
1717 }
1718
1719 return success();
1720}
1721
1722SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1723 int64_t rank = getInit().getType().getRank();
1724 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1725}
1726
1727ArrayAttr MapOp::getIndexingMaps() {
1728 Builder builder(getContext());
1729 int64_t rank = getInit().getType().getRank();
1730 int64_t numIndexingMaps = getOperands().size();
1731 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1732 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1733}
1734
1735void MapOp::getEffects(
1736 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1737 &effects) {
1738 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1739}
1740
1741Speculation::Speculatability MapOp::getSpeculatability() {
1742 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1743}
1744
1745//===----------------------------------------------------------------------===//
1746// ReduceOp
1747//===----------------------------------------------------------------------===//
1748
1749void ReduceOp::getAsmBlockArgumentNames(Region &region,
1750 OpAsmSetValueNameFn setNameFn) {
1751 for (Value v : getRegionInputArgs())
1752 setNameFn(v, "in");
1753 for (Value v : getRegionOutputArgs())
1754 setNameFn(v, "init");
1755}
1756
1757void ReduceOp::getAsmResultNames(
1758 function_ref<void(Value, StringRef)> setNameFn) {
1759 if (!getResults().empty())
1760 setNameFn(getResults().front(), "reduced");
1761}
1762
1763void ReduceOp::build(
1764 OpBuilder &builder, OperationState &result, ValueRange inputs,
1765 ValueRange inits, ArrayRef<int64_t> dimensions,
1766 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1767 ArrayRef<NamedAttribute> attributes) {
1768 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1769 result.addAttributes(attributes);
1770
1771 // Add output types for `RankedTensorType` output arguments.
1772 for (Value init : inits) {
1773 Type initType = init.getType();
1774 if (llvm::isa<RankedTensorType>(initType))
1775 result.addTypes(initType);
1776 }
1777
1778 if (bodyBuild)
1779 buildGenericRegion(builder, result.location, *result.regions.front(),
1780 inputs, inits, bodyBuild);
1781}
1782
1783SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1784 int64_t inputRank =
1785 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1786 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1787 utils::IteratorType::parallel);
1788 for (int64_t reductionDim : getDimensions())
1789 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1790 return iteratorTypes;
1791}
1792
1793ArrayAttr ReduceOp::getIndexingMaps() {
1794 int64_t inputRank =
1795 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1796 SmallVector<AffineMap> affineMaps(
1797 getNumDpsInputs(),
1799 AffineMap resultMap =
1801 .dropResults(getDimensions());
1802 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1803 affineMaps.push_back(resultMap);
1804 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1805}
1806
1807void ReduceOp::getEffects(
1808 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1809 &effects) {
1810 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1811}
1812
1813Speculation::Speculatability ReduceOp::getSpeculatability() {
1814 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1815}
1816
1817static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1818 NamedAttrList &attributes,
1819 StringRef attributeName) {
1820 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1821 return failure();
1822
1823 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1824 return success();
1825}
1826
1827ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1828 std::optional<OperationName> payloadOpName;
1829 NamedAttrList payloadOpAttrs;
1830 if (succeeded(parser.parseOptionalLBrace())) {
1831 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1832 if (failed(operationName))
1833 return failure();
1834 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1835 return failure();
1836 payloadOpName = operationName.value();
1837 if (parser.parseRBrace())
1838 return failure();
1839 }
1840
1841 if (parseDstStyleOp(
1842 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1843 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1844 }))
1845 return failure();
1846
1847 if (payloadOpName.has_value()) {
1848 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1849 ArrayRef(result.operands), /*initFirst=*/true);
1850 } else {
1851 SmallVector<OpAsmParser::Argument> regionArgs;
1852 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1853 /*allowType=*/true, /*allowAttrs=*/true)) {
1854 return failure();
1855 }
1856
1857 Region *body = result.addRegion();
1858 if (parser.parseRegion(*body, regionArgs))
1859 return failure();
1860 }
1861
1862 return success();
1863}
1864
1865static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1866 ArrayRef<int64_t> attributeValue) {
1867 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1868}
1869
1870void ReduceOp::print(OpAsmPrinter &p) {
1871 Block *mapper = getBody();
1872 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1873 if (useShortForm) {
1874 printShortForm(p, &mapper->getOperations().front());
1875 }
1876
1877 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1878 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1879 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1880 if (!useShortForm) {
1881 // Print region if the payload op was not detected.
1882 p.increaseIndent();
1883 p.printNewline();
1884 p << "(";
1885 llvm::interleaveComma(mapper->getArguments(), p,
1886 [&](auto arg) { p.printRegionArgument(arg); });
1887 p << ") ";
1888
1889 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1890 p.decreaseIndent();
1891 }
1892}
1893
1894LogicalResult ReduceOp::verify() {
1895 ArrayRef<int64_t> dimensionsRef = getDimensions();
1896
1897 // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
1898 // of inputs and inits. Detect a mismatch early: when they differ, the
1899 // ODS-generated getInputs()/getInits() accessors compute each group's size
1900 // via floordiv of the total operand count, producing incorrect slices that
1901 // would cause out-of-bounds accesses below.
1902 if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
1903 return emitOpError()
1904 << "expected equal number of inputs and outputs (required by "
1905 "SameVariadicOperandSize), got "
1906 << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
1907 << " output(s)";
1908
1909 if (getInputs().empty())
1910 return emitOpError() << "expected at least one input";
1911
1912 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1913 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1914 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1915 return emitOpError() << "expects all inputs to have the same shapes. "
1916 "Shape at input-index "
1917 << i
1918 << " is not equal to the shape at input-index 0.";
1919 }
1920 }
1921 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1922 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1923 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1924 return emitOpError() << "expects all outputs to have the same shapes. "
1925 "Shape at output-index "
1926 << i
1927 << " is not equal to the shape at output-index 0.";
1928 }
1929 }
1930 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1931 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1932
1933 DenseSet<int64_t> dimensionsToReduce;
1934 for (int64_t dimension : dimensionsRef) {
1935 if (dimension < 0 || dimension >= inputType.getRank()) {
1936 return emitOpError()
1937 << "dimensions for reduction should be in the range [0, "
1938 << inputType.getRank() - 1 << "].";
1939 }
1940 dimensionsToReduce.insert(dimension);
1941 }
1942
1943 auto inputDims = inputType.getShape();
1944 auto initDims = initType.getShape();
1945
1946 // Input dimensions that will be left after the reduction.
1947 SmallVector<int64_t> reducedInputDims;
1948 for (const auto &en : llvm::enumerate(inputDims)) {
1949 if (!dimensionsToReduce.count(en.index()))
1950 reducedInputDims.push_back(en.value());
1951 }
1952
1953 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1954 return emitOpError() << "number of dimensions after reduction "
1955 << reducedInputDims.size()
1956 << " doesn't match the init rank "
1957 << initType.getRank();
1958 }
1959
1960 if (reducedInputDims != initDims)
1961 return emitOpError() << "init dimensions [" << initDims
1962 << "] doesn't match input dimensions after reduction ["
1963 << reducedInputDims << "]";
1964
1965 Block *block = getBody();
1966 if (block->getNumArguments() != this->getNumOperands())
1967 return emitOpError()
1968 << "mismatching number of operands and block arguments";
1969
1970 // Check that the first block arguments match the element type of the inputs.
1971 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1972 Type inputElementType =
1973 llvm::cast<ShapedType>(input.getType()).getElementType();
1974 if (inputElementType != bbArg.getType())
1975 return emitOpError()
1976 << "input element type " << inputElementType
1977 << " does not match corresponding block argument type "
1978 << bbArg.getType();
1979 }
1980
1981 // Check that the last block arguments match the element type of the outputs.
1982 for (auto [output, bbArg] : llvm::zip(
1983 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1984 auto outputElementType =
1985 llvm::cast<ShapedType>(output.getType()).getElementType();
1986 if (outputElementType != bbArg.getType())
1987 return emitOpError()
1988 << "output element type " << outputElementType
1989 << " does not match corresponding block argument type "
1990 << bbArg.getType();
1991 }
1992 return success();
1993}
1994
1995//===----------------------------------------------------------------------===//
1996// TransposeOp
1997//===----------------------------------------------------------------------===//
1998
1999static void buildIdentityRegion(OpBuilder &builder, Location loc,
2000 Region &region, ValueRange inputs,
2001 ValueRange outputs) {
2002 buildGenericRegion(builder, loc, region, inputs, outputs,
2003 [](OpBuilder &b, Location loc, ValueRange args) {
2004 if (!args.empty())
2005 linalg::YieldOp::create(b, loc, args[0]);
2006 });
2007}
2008
2009void TransposeOp::build(::mlir::OpBuilder &builder,
2010 ::mlir::OperationState &result, Value input, Value init,
2011 DenseI64ArrayAttr permutation,
2012 ArrayRef<NamedAttribute> attributes) {
2013 result.addOperands(input);
2014 result.addOperands(init);
2015 result.addAttribute(getPermutationAttrName(result.name), permutation);
2016 result.addAttributes(attributes);
2017
2018 // Add output types for `RankedTensorType` output arguments.
2019 Type initType = init.getType();
2020 if (llvm::isa<RankedTensorType>(initType))
2021 result.addTypes(initType);
2022
2023 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2024 init);
2025}
2026
2027void TransposeOp::build(::mlir::OpBuilder &builder,
2028 ::mlir::OperationState &result, Value input, Value init,
2029 ArrayRef<int64_t> permutation,
2030 ArrayRef<NamedAttribute> attributes) {
2031 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2032 attributes);
2033}
2034
2035ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2037 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2038 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2039 })))
2040 return failure();
2041
2042 OpBuilder builder(parser.getContext());
2043 buildIdentityRegion(builder, result.location, *result.addRegion(),
2044 /*inputs=*/result.operands,
2045 /*outputs=*/{});
2046 return success();
2047}
2048
2049void TransposeOp::getAsmResultNames(
2050 function_ref<void(Value, StringRef)> setNameFn) {
2051 if (!getResults().empty())
2052 setNameFn(getResults().front(), "transposed");
2053}
2054
2055void TransposeOp::print(OpAsmPrinter &p) {
2056 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2057 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2058 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2059}
2060
2061LogicalResult TransposeOp::verify() {
2062 ArrayRef<int64_t> permutationRef = getPermutation();
2063
2064 if (!isPermutationVector(permutationRef))
2065 return emitOpError("permutation is not valid");
2066
2067 auto inputType = getInput().getType();
2068 auto initType = getInit().getType();
2069
2070 int64_t rank = inputType.getRank();
2071
2072 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2073 "init")))
2074 return failure();
2075
2076 if (rank != static_cast<int64_t>(permutationRef.size()))
2077 return emitOpError() << "size of permutation " << permutationRef.size()
2078 << " does not match the argument rank " << rank;
2079
2080 auto inputDims = inputType.getShape();
2081 auto initDims = initType.getShape();
2082
2083 for (int64_t i = 0; i < rank; ++i) {
2084 int64_t inputDim = inputDims[permutationRef[i]];
2085 int64_t initDim = initDims[i];
2086
2087 if (inputDim != initDim) {
2088 return emitOpError() << "dim(result, " << i << ") = " << initDim
2089 << " doesn't match dim(input, permutation[" << i
2090 << "]) = " << inputDim;
2091 }
2092 }
2093
2094 return success();
2095}
2096
2097SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2098 int64_t rank = getInit().getType().getRank();
2099 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2100}
2101
2102ArrayAttr TransposeOp::getIndexingMaps() {
2103 Builder builder(getContext());
2104 int64_t rank = getInit().getType().getRank();
2105 return builder.getAffineMapArrayAttr(
2107 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2108 builder.getMultiDimIdentityMap(rank)});
2109}
2110
2111void TransposeOp::getEffects(
2112 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2113 &effects) {
2114 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2115}
2116
2117Speculation::Speculatability TransposeOp::getSpeculatability() {
2118 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2119}
2120
2121LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2122 SmallVectorImpl<OpFoldResult> &result) {
2123 // Only the tensor type is supported.
2124 if (!isa<TensorType>(getInput().getType()))
2125 return failure();
2126
2127 // Single dimension transpose.
2128 if (getPermutation().empty()) {
2129 result.push_back(getInput());
2130 return success();
2131 }
2132 // Identity permutation.
2133 if (isIdentityPermutation(getPermutation())) {
2134 result.push_back(getInput());
2135 return success();
2136 }
2137
2138 return failure();
2139}
2140
2141/// Fold transpose with transpose.
2142struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2143 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2144
2145 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2146 PatternRewriter &rewriter) const override {
2147 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2148 if (!defTransposeOp)
2149 return failure();
2150 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2151 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2152 SmallVector<int64_t> foldedPerms;
2153 foldedPerms.reserve(perms.size());
2154 for (int64_t perm : perms)
2155 foldedPerms.push_back(defPerms[perm]);
2156
2157 rewriter.replaceOpWithNewOp<TransposeOp>(
2158 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2159 foldedPerms);
2160 return success();
2161 }
2162};
2163
2164/// Rewrite a transpose of a dense splat constant into a dense splat constant of
2165/// the transposed output shape.
2166struct FoldTransposeSplatConstant : OpRewritePattern<linalg::TransposeOp> {
2167 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2168
2169 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2170 PatternRewriter &rewriter) const override {
2171 if (!transposeOp.hasPureTensorSemantics())
2172 return failure();
2173
2174 auto splatValue =
2175 getScalarConstantAttrFromDenseSplat(transposeOp.getInput());
2176 if (!splatValue.has_value())
2177 return failure();
2178
2179 auto resultType =
2180 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2181
2182 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2183 rewriter.replaceOpWithNewOp<arith::ConstantOp>(transposeOp, resultType,
2184 resultAttr);
2185 return success();
2186 }
2187};
2188
2189/// This pattern canonicalize transpose by swapping the order of
2190/// broadcast and transpose:
2191/// transpose(broadcast(input)) -> broadcast(transpose(input))
2192struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2193 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2194
2195 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2196 PatternRewriter &rewriter) const override {
2197 Value input = transposeOp.getInput();
2198 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2199 if (!input.hasOneUse() || !broadcastOp)
2200 return failure();
2201
2202 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2203 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2204
2205 // Get new perms and new dimensions.
2206 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2208 SmallVector<int64_t> resultDimensions;
2209 unsigned dimensionSize = dimensions.size();
2210 for (unsigned i = 0; i < dimensionSize; ++i)
2211 resultDimensions.push_back(invertPerm[dimensions[i]]);
2212
2213 // Create transpose result.
2214 Value broadcastInput = broadcastOp.getInput();
2215 Location loc = transposeOp.getLoc();
2216 MLIRContext *ctx = transposeOp.getContext();
2218 auto broadcastInputTy =
2219 mlir::cast<RankedTensorType>(broadcastInput.getType());
2220 unsigned inputRank = broadcastInputTy.getRank();
2221 for (unsigned i = 0; i < inputRank; ++i) {
2222 if (broadcastInputTy.isDynamicDim(i)) {
2223 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2224 ->getResult(0));
2225 } else {
2226 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2227 broadcastInputTy.getDimSize(i)));
2228 }
2229 }
2230 SmallVector<OpFoldResult> transposeResultShapes =
2231 applyPermutation(dims, resultPerms);
2232 Value transposeInit = tensor::EmptyOp::create(
2233 rewriter, transposeOp.getLoc(), transposeResultShapes,
2234 broadcastInputTy.getElementType());
2235
2236 // Create broadcast(transpose(input)).
2237 Value transposeResult =
2238 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2239 transposeInit, resultPerms)
2240 ->getResult(0);
2241 rewriter.replaceOpWithNewOp<BroadcastOp>(
2242 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2243 return success();
2244 }
2245};
2246
2247void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2248 MLIRContext *context) {
2249 results.add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2250 SwapTransposeWithBroadcast>(context);
2251}
2252
2253//===----------------------------------------------------------------------===//
2254// BroadcastOp
2255//===----------------------------------------------------------------------===//
2256
2257void BroadcastOp::build(::mlir::OpBuilder &builder,
2258 ::mlir::OperationState &result, Value input, Value init,
2259 DenseI64ArrayAttr dimensions,
2260 ArrayRef<NamedAttribute> attributes) {
2261 result.addOperands(input);
2262 result.addOperands(init);
2263 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2264 result.addAttributes(attributes);
2265
2266 // Add output types for `RankedTensorType` output arguments.
2267 Type initType = init.getType();
2268 if (llvm::isa<RankedTensorType>(initType))
2269 result.addTypes(initType);
2270
2271 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2272 init);
2273}
2274
2275void BroadcastOp::build(::mlir::OpBuilder &builder,
2276 ::mlir::OperationState &result, Value input, Value init,
2277 ArrayRef<int64_t> dimensions,
2278 ArrayRef<NamedAttribute> attributes) {
2279 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2280 attributes);
2281}
2282
2283ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2285 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2286 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2287 })))
2288 return failure();
2289
2290 OpBuilder builder(parser.getContext());
2291 buildIdentityRegion(builder, result.location, *result.addRegion(),
2292 /*inputs=*/result.operands,
2293 /*outputs=*/{});
2294 return success();
2295}
2296
2297void BroadcastOp::getAsmResultNames(
2298 function_ref<void(Value, StringRef)> setNameFn) {
2299 if (!getResults().empty())
2300 setNameFn(getResults().front(), "broadcasted");
2301}
2302
2303void BroadcastOp::print(OpAsmPrinter &p) {
2304 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2305 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2306 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2307}
2308
2309LogicalResult BroadcastOp::verify() {
2310 ArrayRef<int64_t> dimensionsRef = getDimensions();
2311
2312 auto inputType = getInput().getType();
2313 auto initType = getInit().getType();
2314
2315 int64_t inputRank = inputType.getRank();
2316 int64_t initRank = initType.getRank();
2317
2318 auto inputShape = inputType.getShape();
2319 auto initShape = initType.getShape();
2320
2321 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2322 return emitOpError() << "input rank plus added dimensions does not "
2323 "match init rank. input rank: "
2324 << inputRank
2325 << ", dimensions size: " << dimensionsRef.size()
2326 << ", init rank: " << initRank;
2327
2328 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2329 if (dim < 0 || dim >= initRank)
2330 return emitOpError() << "dimension " << idx
2331 << " is out of range. expected range: [0, "
2332 << initRank - 1 << "], got: " << dim;
2333 }
2334
2335 // Mapping from input dims to init dims.
2336 SmallVector<int64_t> dimMap;
2337 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2338 if (!llvm::is_contained(dimensionsRef, dim))
2339 dimMap.push_back(dim);
2340 }
2341
2342 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2343 // This dimensions is mapped from the input. Init and input dims should
2344 // match.
2345 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2346 return emitOpError() << "input dim " << inputDimIdx
2347 << " should match init dim " << initDimIdx
2348 << ". input: " << inputShape[inputDimIdx]
2349 << ", init: " << initShape[initDimIdx];
2350 }
2351
2352 return success();
2353}
2354
2355SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2356 int64_t rank = getInit().getType().getRank();
2357 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2358}
2359
2360ArrayAttr BroadcastOp::getIndexingMaps() {
2361 Builder builder(getContext());
2362 int64_t rank = getInit().getType().getRank();
2363 return builder.getAffineMapArrayAttr(
2364 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2365 builder.getMultiDimIdentityMap(rank)});
2366}
2367
2368void BroadcastOp::getEffects(
2369 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2370 &effects) {
2371 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2372}
2373
2374Speculation::Speculatability BroadcastOp::getSpeculatability() {
2375 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2376}
2377
2378/// Fold back-to-back broadcasts together.
2379struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2380 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2381
2382 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2383 PatternRewriter &rewriter) const override {
2384 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2385 if (!defBroadcastOp)
2386 return failure();
2387 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2388 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2389 SmallVector<int64_t> foldedDims(dimensions);
2390 Value init = broadcastOp.getInit();
2391 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2392 // Mapping from input dims to init dims.
2393 SmallVector<int64_t> dimMap;
2394 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2395 if (!llvm::is_contained(dimensions, dim))
2396 dimMap.push_back(dim);
2397 }
2398 for (auto dim : defDimensions)
2399 foldedDims.push_back(dimMap[dim]);
2400
2401 llvm::sort(foldedDims);
2402 rewriter.replaceOpWithNewOp<BroadcastOp>(
2403 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2404 return success();
2405 }
2406};
2407
2408/// Rewrite a broadcast of a dense splat constant into a dense splat constant of
2409/// the broadcast output shape.
2410struct FoldBroadcastSplatConstant : OpRewritePattern<linalg::BroadcastOp> {
2411 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2412
2413 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2414 PatternRewriter &rewriter) const override {
2415 if (!broadcastOp.hasPureTensorSemantics())
2416 return failure();
2417
2418 auto splatValue =
2419 getScalarConstantAttrFromDenseSplat(broadcastOp.getInput());
2420
2421 if (!splatValue.has_value())
2422 return failure();
2423
2424 auto resultType =
2425 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2426 if (!resultType.hasStaticShape())
2427 return rewriter.notifyMatchFailure(broadcastOp,
2428 "result type has dynamic shape");
2429
2430 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2431 rewriter.replaceOpWithNewOp<arith::ConstantOp>(broadcastOp, resultType,
2432 resultAttr);
2433 return success();
2434 }
2435};
2436
2437void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2438 MLIRContext *context) {
2439 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2440 FoldBroadcastSplatConstant>(context);
2441}
2442
2443//===----------------------------------------------------------------------===//
2444// YieldOp
2445//===----------------------------------------------------------------------===//
2446
2447void linalg::YieldOp::print(OpAsmPrinter &p) {
2448 if (getNumOperands() > 0)
2449 p << ' ' << getOperands();
2450 p.printOptionalAttrDict((*this)->getAttrs());
2451 if (getNumOperands() > 0)
2452 p << " : " << getOperandTypes();
2453}
2454
2455ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2456 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2457 SmallVector<Type, 2> types;
2458 SMLoc loc = parser.getCurrentLocation();
2459 return failure(parser.parseOperandList(opInfo) ||
2460 parser.parseOptionalAttrDict(result.attributes) ||
2461 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2462 parser.resolveOperands(opInfo, types, loc, result.operands));
2463}
2464
2465// Check the operand number and types must match the element types of the
2466// LinalgOp interface's shaped operands.
2467static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2468 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2469 return op.emitOpError("expected number of yield values (")
2470 << op.getNumOperands()
2471 << ") to match the number of inits / outs operands of the enclosing "
2472 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2473
2474 for (OpOperand &opOperand : op->getOpOperands()) {
2475 OpOperand *outputOperand =
2476 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2477 Type elementType = outputOperand->get().getType();
2478 if (isa<MemRefType, RankedTensorType>(elementType))
2479 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2480 if (opOperand.get().getType() != elementType)
2481 return op.emitOpError("type of yield operand ")
2482 << (opOperand.getOperandNumber() + 1) << " ("
2483 << opOperand.get().getType() << ") doesn't match "
2484 << "the element type of the enclosing linalg.generic op ("
2485 << elementType << ")";
2486 }
2487 return success();
2488}
2489
2490LogicalResult linalg::YieldOp::verify() {
2491 auto *parentOp = (*this)->getParentOp();
2492 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2493 return emitOpError("expected single non-empty parent region");
2494
2495 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2496 return verifyYield(*this, linalgOp);
2497
2498 return emitOpError("expected parent op with LinalgOp interface");
2499}
2500
2501//===----------------------------------------------------------------------===//
2502// IndexOp
2503//===----------------------------------------------------------------------===//
2504
2505LogicalResult IndexOp::verify() {
2506 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2507 if (!linalgOp)
2508 return emitOpError("expected parent op with LinalgOp interface");
2509 if (linalgOp.getNumLoops() <= getDim())
2510 return emitOpError("expected dim (")
2511 << getDim() << ") to be lower than the number of loops ("
2512 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2513 return success();
2514}
2515
2516OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2517 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2518 // Bail out if `linalg.index` does not have a proper parent yet at this
2519 // point, e.g., when calling `createOrFold` during IR construction in
2520 // `genericOp::build`.
2521 if (!linalgOp)
2522 return OpFoldResult{};
2523
2524 // Index of unit dims is always 0.
2525 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2526 uint64_t dim = getDim();
2527 assert(dim < loopBounds.size() && "Dim is out of bounds");
2528 if (loopBounds[dim] == 1)
2529 return IntegerAttr::get(IndexType::get(getContext()), 0);
2530
2531 return OpFoldResult{};
2532}
2533
2534/////// Operations corresponding to library calls defined with Tablegen ////////
2535
2536#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2537
2538#define GET_OP_CLASSES
2539#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2540
2541#define GET_OP_CLASSES
2542#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2543#define GET_OP_CLASSES
2544#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2545
2546AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2547 unsigned rank,
2548 MLIRContext *context) {
2549 if (maybeMap)
2550 return *maybeMap;
2551 if (rank == 0)
2552 return AffineMap::get(context);
2553 return AffineMap::getMultiDimIdentityMap(rank, context);
2554}
2555
2557mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2558 MLIRContext *context) {
2560 res.reserve(num);
2561 for (unsigned i = 0; i < num; ++i)
2562 res.push_back(getAffineDimExpr(startIdx++, context));
2563 return res;
2564}
2565
2568 auto rangeA = llvm::make_range(a.begin(), a.end());
2569 auto rangeB = llvm::make_range(b.begin(), b.end());
2570 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2571 return llvm::to_vector<4>(concatRanges);
2572}
2573
2574static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2575 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2576 ss << "view";
2577 for (auto size : memref.getShape())
2578 if (size < 0)
2579 ss << "sx";
2580 else
2581 ss << size << "x";
2582 if (failed(appendMangledType(ss, memref.getElementType())))
2583 return failure();
2584 if (auto as = memref.getMemorySpace()) {
2585 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2586 ss << "as" << attr.getInt();
2587 else
2588 return failure();
2589 }
2590 return success();
2591 }
2592 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2593 ss << "vector";
2594 llvm::interleave(
2595 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2596 if (failed(appendMangledType(ss, vec.getElementType())))
2597 return failure();
2598 return success();
2599 }
2601 ss << t;
2602 return success();
2603 }
2604 return failure();
2605}
2606
2608 assert(isa<LinalgOp>(op));
2609 std::string name(op->getName().getStringRef().str());
2610 std::string fun = "";
2611 for (NamedAttribute kv : op->getAttrs()) {
2612 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2613 fun = stringifyEnum(ufa.getValue()).str() + "_";
2614 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2615 fun = stringifyEnum(bfa.getValue()).str() + "_";
2616 }
2617 }
2618 name.reserve(128);
2619 llvm::replace(name, '.', '_');
2620 llvm::raw_string_ostream ss(name);
2621 ss << "_" << fun;
2622 for (Type t : op->getOperandTypes()) {
2623 if (failed(appendMangledType(ss, t)))
2624 return std::string();
2625 ss << "_";
2626 }
2627 name.pop_back();
2628 return name;
2629}
2630
2631//===----------------------------------------------------------------------===//
2632// Canonicalizers and Folders.
2633//===----------------------------------------------------------------------===//
2634
2635namespace {
2636struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2638
2639 LogicalResult matchAndRewrite(LinalgOp op,
2640 PatternRewriter &rewriter) const override {
2641 for (OpOperand &opOperand : op->getOpOperands()) {
2642 // Linalg "inputs" may be either tensor or memref type.
2643 // tensor<0xelt_type> is a convention that may not always mean
2644 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2645 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2646 if (!mt)
2647 continue;
2648 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2649 rewriter.eraseOp(op);
2650 return success();
2651 }
2652 }
2653 return failure();
2654 }
2655};
2656
2657/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2658/// result that is more static than the linalg op.
2659struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2660 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2661
2662 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2663 PatternRewriter &rewriter) const override {
2664 if (!tensor::canFoldIntoProducerOp(castOp))
2665 return failure();
2666
2667 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2668 if (!linalgOp)
2669 return failure();
2670
2671 // Cast can be in conditionally reachable region, if which case folding will
2672 // generate invalid code. Only conservatively fold ops in same block for
2673 // now.
2674 if (castOp->getBlock() != linalgOp->getBlock())
2675 return failure();
2676
2677 OpBuilder::InsertionGuard guard(rewriter);
2678 rewriter.setInsertionPoint(linalgOp);
2679
2680 Location loc = linalgOp.getLoc();
2681 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2682 unsigned resultNumber = resultValue.getResultNumber();
2683 auto resultType =
2684 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2685 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2686 // going from a more dynamic shape to a less dynamic shape. If the producer
2687 // for this cast, i.e. producer of the out operand, is also an operation
2688 // that folds with tensor.cast consumer (like this pattern), the cast will
2689 // continue to propagate as far up the stack as it can go.
2690 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2691 Value newOperand =
2692 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2693 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2694 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2695 linalgOp.getDpsInits().end());
2696 outputOperands[resultNumber] = newOperand;
2697 newOperands.append(outputOperands.begin(), outputOperands.end());
2698
2699 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2700 linalgOp->result_type_end());
2701 resultTypes[resultNumber] = resultType;
2702 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2703
2704 // Create a tensor.cast operation back to the original type.
2705 Value castBack = tensor::CastOp::create(
2706 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2707
2708 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2709 results[resultNumber] = castBack;
2710 rewriter.replaceOp(linalgOp, results);
2711 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2712 return success();
2713 }
2714};
2715
2716/// For each of the operand in `operands` this function maps the static sizes of
2717/// dimensions to their affine dim expressions.
2718static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2719 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2720 for (OpOperand &opOperand : operands) {
2721 if (linalgOp.isScalar(&opOperand))
2722 continue;
2723 Value src = opOperand.get();
2724 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2725 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2726
2727 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2728 // `tensor.cast` operation and source of the cast operation has a static
2729 // shape, then assign it to the `sourceShape`.
2730 auto *parentOp = src.getDefiningOp();
2731 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2732 if (parentOp) {
2733 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2734 Value castSource = castOp.getSource();
2735 auto castSourceType =
2736 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2737 if (castSourceType && castSourceType.hasStaticShape())
2738 sourceShape = castSourceType.getShape();
2739 }
2740 }
2741
2742 // If the source shape's dimension has a static shape, map the affine dim
2743 // expression to the known static size.
2744 for (unsigned i = 0; i < sourceShape.size(); i++) {
2745 if (sourceType.isDynamicDim(i))
2746 continue;
2747 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2748 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2749 }
2750 }
2751}
2752
2753/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2754/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2755/// their result types is stored in `resultTypes`. If `opOperand` requires no
2756/// change then `changeNeeded` is false and same operand is added in the
2757/// `newOperands` list.
2758static void createNewOperandWithStaticSizes(
2759 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2760 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2761 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2762 bool &changeNeeded) {
2763 Value src = opOperand->get();
2764 newOperands.push_back(src);
2765 if (linalgOp.isScalar(opOperand))
2766 return;
2767 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2768 Type resultType = sourceType;
2769 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2770 resultTypes.push_back(resultType);
2771 return;
2772 }
2773 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2774 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2775 SmallVector<int64_t> newShape;
2776 // If operand is updated with new shape, `newOperandNeeded` will be
2777 // true.
2778 bool newOperandNeeded = false;
2779 for (unsigned i = 0; i < sourceShape.size(); i++) {
2780 int64_t dimShape = sourceShape[i];
2781 AffineExpr dimExpr = sourceMap.getResult(i);
2782 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2783 newShape.push_back(dimShape);
2784 continue;
2785 }
2786 // Dimension has a dynamic shape and corresponding affine dim
2787 // expression is present in the map. So assign the size for the
2788 // given affine dim expression to the dimension.
2789 newShape.push_back(affineExprToSize[dimExpr]);
2790 newOperandNeeded = true;
2791 }
2792 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2793 sourceType.getEncoding());
2794 if (newOperandNeeded) {
2795 changeNeeded = true;
2796 // Get the new operand value given its size and element type by
2797 // casting it.
2798 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2799 unsigned index = opOperand->getOperandNumber();
2800 newOperands[index] = newOperand;
2801 }
2802 if (linalgOp.isDpsInit(opOperand))
2803 resultTypes.push_back(resultType);
2804}
2805
2806/// Static shapes for the operands can be inferred if any one of the operands
2807/// have a static shape. This can be done by referring to the affine dim
2808/// expressions for the operand.
2809struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2810 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2811
2812 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2813 PatternRewriter &rewriter) const override {
2814 if (!linalgOp.hasPureTensorSemantics())
2815 return failure();
2816
2817 // Maps must be projected permutations.
2818 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2819 return !map.isProjectedPermutation();
2820 }))
2821 return failure();
2822
2823 // Maps affine dim expressions to the static size of that dimension.
2824 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2825 Location loc = linalgOp.getLoc();
2826
2827 // For each of the affine dim expression, check if the size is known. If
2828 // known add that in the map.
2829 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2830
2831 SmallVector<Value> newOperands;
2832 SmallVector<Type> resultTypes;
2833
2834 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2835 // change in their types.
2836 bool changeNeeded = false;
2837 newOperands.reserve(linalgOp->getNumOperands());
2838 resultTypes.reserve(linalgOp.getNumDpsInits());
2839
2840 // Iterate over all the operands and update the static sizes.
2841 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2842 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2843 affineExprToSize, linalgOp, newOperands,
2844 resultTypes, changeNeeded);
2845 }
2846
2847 // If the generic op has all the required static information, no
2848 // canonicalization needed.
2849 if (!changeNeeded)
2850 return failure();
2851
2852 // Clone op.
2853 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2854 SmallVector<Value> replacements;
2855 replacements.reserve(newOp->getNumResults());
2856 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2857 Value newResult = std::get<1>(it);
2858 Value oldResult = std::get<0>(it);
2859 Type newType = newResult.getType();
2860 Type oldType = oldResult.getType();
2861 replacements.push_back(
2862 (newType != oldType)
2863 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2864 : newResult);
2865 }
2866 rewriter.replaceOp(linalgOp, replacements);
2867 return success();
2868 }
2869};
2870
2871} // namespace
2872
2873// All named ops canonicalizers and folders are auto-generated in the
2874// .cpp.inc.
2875
2876//===----------------------------------------------------------------------===//
2877// SoftmaxOp
2878//===----------------------------------------------------------------------===//
2879
2880LogicalResult SoftmaxOp::verify() {
2881 ShapedType inputType = getInputOperandType();
2882 ShapedType outputType = getOutputOperandType();
2883
2884 ArrayRef<int64_t> inputShape = inputType.getShape();
2885 ArrayRef<int64_t> outputShape = outputType.getShape();
2886 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2887 return emitOpError("incompatible output shape");
2888
2889 int64_t inputRank = getInputOperandRank();
2890 int64_t dimension = getDimension();
2891 if ((dimension < 0) || (dimension >= inputRank))
2892 return emitOpError("incorrect dimension specified");
2893
2894 return success();
2895}
2896
2897SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2898 int64_t operandRank = getInputOperandRank();
2899 SmallVector<Range> loopBounds(operandRank);
2900 Location loc = getLoc();
2901 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2902 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2903 Value source = getInput();
2904 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2905 loopBounds[dim].offset = zero;
2906 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2907 loopBounds[dim].stride = one;
2908 }
2909 return loopBounds;
2910}
2911
2912SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2913 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2914 utils::IteratorType::parallel);
2915 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2916 return iteratorTypes;
2917}
2918
2919FailureOr<TilingResult>
2920SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2921 ArrayRef<OpFoldResult> offsets,
2922 ArrayRef<OpFoldResult> sizes) {
2923 int64_t rank = getInputOperandRank();
2924 auto oneAttr = builder.getI64IntegerAttr(1);
2925 SmallVector<OpFoldResult> strides(rank, oneAttr);
2926 SmallVector<Value> tiledOperands;
2927 Operation *inputSlice =
2928 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2929 if (!inputSlice) {
2930 return emitOpError("failed to compute input slice");
2931 }
2932 tiledOperands.emplace_back(inputSlice->getResult(0));
2933 Operation *outputSlice =
2934 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2935 if (!outputSlice) {
2936 return emitOpError("failed to compute output slice");
2937 }
2938 tiledOperands.emplace_back(outputSlice->getResult(0));
2939
2940 SmallVector<Type, 4> resultTypes;
2941 if (hasPureTensorSemantics())
2942 resultTypes.push_back(tiledOperands[1].getType());
2943 Operation *tiledOp =
2944 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2945
2946 return TilingResult{
2947 {tiledOp},
2948 SmallVector<Value>(tiledOp->getResults()),
2949 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2950}
2951
2952LogicalResult SoftmaxOp::getResultTilePosition(
2953 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2954 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2955 SmallVector<OpFoldResult> &resultSizes) {
2956 if (resultNumber == 0) {
2957 resultOffsets.assign(offsets.begin(), offsets.end());
2958 resultSizes.assign(sizes.begin(), sizes.end());
2959 return success();
2960 }
2961 return failure();
2962}
2963
2964// cast(dynamic) -> static.
2965LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2966 return memref::foldMemRefCast(*this);
2967}
2968
2969LogicalResult
2970SoftmaxOp::reifyResultShapes(OpBuilder &b,
2971 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2972 SmallVector<OpFoldResult> shapes;
2973 Location loc = getOperation()->getLoc();
2974 IRRewriter rewriter(b);
2975 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2976 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2977 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2978 if (!outputShapedType.isDynamicDim(dim)) {
2979 // Static dim: Return IntegerAttr.
2980 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2981 } else {
2982 // Dynamic dim: Return Value.
2983 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2984 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2985 }
2986 }
2987 reifiedReturnShapes.emplace_back(std::move(shapes));
2988 return success();
2989}
2990
2991void SoftmaxOp::getEffects(
2992 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2993 &effects) {
2994 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2995 if (!llvm::isa<MemRefType>(operand.getType()))
2996 continue;
2997 effects.emplace_back(MemoryEffects::Read::get(),
2998 &getOperation()->getOpOperand(index), /*stage=*/0,
2999 /*effectOnFullRegion=*/true,
3001 }
3002
3003 for (OpOperand &operand : getDpsInitsMutable()) {
3004 if (!llvm::isa<MemRefType>(operand.get().getType()))
3005 continue;
3006 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
3007 /*effectOnFullRegion=*/true,
3009 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
3010 /*effectOnFullRegion=*/true,
3012 }
3013}
3014
3015// Helper functions for softmax decomposition.
3016// @{
3017
3018// Helper function to produce the iterator types (reduction or parallel) and
3019// affine maps for the iterators used in the decomposition of softmax.
3020// This method creates:
3021// If allParallel == true:
3022// - iterator type: {parallel, ..., parallel}
3023// - affine maps:
3024// -- identity with inputRank dimensions.
3025// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3026// where N == inputRank.
3027//
3028// If allParallel == false:
3029// - iterator type at dim(i) == parallel for i != \p dim and
3030// dim(dim) == reduction.
3031// - affine map:
3032// -- identity with inputRank dimensions.
3033// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3034// where N == inputRank.
3035static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3037 int64_t dim, bool allParallel = false) {
3038 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
3039 utils::IteratorType::parallel);
3040 if (!allParallel)
3041 iteratorTypes[dim] = utils::IteratorType::reduction;
3042 MLIRContext *ctxt = builder.getContext();
3043 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
3044 SmallVector<AffineExpr, 2> affineExprs;
3045 for (int i = 0; i < inputRank; i++) {
3046 if (i != dim)
3047 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
3048 }
3049 auto reductionMap =
3050 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
3051 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
3052 return std::make_tuple(iteratorTypes, indexingMaps);
3053}
3054
3055// Helper function to produce a linalg.generic that computes a reduction on
3056// dimension \p dim with the operation type \p T.
3057template <typename T>
3058static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
3059 int64_t dim) {
3060 auto inputType = cast<ShapedType>(input.getType());
3061 ArrayRef<int64_t> inputShape = inputType.getShape();
3062 int64_t inputRank = inputShape.size();
3063 auto [iteratorTypes, indexingMaps] =
3064 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
3065 assert(indexingMaps.size() == 2 &&
3066 "We should have two maps: 1 for the input, 1 for the output");
3067 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3068
3069 auto genericOp = linalg::GenericOp::create(
3070 builder, loc, output.getType(), input, output, indexingMaps,
3071 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
3072 Value result = T::create(b, loc, args[0], args[1]);
3073 linalg::YieldOp::create(b, loc, result);
3074 });
3075 return genericOp.getResult(0);
3076}
3077
3078/// Produce a linalg generic that computes the second step of the softmax
3079/// decomposition: res = exp(input - max), where \p max is the max of \p input
3080/// on dimension \p dim.
3081static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3082 Value max, Value output, int64_t dim) {
3083 auto inputType = cast<ShapedType>(input.getType());
3084 ArrayRef<int64_t> inputShape = inputType.getShape();
3085 int64_t inputRank = inputShape.size();
3086 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3087 builder, inputRank, dim, /*allParallel=*/true);
3088 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3089 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3090 // Add the affine map for the output argument.
3091 indexingMaps.push_back(indexingMaps[0]);
3092 auto genericOp = linalg::GenericOp::create(
3093 builder, loc, input.getType(), ValueRange{input, max}, output,
3094 indexingMaps, iteratorTypes,
3095 [&](OpBuilder &b, Location loc, ValueRange args) {
3096 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3097 Value result = math::ExpOp::create(b, loc, diff);
3098 linalg::YieldOp::create(b, loc, result);
3099 });
3100 return genericOp.getResult(0);
3101}
3102
3103/// Produce a linalg generic that computes the final step of the softmax
3104/// decomposition.
3105/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3106/// yield n / d
3107/// }
3108static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3109 Value denominator, Value output, int64_t dim) {
3110 auto inputType = cast<ShapedType>(numerator.getType());
3111 ArrayRef<int64_t> inputShape = inputType.getShape();
3112 int64_t inputRank = inputShape.size();
3113 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3114 builder, inputRank, dim, /*allParallel=*/true);
3115 assert(indexingMaps.size() == 2 &&
3116 "We should have one map for each input (2)");
3117 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3118 // Add the affine map for the output tensor.
3119 indexingMaps.push_back(indexingMaps[0]);
3120 auto genericOp = linalg::GenericOp::create(
3121 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3122 output, indexingMaps, iteratorTypes,
3123 [&](OpBuilder &b, Location loc, ValueRange args) {
3124 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3125 linalg::YieldOp::create(b, loc, result);
3126 });
3127 return genericOp.getResult(0);
3128}
3129// @} End helper functions for softmax decomposition.
3130
3131/// Given an N-dimensional tensor x, this method converts
3132/// softmax(x) to the following sequence of operations:
3133///
3134/// 1. Compute the max of x along dimension d. This results
3135/// in a N-1 dimensional tensor m.
3136/// m = max(x, dim = d)
3137///
3138/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3139/// a N dimensional tensor z.
3140/// z = exp(x - m)
3141///
3142/// 3. Compute the sum of z along dimension d. This results in
3143/// a N-1 dimensional tensor l.
3144/// l = sum(z, dim = d)
3145///
3146/// 4. Divide z and l. This gives the N-dimensional softmax.
3147/// softmax = z / l
3148///
3149FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3150 OpBuilder::InsertionGuard guard(b);
3151 b.setInsertionPoint(*this);
3152 Location loc = getLoc();
3153 Value input = getInput();
3154 ShapedType inputType = getInputOperandType();
3155 Type elementType = inputType.getElementType();
3156 int64_t reductionDim = getDimension();
3157 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3158 Value output = getOutput();
3159 dims.erase(dims.begin() + reductionDim);
3160 // Step 1: Compute max along dim.
3161 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3162 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3163 elementType, b, loc,
3164 /*useOnlyFiniteValue=*/true);
3165 Value neutralForMaxFInit =
3166 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3167 .result();
3168 Value max =
3169 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3170
3171 // Step 2: Subtract max from input and exponentiate.
3172 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3173
3174 // Step 3: Compute sum along dim.
3175 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3176 b, loc, /*useOnlyFiniteValue=*/true);
3177 Value zeroInit =
3178 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3179 Value denominator =
3180 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3181
3182 // Step 4: Compute softmax.
3183 Value result =
3184 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3185 return SmallVector<Value>{result};
3186}
3187
3188//===----------------------------------------------------------------------===//
3189// WinogradFilterTransformOp
3190//===----------------------------------------------------------------------===//
3191
3192LogicalResult WinogradFilterTransformOp::verify() {
3193 auto filterType = cast<ShapedType>(getFilter().getType());
3194 ArrayRef<int64_t> filterShape = filterType.getShape();
3195 int64_t filterH = filterShape[getFilterHDim()];
3196 int64_t filterW = filterShape[getFilterWDim()];
3197 WinogradConv2DFmr fmr = getFmr();
3198 int64_t m, r;
3199 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3200
3201 if (filterH != r && filterH != 1)
3202 return emitOpError("expect filter height either equals to r or 1");
3203 if (filterW != r && filterW != 1)
3204 return emitOpError("expect filter width either equals to r or 1");
3205 if (filterH == 1 && filterW == 1)
3206 return emitOpError("expect either filter height or width equals to r");
3207
3208 SmallVector<int64_t> expectedOutputShape;
3209 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3210 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3211 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3212 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3213
3214 auto outputType = cast<ShapedType>(getOutput().getType());
3215 ArrayRef<int64_t> outputShape = outputType.getShape();
3216 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3217 return emitOpError("the output shape is not expected");
3218 }
3219 return success();
3220}
3221
3222SmallVector<Range>
3223WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3224 Location loc = getLoc();
3225 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3226 IntegerAttr oneAttr = builder.getIndexAttr(1);
3227 Value filter = getFilter();
3228 int64_t filterRank = getFilterOperandRank();
3229 SmallVector<Range> loopBounds(filterRank);
3230 for (unsigned dim = 0; dim < filterRank; ++dim) {
3231 loopBounds[dim].offset = zeroAttr;
3232 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3233 loopBounds[dim].stride = oneAttr;
3234 }
3235 return loopBounds;
3236}
3237
3238SmallVector<utils::IteratorType>
3239WinogradFilterTransformOp::getLoopIteratorTypes() {
3240 int64_t filterRank = getFilterOperandRank();
3241 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3242 utils::IteratorType::parallel);
3243 return iteratorTypes;
3244}
3245
3246LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3247 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3248 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3249 SmallVector<OpFoldResult> &resultSizes) {
3250 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3251 ShapedType filterType = getFilterOperandType();
3252 ArrayRef<int64_t> filterShape = filterType.getShape();
3253 int64_t filterH = filterShape[getFilterHDim()];
3254 int64_t filterW = filterShape[getFilterWDim()];
3255 WinogradConv2DFmr fmr = getFmr();
3256 int64_t m, r;
3257 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3258 int64_t alpha = m + r - 1;
3259 int64_t alphaH = filterH != 1 ? alpha : 1;
3260 int64_t alphaW = filterW != 1 ? alpha : 1;
3261 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3262 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3263
3264 resultOffsets.append(
3265 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3266 resultSizes.append(
3267 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3268
3269 return success();
3270}
3271
3272/// Implement tiling for winograd_filter_transform
3273/// The input of winograd_filter_transform is (F, KH, KW, C).
3274/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3275/// Users can specify the tile sizes of F and C.
3276/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3277/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3278FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3279 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3280 ArrayRef<OpFoldResult> sizes) {
3281 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3282 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3283 ShapedType filterType = getFilterOperandType();
3284 ArrayRef<int64_t> filterShape = filterType.getShape();
3285 int64_t filterH = filterShape[getFilterHDim()];
3286 int64_t filterW = filterShape[getFilterWDim()];
3287 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3288 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3289 SmallVector<Value> tiledOperands;
3290 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3291
3292 sliceOffsets.append(
3293 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3294 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3295 sizes[getFilterCDim()]});
3296 int64_t filterRank = getFilterOperandRank();
3297 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3298 Location loc = getLoc();
3299 auto filterSlice = tensor::ExtractSliceOp::create(
3300 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3301 tiledOperands.emplace_back(filterSlice);
3302
3303 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3304 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3305 resultSizes)))
3306 return failure();
3307
3308 int64_t outputRank = getOutputOperandRank();
3309 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3310 auto outputSlice = tensor::ExtractSliceOp::create(
3311 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3312 tiledOperands.emplace_back(outputSlice);
3313
3314 SmallVector<Type> resultTypes;
3315 resultTypes.push_back(tiledOperands[1].getType());
3316 Operation *tiledOp =
3317 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3318
3319 return TilingResult{
3320 {tiledOp},
3321 SmallVector<Value>(tiledOp->getResults()),
3322 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3323}
3324
3325//===----------------------------------------------------------------------===//
3326// WinogradInputTransformOp
3327//===----------------------------------------------------------------------===//
3328
3329LogicalResult WinogradInputTransformOp::verify() {
3330 auto inputType = cast<ShapedType>(getInput().getType());
3331 ArrayRef<int64_t> inputShape = inputType.getShape();
3332 int64_t inputH = inputShape[getInputHDim()];
3333 int64_t inputW = inputShape[getInputWDim()];
3334 WinogradConv2DFmr fmr = getFmr();
3335 int64_t m, r;
3336 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3337 int64_t tileSize = m + r - 1;
3338
3339 auto outputType = cast<ShapedType>(getOutput().getType());
3340 ArrayRef<int64_t> outputShape = outputType.getShape();
3341 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3342 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3343
3344 SmallVector<int64_t> expectedOutputShape(6, inputH);
3345 if (ShapedType::isDynamic(inputH)) {
3346 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3347 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3348 } else {
3349 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3350 expectedOutputShape[getOutputTileHDim()] =
3351 leftTransform ? (inputH - (r - 1)) / m : inputH;
3352 }
3353 if (ShapedType::isDynamic(inputW)) {
3354 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3355 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3356 } else {
3357 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3358 expectedOutputShape[getOutputTileWDim()] =
3359 rightTransform ? (inputW - (r - 1)) / m : inputW;
3360 }
3361 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3362 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3363
3364 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3365 return emitOpError("the output shape is not expected");
3366 }
3367 return success();
3368}
3369
3370SmallVector<Range>
3371WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3372 Location loc = getLoc();
3373 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3374 IntegerAttr oneAttr = builder.getIndexAttr(1);
3375 Value output = getOutput();
3376 int64_t outputRank = getOutputOperandRank();
3377 SmallVector<Range> loopBounds(outputRank);
3378 for (unsigned dim = 0; dim < outputRank; ++dim) {
3379 loopBounds[dim].offset = zeroAttr;
3380 // alphaH, alphaW, tileH, tileW, N, C
3381 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3382 loopBounds[dim].stride = oneAttr;
3383 }
3384 return loopBounds;
3385}
3386
3387SmallVector<utils::IteratorType>
3388WinogradInputTransformOp::getLoopIteratorTypes() {
3389 int64_t outputRank = getOutputOperandRank();
3390 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3391 utils::IteratorType::parallel);
3392 return iteratorTypes;
3393}
3394
3395LogicalResult WinogradInputTransformOp::getResultTilePosition(
3396 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3397 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3398 SmallVector<OpFoldResult> &resultSizes) {
3399 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3400 ShapedType outputType = getOutputOperandType();
3401 ArrayRef<int64_t> outputShape = outputType.getShape();
3402 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3403 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3404
3405 WinogradConv2DFmr fmr = getFmr();
3406 int64_t m, r;
3407 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3408 int64_t alpha = m + r - 1;
3409 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3410 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3411
3412 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3413 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3414
3415 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3416 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3417 offsets[getOutputCDim()]});
3418 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3419 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3420 sizes[getOutputCDim()]});
3421
3422 return success();
3423}
3424
3425/// Implement tiling for winograd_input_transform
3426/// The input of winograd_input_transform is (N, H, W, C).
3427/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3428/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3429/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3430/// the values for the sizes of tileH, tileW, N, C for one tile.
3431FailureOr<TilingResult>
3432WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3433 ArrayRef<OpFoldResult> offsets,
3434 ArrayRef<OpFoldResult> sizes) {
3435 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3436 WinogradConv2DFmr fmr = getFmr();
3437 int64_t m, r;
3438 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3439
3440 ShapedType outputType = getOutputOperandType();
3441 ArrayRef<int64_t> outputShape = outputType.getShape();
3442 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3443 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3444
3445 Location loc = getLoc();
3446 MLIRContext *context = builder.getContext();
3447 auto identityAffineMap =
3448 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3449 auto offsetAffineMap =
3450 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3451 Value mappedOffsetH = affine::makeComposedAffineApply(
3452 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3453 offsets[getOutputTileHDim()]);
3454 Value mappedOffsetW = affine::makeComposedAffineApply(
3455 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3456 offsets[getOutputTileWDim()]);
3457 auto sizeAffineMap = AffineMap::get(
3458 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3459 Value mappedSizeH = affine::makeComposedAffineApply(
3460 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3461 Value mappedSizeW = affine::makeComposedAffineApply(
3462 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3463
3464 SmallVector<Value> tiledOperands;
3465 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3466
3467 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3468 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3469 sliceOffsets.append(
3470 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3471 OpFoldResult sizeH =
3472 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3473 OpFoldResult sizeW =
3474 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3475 sliceSizes.append(
3476 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3477 int64_t inputRank = getInputOperandRank();
3478 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3479 auto inputSlice = tensor::ExtractSliceOp::create(
3480 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3481 tiledOperands.emplace_back(inputSlice);
3482
3483 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3484 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3485 resultSizes)))
3486 return failure();
3487
3488 int64_t outputRank = getOutputOperandRank();
3489 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3490 auto outputSlice = tensor::ExtractSliceOp::create(
3491 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3492 tiledOperands.emplace_back(outputSlice);
3493
3494 SmallVector<Type> resultTypes;
3495 resultTypes.push_back(tiledOperands[1].getType());
3496 Operation *tiledOp =
3497 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3498
3499 return TilingResult{
3500 {tiledOp},
3501 SmallVector<Value>(tiledOp->getResults()),
3502 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3503}
3504
3505//===----------------------------------------------------------------------===//
3506// WinogradOutputTransformOp
3507//===----------------------------------------------------------------------===//
3508
3509LogicalResult WinogradOutputTransformOp::verify() {
3510 auto valueType = cast<ShapedType>(getValue().getType());
3511 ArrayRef<int64_t> valueShape = valueType.getShape();
3512 int64_t valueH = valueShape[getValueAlphaHDim()];
3513 int64_t valueW = valueShape[getValueAlphaWDim()];
3514 int64_t valueTileH = valueShape[getValueTileHDim()];
3515 int64_t valueTileW = valueShape[getValueTileWDim()];
3516 WinogradConv2DFmr fmr = getFmr();
3517 int64_t m, r;
3518 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3519 bool leftTransform = valueH != 1;
3520 bool rightTransform = valueW != 1;
3521
3522 int64_t outputRank = getOutputOperandRank();
3523 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3524 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3525 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3526 } else {
3527 if (valueH != (leftTransform ? m + r - 1 : 1))
3528 return emitOpError("expect input height equals to input tile size");
3529 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3530 }
3531 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3532 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3533 } else {
3534 if (valueW != (rightTransform ? m + r - 1 : 1))
3535 return emitOpError("expect input width equals to input tile size");
3536 expectedOutputShape[getOutputWDim()] =
3537 (rightTransform ? m : 1) * valueTileW;
3538 }
3539 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3540 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3541
3542 auto outputType = cast<ShapedType>(getOutput().getType());
3543 ArrayRef<int64_t> outputShape = outputType.getShape();
3544 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3545 return emitOpError("the output shape is not expected");
3546 }
3547 return success();
3548}
3549
3550SmallVector<Range>
3551WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3552 Location loc = getLoc();
3553 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3554 IntegerAttr oneAttr = builder.getIndexAttr(1);
3555 Value value = getValue();
3556 int64_t valueRank = getValueOperandRank();
3557 SmallVector<Range> loopBounds(valueRank);
3558 for (unsigned dim = 0; dim < valueRank; ++dim) {
3559 loopBounds[dim].offset = zeroAttr;
3560 // alphaH, alphaW, tileH, tileW, N, F
3561 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3562 loopBounds[dim].stride = oneAttr;
3563 }
3564 return loopBounds;
3565}
3566
3567SmallVector<utils::IteratorType>
3568WinogradOutputTransformOp::getLoopIteratorTypes() {
3569 int64_t valueRank = getValueOperandRank();
3570 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3571 utils::IteratorType::parallel);
3572 return iteratorTypes;
3573}
3574
3575LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3576 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3577 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3578 SmallVector<OpFoldResult> &resultSizes) {
3579 WinogradConv2DFmr fmr = getFmr();
3580 int64_t m, r;
3581 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3582
3583 Location loc = getLoc();
3584 MLIRContext *context = builder.getContext();
3585 auto identityAffineMap =
3586 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3587 auto affineMap =
3588 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3589
3590 ShapedType valueType = getValueOperandType();
3591 ArrayRef<int64_t> valueShape = valueType.getShape();
3592 int64_t valueH = valueShape[0];
3593 int64_t valueW = valueShape[1];
3594 Value mappedOffsetH = affine::makeComposedAffineApply(
3595 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3596 offsets[getValueTileHDim()]);
3597 Value mappedOffsetW = affine::makeComposedAffineApply(
3598 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3599 offsets[getValueTileWDim()]);
3600 Value mappedSizeH = affine::makeComposedAffineApply(
3601 builder, loc, affineMap, sizes[getValueTileHDim()]);
3602 Value mappedSizeW = affine::makeComposedAffineApply(
3603 builder, loc, affineMap, sizes[getValueTileWDim()]);
3604
3605 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3606 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3607 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3608 OpFoldResult sizeH =
3609 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3610 OpFoldResult sizeW =
3611 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3612
3613 resultOffsets.append(
3614 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3615 resultSizes.append(
3616 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3617 return success();
3618}
3619
3620/// Implement tiling for winograd_output_transform
3621/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3622/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3623/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3624/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3625/// for the sizes of tileH, tileW, N, F for one tile.
3626FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3627 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3628 ArrayRef<OpFoldResult> sizes) {
3629 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3630 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3631 Location loc = getLoc();
3632 SmallVector<Value> tiledOperands;
3633 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3634
3635 ShapedType valueType = getValueOperandType();
3636 ArrayRef<int64_t> valueShape = valueType.getShape();
3637 int64_t alphaH = valueShape[getValueAlphaHDim()];
3638 int64_t alphaW = valueShape[getValueAlphaWDim()];
3639 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3640 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3641
3642 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3643 offsets[getValueTileWDim()], offsets[getValueNDim()],
3644 offsets[getValueFDim()]});
3645 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3646 sizes[getValueTileWDim()], sizes[getValueNDim()],
3647 sizes[getValueFDim()]});
3648 int64_t valueRank = getValueOperandRank();
3649 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3650 auto valueSlice = tensor::ExtractSliceOp::create(
3651 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3652 tiledOperands.emplace_back(valueSlice);
3653
3654 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3655 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3656 resultSizes)))
3657 return failure();
3658
3659 int64_t outputRank = getOutputOperandRank();
3660 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3661 auto outputSlice = tensor::ExtractSliceOp::create(
3662 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3663 tiledOperands.emplace_back(outputSlice);
3664
3665 SmallVector<Type> resultTypes;
3666 resultTypes.push_back(tiledOperands[1].getType());
3667 Operation *tiledOp =
3668 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3669
3670 return TilingResult{
3671 {tiledOp},
3672 SmallVector<Value>(tiledOp->getResults()),
3673 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3674}
3675
3676//===----------------------------------------------------------------------===//
3677// LinalgDialect
3678// TODO: Merge with the LinalgDialect block at the bottom
3679//===----------------------------------------------------------------------===//
3680
3681// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3682static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3683 auto explicitRange = subMap.getResults();
3684 auto defaultRange = fullMap.getResults();
3685 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3686 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3687 llvm::set_union(explicitSet, defaultSet);
3688 return explicitSet == defaultSet;
3689}
3690
3691/// Check if the user defined map is valid broadcast map. Here broadcast
3692/// indexing maps are defined in context of corresponding default indexing maps
3693/// for the given Op. This way the check becomes very simple i.e just check the
3694/// number of result dims.
3695/// Returns true if the explictMap is broadcasted with respect to the
3696/// defaultMap.
3697static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3698 return explictMap.getNumResults() < defaultMap.getNumResults();
3699}
3700
3701/// Verifies the broadcast and transpose semantic sepecified by the explicit
3702/// indexing map for the MatmulOp \p op for each operand specified by \p
3703/// opIndex.
3704static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3705 unsigned opIndex) {
3706 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3707 SmallVector<AffineMap, 3> defaultIndexingMaps =
3708 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3709
3710 auto opIndexingMap = opIndexingMaps[opIndex];
3711 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3712 // Check general validity of indexing map results.
3713 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3714 return matmulOp->emitOpError()
3715 << "Unexpected dim expression in map result.";
3716
3717 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3718 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3719 return matmulOp->emitOpError()
3720 << "Invalid broadcast requested, should be (d2).";
3721 }
3722 return success();
3723 }
3724 return success();
3725}
3726
3727// Check general validity of input indexing map of
3728// BatchMatmulOp/BatchReduceMatmulOp.
3729template <typename OpTy>
3730static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3731 AffineMap opIndexingMap,
3732 AffineMap defaultIndexingMap, bool isLHS) {
3733 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3734 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3735 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3736 // Check the result dims are valid.
3737 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3738 return batchVariantMatmulOp->emitOpError()
3739 << "Unexpected result dim expression (outside the set of default "
3740 "result dims).";
3741
3742 // Check for valid number of result dims of input maps.
3743 if (opIndexingMap.getNumResults() > 3)
3744 return batchVariantMatmulOp->emitOpError()
3745 << "no. of result dim expressions exceeds 3.";
3746
3747 auto hasValidBatchDim = [](AffineMap map) {
3748 AffineExpr batchDim = map.getResult(0);
3749 return batchDim.isFunctionOfDim(0);
3750 };
3751
3752 // Check if the requested broadcast is valid.
3753 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3754 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3755 return batchVariantMatmulOp->emitOpError()
3756 << "Invalid broadcast requested.";
3757 } else if (!hasValidBatchDim(opIndexingMap)) {
3758 return batchVariantMatmulOp->emitOpError()
3759 << "Invalid batch dimension expression.";
3760 }
3761 return success();
3762}
3763
3764/// This function checks if the given AffineMap for the output of a
3765/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3766/// dimensions and if the output map result dimensions are valid.
3767template <typename OpTy>
3768static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3769 AffineMap opIndexingMap) {
3770 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3771 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3772 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3773 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3774 opIndexingMap.getNumResults() != 3) {
3775
3776 return batchVariantMatmulOp->emitOpError()
3777 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3778 << ").";
3779 }
3780 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3781 opIndexingMap.getNumResults() != 2) {
3782 return batchVariantMatmulOp->emitOpError()
3783 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3784 << ").";
3785 }
3786
3787 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3788 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3789 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3790 outputMap.getResult(1).isFunctionOfDim(1) &&
3791 outputMap.getResult(2).isFunctionOfDim(2)
3792 : outputMap.getResult(0).isFunctionOfDim(1) &&
3793 outputMap.getResult(1).isFunctionOfDim(2);
3794 };
3795
3796 if (!areValidOutputResultDim(opIndexingMap)) {
3797 return batchVariantMatmulOp->emitOpError()
3798 << "Invalid output map result dimension.";
3799 }
3800
3801 return success();
3802}
3803
3804/// Verifies the broadcast and transpose semantic specified by the explicit
3805/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3806/// specified by opIndex.
3807template <typename OpTy>
3808static LogicalResult
3810 unsigned opIndex) {
3811 SmallVector<AffineMap, 3> opIndexingMaps =
3812 batchVariantMatmulOp.getIndexingMapsArray();
3813 SmallVector<AffineMap, 3> defaultIndexingMaps =
3814 batchVariantMatmulOp.getDefaultIndexingMaps(
3815 batchVariantMatmulOp->getContext());
3816
3817 if (opIndexingMaps.size() != 3)
3818 return batchVariantMatmulOp->emitOpError()
3819 << "Indexing_map attribute must have 3 affine maps.";
3820
3821 auto opIndexingMap = opIndexingMaps[opIndex];
3822 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3823
3824 if (opIndex == 2 &&
3825 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3826 return failure();
3827
3828 if (opIndex != 2 &&
3829 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3830 defaultIndexingMap, opIndex == 0)))
3831 return failure();
3832
3833 return success();
3834}
3835
3836namespace mlir {
3837namespace linalg {
3838
3839std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3840 if (m == 2 && r == 3)
3841 return WinogradConv2DFmr::F_2_3;
3842 if (m == 4 && r == 3)
3843 return WinogradConv2DFmr::F_4_3;
3844 if (m == 2 && r == 5)
3845 return WinogradConv2DFmr::F_2_5;
3846 return std::nullopt;
3847}
3848
3849std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3850 switch (fmr) {
3851 case WinogradConv2DFmr::F_2_3:
3852 return {2, 3};
3853 case WinogradConv2DFmr::F_4_3:
3854 return {4, 3};
3855 case WinogradConv2DFmr::F_2_5:
3856 return {2, 5};
3857 }
3858 llvm_unreachable("Unkown WinogradConv2DFmr");
3859}
3860
3861//===----------------------------------------------------------------------===//
3862// MatMulOp
3863//===----------------------------------------------------------------------===//
3864
3865static FailureOr<SmallVector<SmallVector<int64_t>>>
3868 for (auto map : maps) {
3869 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3870 if (!attr)
3871 return failure();
3873 for (auto result : attr.getAffineMap().getResults()) {
3874 auto dim = dyn_cast<AffineDimExpr>(result);
3875 if (!dim)
3876 return failure();
3877 pos.push_back(dim.getPosition());
3878 }
3879 positions.push_back(pos);
3880 }
3881 return positions;
3882}
3883
3884/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3885SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3886 AffineExpr d0, d1, d2;
3887 SmallVector<AffineMap> indexingMaps;
3888 bindDims(context, d0, d1, d2);
3889 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3890 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3891 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3892 return indexingMaps;
3893}
3894
3895bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3896 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3897 if (!maps)
3898 return false;
3899 if (maps.size() != 3)
3900 return false;
3901 auto positions = getAffineResultPositions(maps);
3902 if (failed(positions))
3903 return false;
3904 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3905 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3906 (*positions)[2] == SmallVector<int64_t>{0, 1};
3907}
3908
3909SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3910 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3911 utils::IteratorType::parallel,
3912 utils::IteratorType::reduction};
3913}
3914
3915unsigned MatmulOp::getNumRegionArgs() { return 3; }
3916
3917std::string MatmulOp::getLibraryCallName() {
3918 return generateLibraryCallName(getOperation());
3919}
3920
3921bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3922
3923/// Check if the op has broadcast and/or transpose semantic. Returns true if
3924/// the user defined indexing maps are not equal to default map.
3925bool MatmulOp::hasUserDefinedMaps() {
3926 SmallVector<AffineMap, 3> defaultMaps =
3927 getDefaultIndexingMaps(this->getContext());
3928 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3929 return defaultMaps != explicitMaps;
3930}
3931
3932/// Implements the block region builder for the MatmulOp. This is called by
3933/// 'fillStructuredOpRegion'.
3934void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3935 ArrayRef<NamedAttribute> attrs,
3936 function_ref<InFlightDiagnostic()> emitError) {
3937 if (emitError && block.getNumArguments() != 3) {
3938 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3939 << block.getNumArguments();
3940 return;
3941 }
3942 assert(block.getNumArguments() == 3 &&
3943 "MatmulOp regionBuilder expects 3 args");
3944 RegionBuilderHelper helper(b, block);
3945 SmallVector<Value> yields;
3946
3947 TypeFn castVal = TypeFn::cast_signed;
3948 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3949 return attr.getName() == "cast";
3950 });
3951 if (castIter != attrs.end()) {
3952 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3953 castVal = attr.getValue();
3954 }
3955
3956 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3957 block.getArgument(0));
3958 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3959 block.getArgument(1));
3960 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3961 if (!value1 || !value2 || !value3)
3962 return;
3963 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3964 value3, emitError);
3965 if (!value4)
3966 return;
3967 yields.push_back(value4);
3968 helper.yieldOutputs(yields);
3969}
3970
3971/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3972/// broadcast map must include K dimension.
3973/// TODO: Strict inclusion of K dimension in the broadcast map is not
3974/// necessary for both input matrices simultaneously. We can relax this
3975/// condition to have K dimension for one input matrix map and infer the K
3976/// dimension for other input matrix map from the one already having K
3977/// dimension.
3978bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3979 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3980 AffineExpr expr = bcastMap.getResult(0);
3981 // Invalid map if the common dimension of matmul not found.
3982 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3983}
3984
3985static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3986 if (parser.parseOptionalKeyword("indexing_maps"))
3987 return ArrayAttr{
3988 nullptr}; // Success in case indexing_maps was not provided.
3989
3990 ArrayAttr arrayAttr;
3991 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3992 return failure();
3993
3994 if (llvm::any_of(arrayAttr,
3995 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3996 return parser.emitError(parser.getCurrentLocation())
3997 << "element of indexing_maps array is not an affine_map";
3998
3999 return arrayAttr;
4000}
4001
4002ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4003 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4004 if (failed(indexingMapsAttr))
4005 return failure();
4006
4007 if (*indexingMapsAttr == nullptr) {
4008 auto indexingMapAttrs = llvm::map_to_vector(
4009 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
4010 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4011 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
4012 }
4013
4014 result.addAttribute("indexing_maps", *indexingMapsAttr);
4015 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
4016 MatmulOp::getRegionBuilder());
4017}
4018
4019void MatmulOp::print(OpAsmPrinter &p) {
4020 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4021 MatmulOp::getDefaultIndexingMaps(getContext()),
4022 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4023 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4024 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4025
4026 std::array<StringRef, 3> elidedAttrs = {
4027 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4028 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4029 elidedAttrs);
4030}
4031
4032/// Verify the user defined indexing maps.
4033LogicalResult MatmulOp::verify() {
4034 // Verification of pure matmul is handled by verifyStructuredOpInterface().
4035 if (!hasUserDefinedMaps())
4036 return success();
4037
4038 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
4039 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
4040 return failure();
4041 }
4042 return success();
4043}
4044
4045LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4046 return memref::foldMemRefCast(*this);
4047}
4048
4049void MatmulOp::getEffects(
4050 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4051 &effects) {
4052 if (hasPureTensorSemantics())
4053 return;
4054 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4055}
4056
4057Speculation::Speculatability MatmulOp::getSpeculatability() {
4058 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4059}
4060
4061SmallVector<AffineMap>
4062MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4063 AffineExpr d0, d1, d2;
4064 MLIRContext *context = builder.getContext();
4065 bindDims(context, d0, d1, d2);
4066 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
4067 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
4068 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4069 return {mapLHS, mapRHS, mapOut};
4070}
4071
4073 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4074 if (!maps)
4075 return false;
4076 if (maps.size() != 3)
4077 return false;
4078 auto positions = getAffineResultPositions(maps);
4079 if (failed(positions))
4080 return false;
4081 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4082 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4083 (*positions)[2] == SmallVector<int64_t>{0, 1};
4084}
4085
4088 ValueRange inputs, ValueRange outputs,
4089 ArrayRef<NamedAttribute> attributes) {
4090 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4091 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4092}
4093
4096 ValueRange inputs, ValueRange outputs,
4097 ArrayRef<NamedAttribute> attributes) {
4098 OperationState state(location, getOperationName());
4099 build(builder, state, inputs, outputs, attributes);
4100 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4101 assert(res && "builder didn't return the right type");
4102 return res;
4103}
4104
4107 TypeRange resultTensorTypes,
4108 ValueRange inputs, ValueRange outputs,
4109 ArrayRef<NamedAttribute> attributes) {
4110 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4111 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4112}
4113
4116 TypeRange resultTensorTypes, ValueRange inputs,
4117 ValueRange outputs,
4118 ArrayRef<NamedAttribute> attributes) {
4119 OperationState state(location, getOperationName());
4120 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4121 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4122 assert(res && "builder didn't return the right type");
4123 return res;
4124}
4125
4128 TypeRange resultTensorTypes,
4129 ValueRange inputs, ValueRange outputs,
4130 Attribute cast,
4131 ArrayRef<NamedAttribute> attributes) {
4132 result.addAttribute("cast", cast);
4133 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4134 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4135}
4136
4139 TypeRange resultTensorTypes, ValueRange inputs,
4140 ValueRange outputs, Attribute cast,
4141 ArrayRef<NamedAttribute> attributes) {
4142 OperationState state(location, getOperationName());
4143 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4144 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4145 assert(res && "builder didn't return the right type");
4146 return res;
4147}
4148
4150 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4152 op->getAttr("indexing_maps"));
4153}
4154
4156MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4157 AffineExpr d0, d1, d2;
4158 MLIRContext *context = builder.getContext();
4159 bindDims(context, d0, d1, d2);
4160 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4161 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4162 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4163 return {mapLHS, mapRHS, mapOut};
4164}
4165
4167 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4168 if (!maps)
4169 return false;
4170 if (maps.size() != 3)
4171 return false;
4172 auto positions = getAffineResultPositions(maps);
4173 if (failed(positions))
4174 return false;
4175 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4176 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4177 (*positions)[2] == SmallVector<int64_t>{0, 1};
4178}
4179
4182 ValueRange inputs, ValueRange outputs,
4183 ArrayRef<NamedAttribute> attributes) {
4184 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4185 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4186}
4187
4190 ValueRange inputs, ValueRange outputs,
4191 ArrayRef<NamedAttribute> attributes) {
4192 OperationState state(location, getOperationName());
4193 build(builder, state, inputs, outputs, attributes);
4194 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4195 assert(res && "builder didn't return the right type");
4196 return res;
4197}
4198
4201 TypeRange resultTensorTypes,
4202 ValueRange inputs, ValueRange outputs,
4203 ArrayRef<NamedAttribute> attributes) {
4204 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4205 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4206}
4207
4210 TypeRange resultTensorTypes, ValueRange inputs,
4211 ValueRange outputs,
4212 ArrayRef<NamedAttribute> attributes) {
4213 OperationState state(location, getOperationName());
4214 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4215 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4216 assert(res && "builder didn't return the right type");
4217 return res;
4218}
4219
4222 TypeRange resultTensorTypes,
4223 ValueRange inputs, ValueRange outputs,
4224 Attribute cast,
4225 ArrayRef<NamedAttribute> attributes) {
4226 result.addAttribute("cast", cast);
4227 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4228 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4229}
4230
4233 TypeRange resultTensorTypes, ValueRange inputs,
4234 ValueRange outputs, Attribute cast,
4235 ArrayRef<NamedAttribute> attributes) {
4236 OperationState state(location, getOperationName());
4237 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4238 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4239 assert(res && "builder didn't return the right type");
4240 return res;
4241}
4242
4244 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4246 op->getAttr("indexing_maps"));
4247}
4248
4250BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4251 AffineExpr d0, d1, d2, d3;
4252 MLIRContext *context = builder.getContext();
4253 bindDims(context, d0, d1, d2, d3);
4254 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4255 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4256 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4257 return {mapLHS, mapRHS, mapOut};
4258}
4259
4261 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4262 if (!maps)
4263 return false;
4264 if (maps.size() != 3)
4265 return false;
4266 auto positions = getAffineResultPositions(maps);
4267 if (failed(positions))
4268 return false;
4269 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4270 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4271 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4272}
4273
4275 OpBuilder &builder, OperationState &result, ValueRange inputs,
4276 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4277 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4278 BatchMatmulOp::getRegionBuilder(),
4279 getDefaultIndexingMaps(builder));
4280}
4281
4284 ValueRange inputs, ValueRange outputs,
4285 ArrayRef<NamedAttribute> attributes) {
4286 OperationState state(location, getOperationName());
4287 build(builder, state, inputs, outputs, attributes);
4288 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4289 assert(res && "builder didn't return the right type");
4290 return res;
4291}
4292
4294 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4295 ValueRange inputs, ValueRange outputs,
4296 ArrayRef<NamedAttribute> attributes) {
4297 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4298 BatchMatmulOp::getRegionBuilder(),
4299 getDefaultIndexingMaps(builder));
4300}
4301
4304 TypeRange resultTensorTypes, ValueRange inputs,
4305 ValueRange outputs,
4306 ArrayRef<NamedAttribute> attributes) {
4307 OperationState state(location, getOperationName());
4308 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4309 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4310 assert(res && "builder didn't return the right type");
4311 return res;
4312}
4313
4315 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4316 ValueRange inputs, ValueRange outputs, Attribute cast,
4317 ArrayRef<NamedAttribute> attributes) {
4318 result.addAttribute("cast", cast);
4319 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4320 BatchMatmulOp::getRegionBuilder(),
4321 getDefaultIndexingMaps(builder));
4322}
4323
4326 TypeRange resultTensorTypes, ValueRange inputs,
4327 ValueRange outputs, Attribute cast,
4328 ArrayRef<NamedAttribute> attributes) {
4329 OperationState state(location, getOperationName());
4330 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4331 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4332 assert(res && "builder didn't return the right type");
4333 return res;
4334}
4335
4337 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4339 op->getAttr("indexing_maps"));
4340}
4341
4343BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4344 AffineExpr d0, d1, d2, d3;
4345 MLIRContext *context = builder.getContext();
4346 bindDims(context, d0, d1, d2, d3);
4347 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4348 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4349 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4350 return {mapLHS, mapRHS, mapOut};
4351}
4352
4354 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4355 if (!maps)
4356 return false;
4357 if (maps.size() != 3)
4358 return false;
4359 auto positions = getAffineResultPositions(maps);
4360 if (failed(positions))
4361 return false;
4362 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4363 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4364 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4365}
4366
4368 OpBuilder &builder, OperationState &result, ValueRange inputs,
4369 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4370 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4371 BatchMatmulOp::getRegionBuilder(),
4372 getDefaultIndexingMaps(builder));
4373}
4374
4377 ValueRange inputs, ValueRange outputs,
4378 ArrayRef<NamedAttribute> attributes) {
4379 OperationState state(location, getOperationName());
4380 build(builder, state, inputs, outputs, attributes);
4381 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4382 assert(res && "builder didn't return the right type");
4383 return res;
4384}
4385
4387 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4388 ValueRange inputs, ValueRange outputs,
4389 ArrayRef<NamedAttribute> attributes) {
4390 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4391 BatchMatmulOp::getRegionBuilder(),
4392 getDefaultIndexingMaps(builder));
4393}
4394
4397 TypeRange resultTensorTypes, ValueRange inputs,
4398 ValueRange outputs,
4399 ArrayRef<NamedAttribute> attributes) {
4400 OperationState state(location, getOperationName());
4401 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4402 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4403 assert(res && "builder didn't return the right type");
4404 return res;
4405}
4406
4408 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4409 ValueRange inputs, ValueRange outputs, Attribute cast,
4410 ArrayRef<NamedAttribute> attributes) {
4411 result.addAttribute("cast", cast);
4412 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4413 BatchMatmulOp::getRegionBuilder(),
4414 getDefaultIndexingMaps(builder));
4415}
4416
4419 TypeRange resultTensorTypes, ValueRange inputs,
4420 ValueRange outputs, Attribute cast,
4421 ArrayRef<NamedAttribute> attributes) {
4422 OperationState state(location, getOperationName());
4423 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4424 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4425 assert(res && "builder didn't return the right type");
4426 return res;
4427}
4428
4430 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4432 op->getAttr("indexing_maps"));
4433}
4434
4435//===----------------------------------------------------------------------===//
4436// ContractOp
4437//===----------------------------------------------------------------------===//
4438
4439SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4440 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4441 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4442 // domains are all the same, and each implements a projected permutation.
4443 // Each iteration space dim must occur for at least one operand and either
4444 // takes part in a contraction/reduction or else has parallel iteration type.
4445 // We have that a dim is a contraction/reduction dim if and only if the dim
4446 // occurs for the output operand. We use this fact for fast inference:
4447 // NB: In case we allow dims to occur solely for one input, the above still
4448 // holds: per the einsum semantics, these are reduction dims as well.
4449 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4450 for (auto result : outAffineMap.getResults()) {
4451 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4452 assert(dimExpr && "affine_map is a projected permutation");
4453 dimsInOutput[dimExpr.getPosition()] = true;
4454 }
4455
4457 for (auto dimOccursInOutput : dimsInOutput)
4458 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4459 : utils::IteratorType::reduction);
4460
4461 return iteratorTypes;
4462}
4463
4464unsigned ContractOp::getNumRegionArgs() { return 3; }
4465
4466/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4467void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4468 ArrayRef<NamedAttribute> attrs,
4469 function_ref<InFlightDiagnostic()> emitError) {
4470 if (emitError && block.getNumArguments() != 3) {
4471 emitError() << "ContractOp regionBuilder expects 3 args, got "
4472 << block.getNumArguments();
4473 return;
4474 }
4475 assert(block.getNumArguments() == 3 &&
4476 "ContractOp regionBuilder expects 3 args");
4477 RegionBuilderHelper helper(b, block);
4478
4479 TypeFn castSignedness = TypeFn::cast_signed;
4480 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4481 return attr.getName() == "cast";
4482 });
4483 if (castIter != attrs.end()) {
4484 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4485 castSignedness = attr.getValue();
4486 }
4487
4488 // TODO: Support fields with operators besides mult & add.
4489 Type outType = block.getArgument(2).getType();
4490 Value lhsAtOutType =
4491 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4492 Value rhsAtOutType =
4493 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4494 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4495 rhsAtOutType, emitError);
4496 if (!productAtOutType)
4497 return;
4498 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4499 productAtOutType, emitError);
4500 if (!result)
4501 return;
4502 helper.yieldOutputs({result});
4503}
4504
4505ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4506 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4507 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4508 return parser.emitError(parser.getCurrentLocation(),
4509 "expected 'indexing_maps' attribute");
4510 result.addAttribute("indexing_maps", *indexingMapsAttr);
4511
4512 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4513 regionBuilder);
4514}
4515
4516void ContractOp::print(OpAsmPrinter &p) {
4517 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4519 p, getOperation(), getInputs(), getOutputs(),
4520 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4521}
4522
4523LogicalResult ContractOp::verify() {
4524 int iterationSpaceDims = -1;
4525 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4526 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4527 // access an input operand (so occurrence count can be at most 2) and
4528 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4529 SmallVector<size_t> inOccurrences;
4530 SmallVector<size_t> outOccurrences;
4531
4532 // A helper so that for each operand's affine_map and type we check that ...
4533 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4534 bool isInput) -> LogicalResult {
4535 // ... the affine_map is a projected permutation;
4536 if (!affineMap.isProjectedPermutation())
4537 return emitError("provided affine_map is not a projected permutation");
4538
4539 // ... the rank of the affine_map's results and corresponding type match;
4540 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4541 if (affineMap.getNumResults() != shapedType.getRank())
4542 return emitError("ranks of shaped operand and results of corresponding "
4543 "affine_map differ");
4544 } else if (affineMap.getNumResults() != 0) {
4545 return emitError("affine_map specifies shaped access while operand has "
4546 "non-shaped type");
4547 }
4548
4549 // ... the rank of the affine_map's domain is the same as those seen prior;
4550 if (iterationSpaceDims == -1) {
4551 iterationSpaceDims = affineMap.getNumDims();
4552 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4553 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4554 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4555 return emitError("iteration spaces of provided affine_maps differ");
4556 }
4557
4558 // ... update counts of dims used to access either an input or the output.
4559 for (AffineExpr affineExpr : affineMap.getResults()) {
4560 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4561 if (!affineDimExpr)
4562 llvm_unreachable("affine_map is a projected permutation");
4563
4564 if (isInput)
4565 inOccurrences[affineDimExpr.getPosition()] += 1;
4566 else
4567 outOccurrences[affineDimExpr.getPosition()] += 1;
4568 }
4569
4570 return success();
4571 };
4572
4573 for (auto &&[affineMap, operandType, isInput] :
4574 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4575 SmallVector<bool>{true, true, false})) {
4576 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4577 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4578 }
4579
4580 bool hasContractingDim = false;
4581 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4582 size_t inOccCount = inOccurrences[dimIndex];
4583 size_t outOccCount = outOccurrences[dimIndex];
4584
4585 // We have a contracting dim if and only if ...
4586 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4587
4588 if (inOccCount == 0 && outOccCount == 0)
4589 return emitError() << "iteration space dim at index " << dimIndex
4590 << " not used to access any operand";
4591
4592 // NB: We disallow a dim which occurs for only one input operand and not
4593 // for the output. In terms of einsum semantics such dims have a
4594 // sensible meaning - namely an additional reduction per each such dim.
4595 // By contrast, the ContractionOpInterface does not know about this
4596 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4597 // while vector.contract's verifier accepts dims of this kind many of
4598 // its lowerings give up on encountering these dims.
4599 // TODO: Remove following once we have comprehensive support for input-only
4600 // reduction dims, at both the linalg- and vector-dialect levels.
4601 if (inOccCount == 1 && outOccCount != 1)
4602 return emitError()
4603 << "iteration space dim at index " << dimIndex
4604 << " is neither a contracting dim nor of parallel iteration type";
4605 }
4606
4607 if (!hasContractingDim)
4608 return emitError("'indexing_maps' do not specify a contracting dimension");
4609
4610 return success();
4611}
4612
4613LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4614 return memref::foldMemRefCast(*this);
4615}
4616
4617void ContractOp::getEffects(
4618 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4619 &effects) {
4620 if (hasPureTensorSemantics())
4621 return;
4622 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4623}
4624
4625Speculation::Speculatability ContractOp::getSpeculatability() {
4626 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4627}
4628
4629//===----------------------------------------------------------------------===//
4630// Implementation of BatchMatmulOp
4631//===----------------------------------------------------------------------===//
4632SmallVector<AffineMap>
4633BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4634 AffineExpr d0, d1, d2, d3;
4635 SmallVector<AffineMap> indexingMaps;
4636 bindDims(context, d0, d1, d2, d3);
4637 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4638 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4639 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4640 return indexingMaps;
4641}
4642
4643bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4644 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4645 if (!maps)
4646 return false;
4647 if (maps.size() != 3)
4648 return false;
4649 auto positions = getAffineResultPositions(maps);
4650 if (failed(positions))
4651 return false;
4652 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4653 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4654 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4655}
4656
4657SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4658 return SmallVector<utils::IteratorType>{
4659 utils::IteratorType::parallel, utils::IteratorType::parallel,
4660 utils::IteratorType::parallel, utils::IteratorType::reduction};
4661}
4662
4663unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4664
4665std::string BatchMatmulOp::getLibraryCallName() {
4666 return generateLibraryCallName(getOperation());
4667}
4668
4669/// Check if the op has broadcast and/or transpose semantic. Returns true if
4670/// the user defined indexing maps are not equal to default map.
4671bool BatchMatmulOp::hasUserDefinedMaps() {
4672 SmallVector<AffineMap, 3> defaultMaps =
4673 getDefaultIndexingMaps(this->getContext());
4674 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4675 return defaultMaps != explicitMaps;
4676}
4677
4678/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4679/// broadcast map must include K dimension.
4680/// TODO: Strict inclusion of K dimension in the broadcast map is not
4681/// necessary for both input matrices simultaneously. We can relax this
4682/// condition to have K dimension for one input matrix map and infer the K
4683/// dimension for other input matrix map from the one already having K
4684/// dimension.
4685bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4686 assert(bcastMap.getNumResults() < 3 &&
4687 "Expected less than 3 result dim expr.");
4688 bool isValid = false;
4689 enum Indices { batchPos, mPos, nPos, kPos };
4690 if (bcastMap.getNumResults() == 1) {
4691 AffineExpr expr = bcastMap.getResult(0);
4692 isValid = expr.isFunctionOfDim(kPos);
4693 } else if (bcastMap.getNumResults() == 2) {
4694 AffineExpr expr0 = bcastMap.getResult(0);
4695 AffineExpr expr1 = bcastMap.getResult(1);
4696 isValid =
4697 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4698 expr0.isFunctionOfDim(mPos)) &&
4699 expr1.isFunctionOfDim(kPos))
4700 : ((expr0.isFunctionOfDim(batchPos) &&
4701 expr1.isFunctionOfDim(kPos)) ||
4702 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4703 }
4704 return isValid;
4705}
4706
4707void BatchMatmulOp::regionBuilder(
4708 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4709 function_ref<InFlightDiagnostic()> emitError) {
4710 if (emitError && block.getNumArguments() != 3) {
4711 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4712 << block.getNumArguments();
4713 return;
4714 }
4715 assert(block.getNumArguments() == 3 &&
4716 "BatchMatmulOp regionBuilder expects 3 args");
4717 RegionBuilderHelper helper(b, block);
4718 SmallVector<Value> yields;
4719
4720 TypeFn castVal = TypeFn::cast_signed;
4721 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4722 return attr.getName() == "cast";
4723 });
4724 if (castIter != attrs.end()) {
4725 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4726 castVal = attr.getValue();
4727 }
4728
4729 auto toType = block.getArgument(2).getType();
4730 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4731 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4732 Value mulVal =
4733 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4734 if (!castValA || !castValB || !mulVal)
4735 return;
4736 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4737 mulVal, emitError);
4738 if (!addVal)
4739 return;
4740 yields.push_back(addVal);
4741 helper.yieldOutputs(yields);
4742}
4743
4744ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4745 SmallVector<Attribute, 3> indexingMapsAttr;
4746 Attribute mapAttr;
4747 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4748 if (parser.parseEqual())
4749 return failure();
4750
4751 if (parser.parseLSquare())
4752 return failure();
4753
4754 do {
4755 if (parser.parseAttribute(mapAttr))
4756 return failure();
4757 if (!isa<AffineMapAttr>(mapAttr)) {
4758 return parser.emitError(parser.getCurrentLocation(),
4759 "expected affine map attribute");
4760 }
4761 indexingMapsAttr.push_back(mapAttr);
4762
4763 if (parser.parseOptionalComma())
4764 break;
4765 } while (true);
4766
4767 if (parser.parseRSquare())
4768 return failure();
4769 }
4770 // Initialize indexingMaps, if not supplied explicitly.
4771 if (indexingMapsAttr.empty()) {
4772 indexingMapsAttr = llvm::map_to_vector(
4773 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4774 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4775 }
4776 result.addAttribute("indexing_maps",
4777 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4778
4779 return ::parseNamedStructuredOp(parser, result,
4780 BatchMatmulOp::getNumRegionArgs(),
4781 BatchMatmulOp::getRegionBuilder());
4782}
4783
4784void BatchMatmulOp::print(OpAsmPrinter &p) {
4785 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4786 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4787 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4788 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4789 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4790
4791 std::array<StringRef, 3> elidedAttrs = {
4792 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4793 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4794 elidedAttrs);
4795}
4796
4797/// Verify the user defined indexing maps.
4798LogicalResult BatchMatmulOp::verify() {
4799 // Verification of pure batch_matmul is handled by
4800 // verifyStructuredOpInterface().
4801 if (!hasUserDefinedMaps())
4802 return success();
4803
4804 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4806 return failure();
4807 }
4808 return success();
4809}
4810
4811LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4812 SmallVectorImpl<OpFoldResult> &) {
4813 return memref::foldMemRefCast(*this);
4814}
4815
4816void BatchMatmulOp::getEffects(
4817 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4818 &effects) {
4819 if (hasPureTensorSemantics())
4820 return;
4821 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4822}
4823
4824Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4825 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4826}
4827
4828//===----------------------------------------------------------------------===//
4829// ElementwiseOp
4830//===----------------------------------------------------------------------===//
4831//
4832namespace {
4833struct ArityGroupAndKind {
4834 // The enum class {Unary, Binary, Ternary, ..}
4835 ElementwiseArityGroup arityGroup;
4836
4837 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4838 union Kind {
4839 UnaryFn unaryFn;
4840 BinaryFn binaryFn;
4841 TernaryFn ternaryFn;
4842 } kind;
4843};
4844
4845unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4846 return static_cast<unsigned>(arityGroup);
4847}
4848} // namespace
4849
4850static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4851 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4852 constexpr int lastBinary =
4853 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4854 constexpr int lastTernary =
4855 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4856
4857 int val = static_cast<int>(kind);
4858 ArityGroupAndKind result;
4859
4860 if (val < lastUnary) {
4861 result.arityGroup = ElementwiseArityGroup::Unary;
4862 result.kind.unaryFn = static_cast<UnaryFn>(val);
4863 return result;
4864 }
4865 if (val < lastBinary) {
4866 result.arityGroup = ElementwiseArityGroup::Binary;
4867 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4868 return result;
4869 }
4870 if (val >= lastTernary) {
4871 llvm_unreachable("unhandled ElementwiseFn");
4872 }
4873 result.arityGroup = ElementwiseArityGroup::Ternary;
4874 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4875 return result;
4876}
4877
4878SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4879 auto rank = getResultRank();
4880 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4881}
4882
4884ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4885 MLIRContext *context) {
4886 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4887 return SmallVector<AffineMap>(numMaps, map);
4888}
4889
4890ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4891 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4892 Attribute attr;
4893 mlir::linalg::ElementwiseKind elemwiseKindVal;
4894 if (parser.parseKeyword("kind") || parser.parseEqual())
4895 return failure();
4896
4897 if (succeeded(parser.parseAttribute(attr))) {
4898 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4899 if (!elemwiseKindAttr)
4900 return parser.emitError(parser.getCurrentLocation(),
4901 "expected ElementwiseKind attribute");
4902 elemwiseKindVal = elemwiseKindAttr.getValue();
4903 } else {
4904 return parser.emitError(parser.getCurrentLocation(),
4905 "expected operation 'kind' attribute");
4906 }
4907 result.addAttribute(
4908 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4909
4910 // Parse optional `indexing_maps`
4911 SmallVector<Attribute, 3> indexingMapsAttr;
4912 Attribute mapAttr;
4913 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4914 if (parser.parseEqual())
4915 return failure();
4916 if (parser.parseLSquare())
4917 return failure();
4918 do {
4919 if (parser.parseAttribute(mapAttr))
4920 return failure();
4921 if (!isa<AffineMapAttr>(mapAttr))
4922 return parser.emitError(parser.getCurrentLocation(),
4923 "expected affine map attribute");
4924 indexingMapsAttr.push_back(mapAttr);
4925 if (parser.parseOptionalComma())
4926 break;
4927 } while (true);
4928 if (parser.parseRSquare())
4929 return failure();
4930 }
4931 // At this stage of parsing the only way to infer number of region
4932 // args is through op kind, as input output tensors are not parsed yet.
4933 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4934 int numRegionArgs =
4935 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4936 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4937 ElementwiseOp::getRegionBuilder())) {
4938 return parser.emitError(parser.getCurrentLocation(),
4939 "unable to parse elemwise op");
4940 }
4941
4942 // Initialize indexingMaps, if not supplied explicitly.
4943 if (indexingMapsAttr.empty()) {
4944 // We need to infer the numDims of the indexing maps from the output
4945 // type which is already parsed by now.
4946 auto resultType = result.operands[result.operands.size() - 1].getType();
4947 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4948 if (!shapedType)
4949 return parser.emitError(parser.getCurrentLocation(),
4950 "return type needs to be shaped type");
4951 auto numDims = shapedType.getRank();
4952 indexingMapsAttr = llvm::map_to_vector(
4953 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4954 parser.getContext()),
4955 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4956 }
4957
4958 result.addAttribute("indexing_maps",
4959 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4960 return success();
4961}
4962
4963void ElementwiseOp::print(OpAsmPrinter &p) {
4964 p << " kind=";
4965 p.printAttribute(getKindAttr());
4966 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4967 "indexing_maps"};
4968 unsigned arity =
4969 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4970 unsigned numDims = getResultRank();
4971
4972 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4973 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4974 getContext()),
4975 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4976
4977 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4978 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4979
4980 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4981 elidedAttrs);
4982}
4983
4984/// Implements the block region builder for the ElementwiseOp. This is called by
4985/// 'fillStructuredOpRegion'.
4986void ElementwiseOp::regionBuilder(
4987 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4988 function_ref<InFlightDiagnostic()> emitError) {
4989 ElementwiseKind elemwiseKind;
4990 for (auto attr : attrs) {
4991 if (attr.getName() == b.getStringAttr("kind")) {
4992 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4993 assert(kindAttr && "op kind attribute incorrectly set");
4994 elemwiseKind = kindAttr.getValue();
4995 break;
4996 }
4997 }
4998
4999 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
5000 auto arityGroup = groupAndKind.arityGroup;
5001 auto kind = groupAndKind.kind;
5002 if (emitError && block.getNumArguments() !=
5003 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
5004 emitError() << "Elementwise regionBuilder expects "
5005 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
5006 << block.getNumArguments();
5007 return;
5008 }
5009 assert(block.getNumArguments() ==
5010 getArityGroupAsUInt(arityGroup) + 1 /*output*/
5011 && "Elementwise regionBuilder number of block args mismatch");
5012
5013 RegionBuilderHelper helper(b, block);
5014 SmallVector<Value> yields;
5015 Value result;
5016
5017 if (arityGroup == ElementwiseArityGroup::Unary) {
5018 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
5019
5020 } else if (arityGroup == ElementwiseArityGroup::Binary) {
5021 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
5022 block.getArgument(1));
5023
5024 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
5025 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
5026 block.getArgument(1), block.getArgument(2));
5027
5028 } else {
5029 assert(false && "found unhandled category in elemwise");
5030 }
5031
5032 yields.push_back(result);
5033 helper.yieldOutputs(yields);
5034}
5035
5036LogicalResult ElementwiseOp::fold(FoldAdaptor,
5037 SmallVectorImpl<OpFoldResult> &) {
5038 return memref::foldMemRefCast(*this);
5039}
5040
5041void ElementwiseOp::getEffects(
5042 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5043 &effects) {
5044 if (hasPureTensorSemantics())
5045 return;
5046 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5047}
5048
5049Speculation::Speculatability ElementwiseOp::getSpeculatability() {
5050 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5051}
5052
5053//===----------------------------------------------------------------------===//
5054// PackOp/UnPackOp Common
5055//===----------------------------------------------------------------------===//
5056
5057template <typename OpTy, typename>
5058SmallVector<int64_t>
5060 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5061 ? packOrUnPack.getDestType()
5062 : packOrUnPack.getSourceType();
5063 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5064 ? packOrUnPack.getSourceType()
5065 : packOrUnPack.getDestType();
5067 packedType.getShape().take_front(unpackedType.getRank()));
5068 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5070 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
5071 }
5072 return result;
5073}
5078
5079// Given the (potentially) updated packed type, `newPackedTy`, generates an
5080// updated mixed-tile-sizes list. For each inner packed dimension that is static
5081// in `newPackedTy`, the tile is set to that static size (replacing SSA values
5082// or mismatched constants). Dynamic packed dimensions preserve the original
5083// tile. The folded tensor type is treated as authoritative for static extents.
5084// Note - packed-type-dim and mixed-tile-size should always match!
5087 ArrayRef<OpFoldResult> mixedTiles) {
5088 SmallVector<OpFoldResult> newMixedTileSizes;
5089 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5090 .getShape()
5091 .take_back(mixedTiles.size()),
5092 mixedTiles)) {
5093 int64_t dimSize = std::get<0>(it);
5094 if (dimSize == ShapedType::kDynamic) {
5095 newMixedTileSizes.push_back(std::get<1>(it));
5096 continue;
5097 }
5098 newMixedTileSizes.push_back(rewriter.getIndexAttr(dimSize));
5099 }
5100
5101 return newMixedTileSizes;
5102}
5103
5104template <typename OpTy>
5105static LogicalResult
5107 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5108 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5109 "applies to only pack or unpack operations");
5110 int64_t destRank = op.getDestRank();
5111 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5112 for (auto dim : llvm::seq<int64_t>(0, destRank))
5113 reifiedReturnShapes[0][dim] =
5114 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5115 return success();
5116}
5117
5118template <typename OpTy>
5120 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5121 "applies to only pack or unpack operations");
5122 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5123 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5124 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5125 assert(tiles.size() == dimsToTile.size() &&
5126 "tiles must match indices of dimension to block");
5127 // bind the dimension `i` with the tile factor.
5128 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5129 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5130 return dimAndTileMapping;
5131}
5132
5133template <typename OpTy>
5135 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5136 "applies to only pack or unpack operations");
5137 Builder builder(op);
5138 SmallVector<OpFoldResult> mixedInnerTiles;
5139 unsigned dynamicValIndex = 0;
5140 for (int64_t staticTile : op.getStaticInnerTiles()) {
5141 if (ShapedType::isStatic(staticTile))
5142 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5143 else
5144 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5145 }
5146 return mixedInnerTiles;
5147}
5148
5149template <typename OpTy>
5151 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5152 "applies to only pack or unpack operations");
5153 SmallVector<Value> dynamicTiles;
5154 SmallVector<int64_t> staticTiles;
5155 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5156 return staticTiles;
5157}
5158
5159/// Returns true if `dimsPos` is invalid. It is invalid when:
5160/// a) It contains duplicate.
5161/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5162/// c) The number of elements in `dimsPos` is > than `rank`.
5164 size_t rank) {
5165 size_t dimsPosSize = dimsPos.size();
5166 if (dimsPosSize > rank)
5167 return true;
5168 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5169 if (dimsPosSize != uniqued.size())
5170 return true;
5171 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5172 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5173 });
5174}
5175
5176template <typename OpTy>
5177static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5178 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5179 "applies to only pack or unpack operations");
5180 Operation *op = packOrUnPack.getOperation();
5181
5182 // Return true if we have a zero-value tile.
5183 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5184 return llvm::any_of(tiles, [](OpFoldResult tile) {
5185 return isa<Attribute>(tile) && isZeroInteger(tile);
5186 });
5187 };
5188
5189 // Verify that the source and destination are ranked types.
5190 if (!packOrUnPack.getSourceType().hasRank() ||
5191 !packOrUnPack.getDestType().hasRank())
5192 return op->emitError("expected both source and destination to have rank");
5193
5194 // Verify that the Operation does not have mixed tensor/buffer semantics.
5195 if (!packOrUnPack.hasPureBufferSemantics() &&
5196 !packOrUnPack.hasPureTensorSemantics())
5197 return op->emitError("mixing tensor and buffer semantics is not allowed");
5198 const unsigned numResults = packOrUnPack.getNumResults();
5199 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5200 return op->emitError("expected 1 result, got ") << numResults;
5201 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5202 return op->emitError("expected 0 results, got ") << numResults;
5203
5204 // Verify tiles. Do not allow zero tiles.
5205 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5206 if (hasZeros(mixedTiles))
5207 return op->emitError("invalid zero tile factor");
5208
5209 // Verify inner_dims_pos and outer_dims_perm.
5210 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5211 ? packOrUnPack.getSourceType()
5212 : packOrUnPack.getDestType();
5213 size_t unpackedRank = unpackedType.getRank();
5214 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5215 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5216 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5217 return op->emitError("invalid inner_dims_pos vector");
5218 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5219 return op->emitError("invalid outer_dims_perm vector");
5220 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5221 return op->emitError("outer_dims_perm must be a permutation or empty");
5222
5223 // Tiling factors must be less than or equal to the input rank for pack (or
5224 // output rank for unpack), and must match the number of `inner_dims_pos`.
5225 if (mixedTiles.size() > unpackedRank) {
5226 return op->emitError("tiling factors must be less than or equal to the "
5227 "input rank for pack or output rank for unpack");
5228 }
5229 if (mixedTiles.size() != innerDimsPos.size()) {
5230 return op->emitError(
5231 "tiling factors must equal the number of dimensions to tile");
5232 }
5233
5234 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5235 ? packOrUnPack.getDestType()
5236 : packOrUnPack.getSourceType();
5237 size_t packedRank = packedType.getRank();
5238 // Require output rank to match input rank + number of blocking factors.
5239 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5240 if (expectedPackedRank != packedRank) {
5241 return op->emitError(
5242 "packed rank != (unpacked rank + num tiling factors), got ")
5243 << packedRank << " != " << expectedPackedRank;
5244 }
5245
5246 // Verify result shape is greater than the minimum expected
5247 // by the pack operation, and that the output shape
5248 // represents full tiles.
5249 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5250 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5251 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5252 for (auto it : llvm::enumerate(llvm::zip(
5253 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5254 int64_t dimSize = std::get<0>(it.value());
5255 if (Attribute attr =
5256 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5257 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5258 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5259 if (dimSize != staticTileSize)
5260 return op->emitError(
5261 "mismatch in inner tile sizes specified and shaped of "
5262 "tiled dimension in the packed type at index ")
5263 << it.index() << ": got " << dimSize << " != " << staticTileSize;
5264 } else if (!ShapedType::isDynamic(dimSize)) {
5265 return op->emitError("mismatch in inner tile sizes specified at index ")
5266 << it.index() << ": got static shape " << dimSize
5267 << " but dynamic tile size";
5268 }
5269 }
5270 if (failed(
5271 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5272 auto elementType = unpackedType.getElementType();
5273 Type expectedType, actualType;
5274 if (packOrUnPack.hasPureTensorSemantics()) {
5275 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5276 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5277 } else {
5278 expectedType = MemRefType::get(expectedPackedShape, elementType);
5279 actualType = MemRefType::get(packedType.getShape(), elementType);
5280 }
5281 return op->emitError("expected ")
5282 << expectedType << " for the packed domain value, got "
5283 << actualType;
5284 }
5285 return success();
5286}
5287
5288namespace {
5289/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5290/// various permutations to the op.
5291// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5292// these. These may or may not become true foldings / canonicalizations
5293// depending on how aggressive we want to be in automatically folding
5294// transposes.
5295struct PackOrUnPackTransposeResult {
5296 SmallVector<int64_t> innerDimsPos;
5297 SmallVector<OpFoldResult> innerTiles;
5298 SmallVector<int64_t> outerDimsPerm;
5299};
5300} // namespace
5301
5302template <typename OpTy>
5303static PackOrUnPackTransposeResult
5305 ArrayRef<int64_t> innerPermutation,
5306 ArrayRef<int64_t> outerPermutation) {
5307 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5308 "applies to only pack or unpack operations");
5309 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5310 "some permutation must be non-empty");
5311 PackOrUnPackTransposeResult metadata;
5312 metadata.innerDimsPos =
5313 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5314 metadata.innerTiles =
5315 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5316 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5317 ? packOrUnPackOp.getSourceRank()
5318 : packOrUnPackOp.getDestRank();
5319 metadata.outerDimsPerm =
5320 packOrUnPackOp.getOuterDimsPerm().empty()
5321 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5322 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5323 if (!innerPermutation.empty()) {
5324 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5325 isPermutationVector(innerPermutation) &&
5326 "invalid inner permutation");
5327 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5328 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5329 }
5330 if (!outerPermutation.empty()) {
5331 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5332 isPermutationVector(outerPermutation) &&
5333 "invalid outer permutation");
5334 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5335 }
5336 return metadata;
5337}
5338
5339//===----------------------------------------------------------------------===//
5340// PackOp
5341//===----------------------------------------------------------------------===//
5342
5343void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5344 if (!getResults().empty())
5345 setNameFn(getResult(), "pack");
5346}
5347
5348ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5349 OpAsmParser::UnresolvedOperand source, dest;
5352 SmallVector<Type> paddingValueType;
5353 SmallVector<int64_t> staticTiles;
5354 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5355 Type sourceType, destType, resultType;
5356
5357 if (parser.parseOperand(source))
5358 return failure();
5359
5360 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5361 if (parser.parseLParen() ||
5362 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5363 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5364 parser.parseRParen())
5365 return failure();
5366 }
5367
5368 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5369 if (parser.parseEqual())
5370 return failure();
5371
5372 SmallVector<int64_t> outerDimsPermVec;
5374 int64_t value;
5375 if (parser.parseInteger(value))
5376 return failure();
5377 outerDimsPermVec.push_back(value);
5378 return success();
5379 }))
5380 return failure();
5381 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5382 }
5383
5384 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5385 return failure();
5386
5387 SmallVector<int64_t> innerDimsPosVec;
5389 int64_t value;
5390 if (parser.parseInteger(value))
5391 return failure();
5392 innerDimsPosVec.push_back(value);
5393 return success();
5394 }))
5395 return failure();
5396 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5397
5398 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5399 return failure();
5400
5401 DenseI64ArrayAttr staticTilesAttr;
5402 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5403 return failure();
5404 for (auto val : staticTilesAttr.asArrayRef())
5405 staticTiles.push_back(val);
5406
5407 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5408 return failure();
5409
5410 if (parser.parseOptionalAttrDict(result.attributes))
5411 return failure();
5412
5413 if (parser.parseColon() || parser.parseType(sourceType))
5414 return failure();
5415
5416 bool hasArrow = succeeded(parser.parseOptionalArrow());
5417 if (hasArrow) {
5418 if (parser.parseType(destType))
5419 return failure();
5420 }
5421
5422 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5423 if (!hasArrow) {
5424 return parser.emitError(parser.getCurrentLocation(),
5425 "pack/unpack requires '->' and destination type");
5426 }
5427
5428 if (!isMemRef)
5429 resultType = destType;
5430
5431 if (parser.resolveOperand(source, sourceType, result.operands) ||
5432 parser.resolveOperand(dest, destType, result.operands))
5433 return failure();
5434
5435 if (!paddingValue.empty() &&
5436 parser.resolveOperands(paddingValue, paddingValueType[0],
5437 result.operands))
5438 return failure();
5439
5440 if (!dynamicTiles.empty() &&
5441 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5442 result.operands))
5443 return failure();
5444
5445 result.addAttribute("static_inner_tiles",
5446 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5447 result.addAttribute("inner_dims_pos", innerDimsPos);
5448 if (outerDimsPerm)
5449 result.addAttribute("outer_dims_perm", outerDimsPerm);
5450
5451 SmallVector<int32_t> segmentSizes = {
5452 1, 1, static_cast<int32_t>(paddingValue.size()),
5453 static_cast<int32_t>(dynamicTiles.size())};
5454 result.addAttribute("operandSegmentSizes",
5455 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5456
5457 if (!isMemRef)
5458 result.addTypes(resultType);
5459
5460 return success();
5461}
5462
5463void PackOp::print(OpAsmPrinter &p) {
5464 p << " " << getSource();
5465
5466 if (getPaddingValue()) {
5467 p << " padding_value(" << getPaddingValue() << " : "
5468 << getPaddingValue().getType() << ")";
5469 }
5470
5471 if (!getOuterDimsPerm().empty()) {
5472 p << " outer_dims_perm = [";
5473 llvm::interleaveComma(getOuterDimsPerm(), p);
5474 p << "]";
5475 }
5476
5477 p << " inner_dims_pos = [";
5478 llvm::interleaveComma(getInnerDimsPos(), p);
5479 p << "]";
5480
5481 p << " inner_tiles = ";
5482 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5483
5484 p << " into " << getDest();
5485
5486 p.printOptionalAttrDict((*this)->getAttrs(),
5487 {"static_inner_tiles", "inner_dims_pos",
5488 "outer_dims_perm", "operandSegmentSizes"});
5489
5490 p << " : " << getSource().getType();
5491 p << " -> " << getDest().getType();
5492}
5493
5494void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5495 Value dest, ArrayRef<int64_t> innerDimsPos,
5496 ArrayRef<OpFoldResult> innerTiles,
5497 std::optional<Value> paddingValue,
5498 ArrayRef<int64_t> outerDimsPerm) {
5499 assert(innerDimsPos.size() == innerTiles.size() &&
5500 "number of tile sizes specified must match the specified number of "
5501 "original dimensions to be tiled");
5502 SmallVector<int64_t> staticTileSizes;
5503 SmallVector<Value> dynamicTileSizes;
5504 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5505 build(builder, state, dest.getType(), source, dest,
5506 paddingValue ? *paddingValue : nullptr,
5507 outerDimsPerm.empty() ? nullptr
5508 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5509 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5510 builder.getDenseI64ArrayAttr(staticTileSizes));
5511}
5512
5513LogicalResult
5514PackOp::reifyResultShapes(OpBuilder &builder,
5515 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5516 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5517}
5518
5519DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5520 return getDimAndTileMappingImpl(*this);
5521}
5522
5523SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5524 return getMixedTilesImpl(*this);
5525}
5526
5527SmallVector<int64_t> PackOp::getStaticTiles() {
5528 return getStaticTilesImpl(*this);
5529}
5530
5531ArrayRef<int64_t> PackOp::getAllOuterDims() {
5532 ShapedType inputType = getSourceType();
5533 int64_t inputRank = inputType.getRank();
5534 return getDestType().getShape().take_front(inputRank);
5535}
5536
5537SmallVector<int64_t> PackOp::getTiledOuterDims() {
5538 auto innerDimsPos = getInnerDimsPos();
5539 SmallVector<int64_t> outerDims(getAllOuterDims());
5540 SmallVector<int64_t> res;
5541
5542 // Recover the original order of the outer dims.
5543 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5544 invertPermutationVector(outerDimPermInv);
5545 if (!outerDimPermInv.empty())
5546 applyPermutationToVector(outerDims, outerDimPermInv);
5547
5548 // Collect the outer dims corresponding to the tilled inner dims.
5549 for (auto index : innerDimsPos)
5550 res.push_back(outerDims[index]);
5551
5552 return res;
5553}
5554
5555bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5556 ArrayRef<int64_t> innerDimsPos,
5557 ArrayRef<int64_t> outputShape,
5558 ArrayRef<int64_t> outerDimsPerm,
5559 ArrayRef<OpFoldResult> innerTiles) {
5560 SmallVector<int64_t> outputTileSizes(
5561 outputShape.take_front(inputShape.size()));
5562 if (!outerDimsPerm.empty()) {
5563 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5564 "expected output and outer_dims_perm to have same size");
5565 applyPermutationToVector(outputTileSizes,
5566 invertPermutationVector(outerDimsPerm));
5567 }
5568 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5569 if (ShapedType::isDynamic(inputShape[pos]))
5570 continue;
5571 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5572 if (!constantTile) {
5573 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5574 (inputShape[pos] % outputTileSizes[pos] != 0))
5575 return true;
5576 } else {
5577 assert(*constantTile != 0 && "static tile size can't be zero");
5578 if (inputShape[pos] % (*constantTile) != 0) {
5579 return true;
5580 }
5581 }
5582 }
5583 return false;
5584}
5585
5586bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5587 ArrayRef<int64_t> innerDimsPos,
5588 ArrayRef<int64_t> outputShape,
5589 ArrayRef<int64_t> outerDimsPerm,
5590 ArrayRef<OpFoldResult> innerTiles) {
5591 SmallVector<int64_t> outputTileSizes(
5592 outputShape.take_front(inputShape.size()));
5593 if (!outerDimsPerm.empty()) {
5594 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5595 "expected output and outer_dims_perm to have same size");
5596 applyPermutationToVector(outputTileSizes,
5597 invertPermutationVector(outerDimsPerm));
5598 }
5599 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5600 if (ShapedType::isDynamic(inputShape[pos]) ||
5601 ShapedType::isDynamic(outputTileSizes[pos]))
5602 return true;
5603 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5604 if (!constantTile)
5605 return true;
5606 assert(*constantTile != 0 && "static tile size can't be zero");
5607 if (inputShape[pos] % (*constantTile) != 0)
5608 return true;
5609 }
5610 return false;
5611}
5612
5613LogicalResult PackOp::verify() {
5615 return failure();
5616
5617 // Verify padding value, and bail out if the tile does not divide the
5618 // dimension fully. In the case of dynamic tile factors or dimensions, having
5619 // a partial tile is undefined behavior.
5620 auto paddingValue = getPaddingValue();
5621 if (paddingValue &&
5622 paddingValue.getType() != getSourceType().getElementType()) {
5623 return emitOpError("expected padding_value has ")
5624 << getSourceType().getElementType()
5625 << " but got: " << paddingValue.getType();
5626 }
5627
5628 if (!paddingValue &&
5629 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5630 getDestType().getShape(), getOuterDimsPerm(),
5631 getMixedTiles())) {
5632 return emitOpError(
5633 "invalid tile factor or output size provided. Only full tiles are "
5634 "supported when padding_value is not set");
5635 }
5636 return success();
5637}
5638
5639/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5640/// Value's to kDynamic, even if they are arith.constant values.
5641static SmallVector<int64_t>
5644 for (auto o : ofrs) {
5645 // Have to do this first, as getConstantIntValue special-cases constants.
5646 if (llvm::dyn_cast_if_present<Value>(o))
5647 result.push_back(ShapedType::kDynamic);
5648 else
5649 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5650 }
5651 return result;
5652}
5653
5654SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5655 ArrayRef<int64_t> innerTileSizes,
5656 ArrayRef<int64_t> innerDimsPos,
5657 ArrayRef<int64_t> outerDimsPerm) {
5658 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5659 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5660 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5661 continue;
5662 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5663 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5664 continue;
5665 }
5666 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5667 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5668 }
5669
5670 // Swap tile loops if outer_dims_perm is available.
5671 if (!outerDimsPerm.empty())
5672 applyPermutationToVector(resultShape, outerDimsPerm);
5673
5674 // Append the inner tile dimensions.
5675 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5676 return resultShape;
5677}
5678
5679SmallVector<OpFoldResult> PackOp::getResultShape(
5680 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5681 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5682 ArrayRef<int64_t> outerDimsPerm) {
5683 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5684
5685 AffineExpr s0, s1;
5686 bindSymbols(builder.getContext(), s0, s1);
5687 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5688 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5689 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5690 builder, loc, ceilDivExpr,
5691 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5692 }
5693 if (!outerDimsPerm.empty())
5694 applyPermutationToVector(resultDims, outerDimsPerm);
5695 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5696
5697 SmallVector<int64_t> resultTypeShape =
5698 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5699 asShapeWithAnyValueAsDynamic(innerTileSizes),
5700 innerDimsPos, outerDimsPerm);
5701
5702 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5703 // result type shape says it's a dynamic dim. This is needed as callers may
5704 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5705 // dynamic dims returned by that.
5706 for (unsigned i = 0; i < resultDims.size(); ++i) {
5707 if (ShapedType::isStatic(resultTypeShape[i]))
5708 continue;
5709 resultDims[i] =
5710 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5711 }
5712
5713 return resultDims;
5714}
5715
5716RankedTensorType PackOp::inferPackedTensorType(
5717 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5718 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5719 SmallVector<int64_t> resultShape = inferPackedShape(
5720 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5721 return RankedTensorType::get(resultShape, sourceType.getElementType());
5722}
5723
5724MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5725 ArrayRef<int64_t> innerTileSizes,
5726 ArrayRef<int64_t> innerDimsPos,
5727 ArrayRef<int64_t> outerDimsPerm) {
5728 SmallVector<int64_t> resultShape = inferPackedShape(
5729 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5730 return MemRefType::get(resultShape, sourceType.getElementType());
5731}
5732
5733Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5734 ArrayRef<OpFoldResult> innerTileSizes,
5735 ArrayRef<int64_t> innerDimsPos,
5736 ArrayRef<int64_t> outerDimsPerm) {
5737 AffineExpr dim0, dim1;
5738 bindDims(b.getContext(), dim0, dim1);
5739 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5740 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5741 {v1, v2});
5742 };
5743
5744 SmallVector<OpFoldResult> mixedSizes;
5745 for (auto [index, value] : llvm::enumerate(
5746 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5747 if (ShapedType::isDynamic(value))
5748 mixedSizes.push_back(
5749 tensor::DimOp::create(b, loc, source, index).getResult());
5750 else
5751 mixedSizes.push_back(b.getIndexAttr(value));
5752 }
5753 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5754 int64_t dimPos = std::get<0>(it);
5755 OpFoldResult tileSize = std::get<1>(it);
5756 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5757 }
5758 if (!outerDimsPerm.empty())
5759 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5760
5761 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5762 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5763 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5764}
5765
5766PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5767 ArrayRef<int64_t> innerPermutation,
5768 ArrayRef<int64_t> outerPermutation) {
5769 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5770 *this, innerPermutation, outerPermutation);
5771 Value transposedDest =
5772 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5773 metadata.innerDimsPos, metadata.outerDimsPerm);
5774 return PackOp::create(b, loc, getSource(), transposedDest,
5775 metadata.innerDimsPos, metadata.innerTiles,
5776 getPaddingValue(), metadata.outerDimsPerm);
5777}
5778
5779template <typename OpTy>
5782 &effects) {
5783 // No memory effects for pure tensor semantics
5784 if (op.hasPureTensorSemantics())
5785 return;
5786
5787 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5788 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5789 continue;
5790
5791 if (&opOperand == &op.getSourceMutable()) {
5792 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5793 /*effectOnFullRegion=*/true,
5795 } else if (&opOperand == &op.getDestMutable()) {
5796 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5797 /*effectOnFullRegion=*/true,
5799 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5800 /*effectOnFullRegion=*/true,
5802 }
5803 }
5804}
5805
5806void PackOp::getEffects(
5808 &effects) {
5809 getPackUnPackEffectsImpl(*this, effects);
5810}
5811
5812void UnPackOp::getEffects(
5814 &effects) {
5815 getPackUnPackEffectsImpl(*this, effects);
5816}
5817
5818/// Returns true if the tiles and the tiled dims are constant.
5819template <typename OpTy>
5821 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5822 "applies to only pack or unpack operations");
5823 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5824 ? op.getDestType()
5825 : op.getSourceType();
5826 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5827 for (auto [dimDest, tile] : llvm::zip(
5828 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5829 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5830 if (!constTileSize || ShapedType::isDynamic(dimDest))
5831 return false;
5832 }
5833 return true;
5834}
5835
5836Speculation::Speculatability PackOp::getSpeculatability() {
5837 if (!hasPureTensorSemantics())
5839 if (getPaddingValue())
5841
5842 // The verifier rejects already operations if we can statically prove that the
5843 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5844 // to have constant tiles and tiled inner dimensions.
5847
5849}
5850
5851// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5852// dimensions for pack and unpack.
5853static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5854 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5855 return false;
5856 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5857 return true;
5858 // Outer dims permutation is optional.
5859 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5860 // identity permutation.
5861 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5862 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5863}
5864
5865// Return true if pack and unpack have the same tiles.
5866// Same SSA values or same integer constants.
5867static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5868 auto packTiles = packOp.getMixedTiles();
5869 auto unPackTiles = unPackOp.getMixedTiles();
5870 if (packTiles.size() != unPackTiles.size())
5871 return false;
5872 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5873 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5874 return false;
5875 }
5876 return true;
5877}
5878
5879/// Returns true if the pack op does not need a padding value.
5880static bool paddingIsNotNeeded(PackOp op) {
5881 auto srcType = op.getSourceType();
5882 if (llvm::any_of(op.getInnerDimsPos(),
5883 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5884 return false;
5885 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5886 return false;
5887 return !PackOp::requirePaddingValue(
5888 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5889 op.getOuterDimsPerm(), op.getMixedTiles());
5890}
5891
5892/// Returns true if the `srcShape` or `destShape` is different from the one in
5893/// `packOp` and populates each with the inferred static shape.
5894static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5895 SmallVectorImpl<int64_t> &destShape) {
5896 bool changeNeeded = false;
5897 srcShape.assign(packOp.getSourceType().getShape().begin(),
5898 packOp.getSourceType().getShape().end());
5899 destShape.assign(packOp.getDestType().getShape().begin(),
5900 packOp.getDestType().getShape().end());
5901 llvm::SmallSetVector<int64_t, 4> innerDims;
5902 innerDims.insert_range(packOp.getInnerDimsPos());
5903 SmallVector<int64_t> inverseOuterDimsPerm;
5904 if (!packOp.getOuterDimsPerm().empty())
5905 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5906 int srcRank = packOp.getSourceRank();
5907 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5908 if (innerDims.contains(i))
5909 continue;
5910 int64_t srcPos = i;
5911 int64_t destPos = i;
5912 if (!inverseOuterDimsPerm.empty())
5913 destPos = inverseOuterDimsPerm[srcPos];
5914 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5915 ShapedType::isDynamic(destShape[destPos])) {
5916 continue;
5917 }
5918 int64_t size = srcShape[srcPos];
5919 if (ShapedType::isDynamic(size))
5920 size = destShape[destPos];
5921 srcShape[srcPos] = size;
5922 destShape[destPos] = size;
5923 changeNeeded = true;
5924 }
5925 return changeNeeded;
5926}
5927
5928LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5929 // TODO: Support Memref PackOp. Temporarily return failure.
5930 if (!packOp.hasPureTensorSemantics())
5931 return failure();
5932
5933 // Fold an pack(unpack(x)) to x.
5934 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5935 if (unPackOp.getSourceType() == packOp.getDestType() &&
5936 !packOp.getPaddingValue() &&
5937 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5938 haveSameTiles(packOp, unPackOp)) {
5939 rewriter.replaceOp(packOp, unPackOp.getSource());
5940 return success();
5941 }
5942 }
5943
5944 // Fold optional PaddingValue operand away if padding is not needed.
5945 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5946 rewriter.startOpModification(packOp);
5947 packOp.getPaddingValueMutable().clear();
5948 rewriter.finalizeOpModification(packOp);
5949 return success();
5950 }
5951
5952 // Insert tensor.cast ops if static shape inference is available..
5953 SmallVector<int64_t> srcShape, destShape;
5954 if (inferStaticShape(packOp, srcShape, destShape)) {
5955 Location loc = packOp.getLoc();
5956 Value source = packOp.getSource();
5957 if (srcShape != packOp.getSourceType().getShape()) {
5958 auto newSrcType = packOp.getSourceType().clone(srcShape);
5959 source =
5960 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5961 }
5962 Value dest = packOp.getDest();
5963 ShapedType originalResultType = packOp.getDestType();
5964 bool needUpdateDestType = (destShape != originalResultType.getShape());
5965 if (needUpdateDestType) {
5966 auto newDestType = packOp.getDestType().clone(destShape);
5967 dest =
5968 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5969 }
5970 rewriter.modifyOpInPlace(packOp, [&] {
5971 packOp.getSourceMutable().assign(source);
5972 packOp.getDestMutable().assign(dest);
5973 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5974 });
5975 // Insert a cast if needed
5976 if (needUpdateDestType) {
5977 rewriter.setInsertionPointAfter(packOp);
5978 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5979 packOp.getResult());
5980 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5981 }
5982 return success();
5983 }
5984
5985 return failure();
5986}
5987
5988template <typename PackOrUnpackOp>
5989static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5990 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5991 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5992 "Function meant for pack/unpack");
5993 // This is a pad if packing only adds ones and we don't transpose dimensions.
5994
5995 // Check that we are not transposing any dimensions.
5996 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5997 int64_t numPackedDims = innerDimsPos.size();
5998 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5999 if (orderedDims != innerDimsPos) {
6000 // Dimensions don't happen in order.
6001 return false;
6002 }
6003
6004 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
6005 int64_t packedRank = packedTensorType.getRank();
6006 // At this point we know that we are taking numPackedDims outer
6007 // dimensions and pushing them all the way as the inner most dimensions.
6008 // What's left on the outer most dimensions is, in this order:
6009 // - the factor of the packed dimensions, then
6010 // - the untouched dimensions
6011 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
6012 // if all the dimensions that bubble outerward are ones.
6013 // Therefore check that all the dimensions but the numPackedDims inner most
6014 // ones are ones.
6015 return llvm::all_of(
6016 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6017 [&packedShape](int64_t i) { return packedShape[i] == 1; });
6018}
6019
6020bool PackOp::isLikePad() {
6021 auto packedTensorType =
6022 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6023 return isLikePadUnPad(*this, packedTensorType);
6024}
6025
6026::mlir::LogicalResult
6027PackOp::fold(FoldAdaptor adaptor,
6029 if (!hasPureTensorSemantics())
6030 return failure();
6031 std::optional<Attribute> paddingValue;
6032 if (auto pad = adaptor.getPaddingValue())
6033 paddingValue = pad;
6034 if (OpFoldResult reshapedSource = reshapeConstantSource(
6035 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6036 cast<TensorType>(getDestType()), paddingValue)) {
6037 results.push_back(reshapedSource);
6038 return success();
6039 }
6040 return failure();
6041}
6042
6043/// Folds a tensor.cast op into a consuming PackOp op if the
6044/// `tensor.cast` has source that is more static than the consuming op.
6045///
6046/// Example:
6047/// ```mlir
6048/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
6049/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
6050/// ```
6051///
6052/// folds into:
6053///
6054/// ```mlir
6055/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
6056/// ```
6059
6060 LogicalResult matchAndRewrite(PackOp op,
6061 PatternRewriter &rewriter) const override {
6062 // TODO: Support Memref PackOp. Temporarily return failure.
6063 if (!op.hasPureTensorSemantics())
6064 return failure();
6065
6067 return failure();
6068
6069 SmallVector<Type> newResultTypes(op->getResultTypes());
6070 SmallVector<Value> newOperands =
6072
6073 // Get the updated mixed-tile-sizes attribute.
6074 SmallVector<OpFoldResult> newMixedTileSizes =
6075 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6076 if (llvm::any_of(newMixedTileSizes, isZeroInteger))
6077 return failure();
6078
6079 // Clone op.
6080 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6081 // this point. However, in practice, we use them for things that we'd like
6082 // to preserve. Implement a better abstraction.
6083 PackOp newOp =
6084 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6085 op.getInnerDimsPos(), newMixedTileSizes,
6086 op.getPaddingValue(), op.getOuterDimsPerm());
6087 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6088
6089 // Replace op.
6090 Value oldResult = op.getResult();
6091 Value newResult = newOp.getResult();
6093 (newResult.getType() != oldResult.getType())
6094 ? tensor::CastOp::create(rewriter, op->getLoc(),
6095 oldResult.getType(), newResult)
6096 : newResult;
6097
6098 rewriter.replaceOp(op, {replacement});
6099
6100 return success();
6101 }
6102};
6103
6104//===----------------------------------------------------------------------===//
6105// UnPackOp
6106//===----------------------------------------------------------------------===//
6107
6108void UnPackOp::getAsmResultNames(
6109 function_ref<void(Value, StringRef)> setNameFn) {
6110 if (!getResults().empty())
6111 setNameFn(getResult(), "unpack");
6112}
6113
6114// Custom parser for UnPackOp that handles the memref/tensor case distinction
6115ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6116 OpAsmParser::UnresolvedOperand source, dest;
6118 SmallVector<int64_t> staticTiles;
6119 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6120 Type sourceType, destType, resultType;
6121
6122 if (parser.parseOperand(source))
6123 return failure();
6124
6125 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6126 if (parser.parseEqual())
6127 return failure();
6128
6129 SmallVector<int64_t> outerDimsPermVec;
6131 int64_t value;
6132 if (parser.parseInteger(value))
6133 return failure();
6134 outerDimsPermVec.push_back(value);
6135 return success();
6136 }))
6137 return failure();
6138 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6139 }
6140
6141 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6142 return failure();
6143
6144 SmallVector<int64_t> innerDimsPosVec;
6146 int64_t value;
6147 if (parser.parseInteger(value))
6148 return failure();
6149 innerDimsPosVec.push_back(value);
6150 return success();
6151 }))
6152 return failure();
6153 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6154
6155 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6156 return failure();
6157
6158 DenseI64ArrayAttr staticTilesAttr;
6159 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6160 return failure();
6161 for (auto val : staticTilesAttr.asArrayRef())
6162 staticTiles.push_back(val);
6163
6164 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6165 return failure();
6166
6167 if (parser.parseOptionalAttrDict(result.attributes))
6168 return failure();
6169
6170 if (parser.parseColon() || parser.parseType(sourceType))
6171 return failure();
6172
6173 bool hasArrow = succeeded(parser.parseOptionalArrow());
6174 if (hasArrow) {
6175 if (parser.parseType(destType))
6176 return failure();
6177 }
6178
6179 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6180 if (!hasArrow) {
6181 return parser.emitError(parser.getCurrentLocation(),
6182 "pack/unpack requires '->' and destination type");
6183 }
6184
6185 if (!isMemRef)
6186 resultType = destType;
6187
6188 if (parser.resolveOperand(source, sourceType, result.operands) ||
6189 parser.resolveOperand(dest, destType, result.operands))
6190 return failure();
6191
6192 if (!dynamicTiles.empty() &&
6193 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6194 result.operands))
6195 return failure();
6196
6197 result.addAttribute("static_inner_tiles",
6198 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6199 result.addAttribute("inner_dims_pos", innerDimsPos);
6200 if (outerDimsPerm)
6201 result.addAttribute("outer_dims_perm", outerDimsPerm);
6202
6203 SmallVector<int32_t> segmentSizes = {
6204 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6205 result.addAttribute("operandSegmentSizes",
6206 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6207
6208 if (!isMemRef)
6209 result.addTypes(resultType);
6210
6211 return success();
6212}
6213
6214void UnPackOp::print(OpAsmPrinter &p) {
6215 p << " " << getSource();
6216
6217 if (!getOuterDimsPerm().empty()) {
6218 p << " outer_dims_perm = [";
6219 llvm::interleaveComma(getOuterDimsPerm(), p);
6220 p << "]";
6221 }
6222
6223 p << " inner_dims_pos = [";
6224 llvm::interleaveComma(getInnerDimsPos(), p);
6225 p << "]";
6226
6227 p << " inner_tiles = ";
6228 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6229
6230 p << " into " << getDest();
6231
6232 p.printOptionalAttrDict((*this)->getAttrs(),
6233 {"static_inner_tiles", "inner_dims_pos",
6234 "outer_dims_perm", "operandSegmentSizes"});
6235
6236 p << " : " << getSource().getType();
6237 p << " -> " << getDest().getType();
6238}
6239
6240LogicalResult
6241UnPackOp::reifyResultShapes(OpBuilder &builder,
6242 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6243 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6244}
6245
6246DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6247 return getDimAndTileMappingImpl(*this);
6248}
6249
6250SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6251 return getMixedTilesImpl(*this);
6252}
6253
6254SmallVector<int64_t> UnPackOp::getStaticTiles() {
6255 return getStaticTilesImpl(*this);
6256}
6257
6258ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6259 ShapedType destType = getDestType();
6260 int64_t destRank = destType.getRank();
6261 return getSourceType().getShape().take_front(destRank);
6262}
6263
6264SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6265 auto innerDimsPos = getInnerDimsPos();
6266 SmallVector<int64_t> outerDims(getAllOuterDims());
6267 SmallVector<int64_t> res;
6268
6269 // Recover the original order of the outer dims.
6270 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6271 invertPermutationVector(outerDimPermInv);
6272 if (!outerDimPermInv.empty())
6273 applyPermutationToVector(outerDims, outerDimPermInv);
6274
6275 // Collect the outer dims corresponding to the tilled inner dims.
6276 for (auto index : innerDimsPos)
6277 res.push_back(outerDims[index]);
6278
6279 return res;
6280}
6281
6282LogicalResult UnPackOp::verify() {
6283 return commonVerifierPackAndUnPackOp(*this);
6284}
6285
6286Speculation::Speculatability UnPackOp::getSpeculatability() {
6287 if (!hasPureTensorSemantics())
6289 // See PackOp::getSpeculatability.
6292
6294}
6295
6296void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6297 Value dest, ArrayRef<int64_t> innerDimsPos,
6298 ArrayRef<OpFoldResult> innerTiles,
6299 ArrayRef<int64_t> outerDimsPerm) {
6300 assert(innerDimsPos.size() == innerTiles.size() &&
6301 "number of tile sizes specified must match the specified number of "
6302 "original dimensions to be tiled");
6303 SmallVector<int64_t> staticTileSizes;
6304 SmallVector<Value> dynamicTileSizes;
6305 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6306 build(builder, state, dest.getType(), source, dest,
6307 outerDimsPerm.empty() ? nullptr
6308 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6309 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6310 builder.getDenseI64ArrayAttr(staticTileSizes));
6311}
6312
6313Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6314 Value source,
6315 ArrayRef<OpFoldResult> innerTileSizes,
6316 ArrayRef<int64_t> innerDimsPos,
6317 ArrayRef<int64_t> outerDimsPerm) {
6318 AffineExpr sym0, sym1;
6319 bindSymbols(b.getContext(), sym0, sym1);
6320 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6321 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6322 };
6323
6324 SmallVector<OpFoldResult> mixedSizes;
6325 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6326 for (auto i :
6327 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6328 if (srcType.isDynamicDim(i))
6329 mixedSizes.push_back(
6330 tensor::DimOp::create(b, loc, source, i).getResult());
6331 else
6332 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6333 }
6334 if (!outerDimsPerm.empty()) {
6336 mixedSizes, invertPermutationVector(outerDimsPerm));
6337 }
6338
6339 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6340 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6341
6342 auto elemType = srcType.getElementType();
6343 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6344}
6345
6346UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6347 Value transposedSource,
6348 ArrayRef<int64_t> innerPermutation,
6349 ArrayRef<int64_t> outerPermutation) {
6350 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6351 *this, innerPermutation, outerPermutation);
6352 return UnPackOp::create(b, loc, transposedSource, getDest(),
6353 metadata.innerDimsPos, metadata.innerTiles,
6354 metadata.outerDimsPerm);
6355}
6356
6357/// Returns true if the `srcShape` or `destShape` is different from the one in
6358/// `op` and populates each with the inferred static shape.
6359static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6360 SmallVectorImpl<int64_t> &destShape) {
6361 bool changeNeeded = false;
6362 srcShape.assign(op.getSourceType().getShape().begin(),
6363 op.getSourceType().getShape().end());
6364 destShape.assign(op.getDestType().getShape().begin(),
6365 op.getDestType().getShape().end());
6366 llvm::SmallSetVector<int64_t, 4> innerDims;
6367 innerDims.insert_range(op.getInnerDimsPos());
6368 SmallVector<int64_t> inverseOuterDimsPerm;
6369 if (!op.getOuterDimsPerm().empty())
6370 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6371 int destRank = op.getDestRank();
6372 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6373 if (innerDims.contains(i))
6374 continue;
6375 int64_t srcPos = i;
6376 int64_t destPos = i;
6377 if (!inverseOuterDimsPerm.empty())
6378 srcPos = inverseOuterDimsPerm[destPos];
6379 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6380 ShapedType::isDynamic(destShape[destPos])) {
6381 continue;
6382 }
6383 int64_t size = srcShape[srcPos];
6384 if (ShapedType::isDynamic(size))
6385 size = destShape[destPos];
6386 srcShape[srcPos] = size;
6387 destShape[destPos] = size;
6388 changeNeeded = true;
6389 }
6390 return changeNeeded;
6391}
6392
6393LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6394 PatternRewriter &rewriter) {
6395 // TODO: Support Memref UnPackOp. Temporarily return failure.
6396 if (!unPackOp.hasPureTensorSemantics())
6397 return failure();
6398
6399 /// unpack(pack(x)) -> x
6400 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6401 if (packOp.getSourceType() != unPackOp.getDestType())
6402 return failure();
6403 if (packOp.getPaddingValue() ||
6404 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6405 !haveSameTiles(packOp, unPackOp))
6406 return failure();
6407 rewriter.replaceOp(unPackOp, packOp.getSource());
6408 return success();
6409 }
6410 /// unpack(destinationStyleOp(x)) -> unpack(x)
6411 if (auto dstStyleOp =
6412 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6413 auto destValue = cast<OpResult>(unPackOp.getDest());
6414 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6415 rewriter.modifyOpInPlace(unPackOp,
6416 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6417 return success();
6418 }
6419 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6420 if (unPackOp->hasOneUse()) {
6421 auto extractSliceUser =
6422 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6423 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6424 OpBuilder::InsertionGuard g(rewriter);
6425 rewriter.setInsertionPoint(unPackOp);
6426 auto newDest = tensor::ExtractSliceOp::create(
6427 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6428 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6429 extractSliceUser.getMixedStrides());
6430 rewriter.modifyOpInPlace(unPackOp, [&]() {
6431 unPackOp.setDpsInitOperand(0, newDest);
6432 unPackOp.getResult().setType(newDest.getType());
6433 });
6434 rewriter.replaceOp(extractSliceUser, unPackOp);
6435 return success();
6436 }
6437 }
6438
6439 // Insert tensor.cast ops if static shape inference is available..
6440 SmallVector<int64_t> srcShape, destShape;
6441 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6442 Location loc = unPackOp.getLoc();
6443 Value source = unPackOp.getSource();
6444 if (srcShape != unPackOp.getSourceType().getShape()) {
6445 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6446 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6447 unPackOp.getSource());
6448 }
6449 Value dest = unPackOp.getDest();
6450 if (destShape != unPackOp.getDestType().getShape()) {
6451 auto newDestType = unPackOp.getDestType().clone(destShape);
6452 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6453 unPackOp.getDest());
6454 }
6455 UnPackOp newOp = UnPackOp::create(
6456 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6457 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6458 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6459 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6460 return success();
6461 }
6462
6463 return failure();
6464}
6465
6466bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6467 // Rank-reduced folding is not supported.
6468 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6469 return false;
6470 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6471 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6472 return false;
6473 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6474 SmallVector<int64_t> outerShapeWithoutTranspose =
6476 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6477 for (auto [pos, tileSize] :
6478 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6479 areOuterDimsTiled[pos] = true;
6480 if (unpackedTypeAfterFold.isDynamicDim(pos))
6481 return false;
6482 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6483 return false;
6484 if (ShapedType::isDynamic(tileSize))
6485 return false;
6486 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6487 unpackedTypeAfterFold.getDimSize(pos);
6488 if (paddingSize >= tileSize)
6489 return false;
6490 }
6491 // extract_slice must not affect dimensions that are not being unpacked
6492 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6493 if (areOuterDimsTiled[pos])
6494 continue;
6495 int64_t dim = outerShapeWithoutTranspose[pos];
6496 if (ShapedType::isDynamic(dim))
6497 return false;
6498 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6499 return false;
6500 }
6501 return true;
6502}
6503
6504bool UnPackOp::isLikeUnPad() {
6505 ShapedType packedTensorType = getSourceType();
6506 return isLikePadUnPad(*this, packedTensorType);
6507}
6508
6509::mlir::LogicalResult
6510UnPackOp::fold(FoldAdaptor adaptor,
6511 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6512 // TODO: Support Memref UnPackOp. Temporarily return failure.
6513 if (!hasPureTensorSemantics())
6514 return failure();
6515
6516 if (OpFoldResult reshapedSource = reshapeConstantSource(
6517 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6518 cast<TensorType>(getResult().getType()))) {
6519 results.push_back(reshapedSource);
6520 return success();
6521 }
6522 return failure();
6523}
6524
6525/// Folds a tensor.cast op into a consuming UnPackOp op if the
6526/// `tensor.cast` has source that is more static than the consuming op.
6527///
6528/// Example:
6529/// ```mlir
6530/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6531/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6532/// ```
6533///
6534/// folds into:
6535///
6536/// ```mlir
6537/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6538/// ```
6539struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6540 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6541
6542 LogicalResult matchAndRewrite(UnPackOp op,
6543 PatternRewriter &rewriter) const override {
6544 // TODO: Support Memref UnPackOp. Temporarily return failure.
6545 if (!op.hasPureTensorSemantics())
6546 return failure();
6547
6549 return failure();
6550
6551 SmallVector<Type> newResultTypes(op->getResultTypes());
6552 SmallVector<Value> newOperands =
6554 Value sourceTensor = newOperands[0];
6555
6556 // Get the updated mixed-tile-sizes attribute.
6558 rewriter, sourceTensor.getType(), op.getMixedTiles());
6559
6560 // Clone op.
6561 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6562 // this point. However, in practice, we use them for things that we'd like
6563 // to preserve. Implement a better abstraction.
6564 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6565 newOperands[1], op.getInnerDimsPos(),
6566 newMixedTileSizes, op.getOuterDimsPerm());
6567 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6568
6569 // Replace op.
6570 Value oldResult = op.getResult();
6571 Value newResult = newOp.getResult();
6573 (newResult.getType() != oldResult.getType())
6574 ? tensor::CastOp::create(rewriter, op->getLoc(),
6575 oldResult.getType(), newResult)
6576 : newResult;
6577
6578 rewriter.replaceOp(op, {replacement});
6579
6580 return success();
6581 }
6582};
6583
6584//===----------------------------------------------------------------------===//
6585// BatchReduceMatmulOp
6586//===----------------------------------------------------------------------===//
6587SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6589 utils::IteratorType::reduction, utils::IteratorType::parallel,
6590 utils::IteratorType::parallel, utils::IteratorType::reduction};
6591}
6592
6593SmallVector<AffineMap>
6594BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6595 AffineExpr d0, d1, d2, d3;
6596 SmallVector<AffineMap> indexingMaps;
6597 bindDims(context, d0, d1, d2, d3);
6598 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6599 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6600 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6601 return indexingMaps;
6602}
6603
6604bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6605 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6606 if (!maps)
6607 return false;
6608 if (maps.size() != 3)
6609 return false;
6610 auto positions = getAffineResultPositions(maps);
6611 if (failed(positions))
6612 return false;
6613 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6614 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6615 (*positions)[2] == SmallVector<int64_t>{1, 2};
6616}
6617unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6618
6619std::string BatchReduceMatmulOp::getLibraryCallName() {
6620 return generateLibraryCallName(getOperation());
6621}
6622
6623/// Check if the op has broadcast and/or transpose semantic. Returns true if
6624/// the user defined indexing maps are not equal to default map.
6625bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6626 SmallVector<AffineMap, 3> defaultMaps =
6627 getDefaultIndexingMaps(this->getContext());
6628 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6629 return defaultMaps != explicitMaps;
6630}
6631
6632/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6633/// broadcast map must include K dimension.
6634/// TODO: Strict inclusion of K dimension in the broadcast map is not
6635/// necessary for both input matrices simultaneously. We can relax this
6636/// condition to have K dimension for one input matrix map and infer the K
6637/// dimension for other input matrix map from the one already having K
6638/// dimension.
6639bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6640 bool isLHS) {
6641 assert(bcastMap.getNumResults() < 3 &&
6642 "Expected less than 3 result dim expr.");
6643 bool isValid = false;
6644 enum Indices { batchPos, mPos, nPos, kPos };
6645 if (bcastMap.getNumResults() == 1) {
6646 AffineExpr expr = bcastMap.getResult(0);
6647 isValid = expr.isFunctionOfDim(kPos);
6648 } else if (bcastMap.getNumResults() == 2) {
6649 AffineExpr expr0 = bcastMap.getResult(0);
6650 AffineExpr expr1 = bcastMap.getResult(1);
6651 isValid =
6652 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6653 expr0.isFunctionOfDim(mPos)) &&
6654 expr1.isFunctionOfDim(kPos))
6655 : ((expr0.isFunctionOfDim(batchPos) &&
6656 expr1.isFunctionOfDim(kPos)) ||
6657 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6658 }
6659 return isValid;
6660}
6661
6662void BatchReduceMatmulOp::regionBuilder(
6663 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
6664 function_ref<InFlightDiagnostic()> emitError) {
6665 if (emitError && block.getNumArguments() != 3) {
6666 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6667 << block.getNumArguments();
6668 return;
6669 }
6670 assert(block.getNumArguments() == 3 &&
6671 "BatchReduceMatmulOp regionBuilder expects 3 args");
6672 RegionBuilderHelper helper(b, block);
6673 SmallVector<Value> yields;
6674
6675 auto toType = block.getArgument(2).getType();
6676 Value castValA =
6677 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6678 Value castValB =
6679 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6680 Value mulVal =
6681 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6682 if (!castValA || !castValB || !mulVal)
6683 return;
6684 Value addVal =
6685 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6686 if (!addVal)
6687 return;
6688 yields.push_back(addVal);
6689 helper.yieldOutputs(yields);
6690}
6691
6692ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6693 OperationState &result) {
6694 SmallVector<Attribute, 3> indexingMapsAttr;
6695 Attribute mapAttr;
6696 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6697 if (parser.parseEqual())
6698 return failure();
6699 if (parser.parseLSquare())
6700 return failure();
6701
6702 do {
6703 if (parser.parseAttribute(mapAttr))
6704 return failure();
6705 if (!isa<AffineMapAttr>(mapAttr)) {
6706 return parser.emitError(parser.getCurrentLocation(),
6707 "expected affine map attribute");
6708 }
6709 indexingMapsAttr.push_back(mapAttr);
6710
6711 if (parser.parseOptionalComma())
6712 break;
6713 } while (true);
6714
6715 if (parser.parseRSquare())
6716 return failure();
6717 }
6718 // Initialize indexingMaps, if not supplied explicitly.
6719 if (indexingMapsAttr.empty()) {
6720 indexingMapsAttr = llvm::map_to_vector(
6721 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6722 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6723 }
6724 result.addAttribute("indexing_maps",
6725 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6726 return ::parseNamedStructuredOp(parser, result,
6727 BatchReduceMatmulOp::getNumRegionArgs(),
6728 BatchReduceMatmulOp::getRegionBuilder());
6729}
6730
6731void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6732 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6733 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6734 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6735
6736 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6737 p << " indexing_maps = [";
6738 llvm::interleaveComma(getIndexingMaps(), p,
6739 [&](Attribute attr) { p.printAttribute(attr); });
6740 p << "]";
6741 }
6742
6743 SmallVector<StringRef, 3> elidedAttrs = {
6744 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6745 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6746 elidedAttrs);
6747}
6748
6749/// Verify the user defined indexing maps.
6750LogicalResult BatchReduceMatmulOp::verify() {
6751 // Verification of pure batch_reduce_matmul is handled by
6752 // verifyStructuredOpInterface().
6753 if (!hasUserDefinedMaps())
6754 return success();
6755
6756 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6758 return failure();
6759 }
6760 return success();
6761}
6762LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6763 SmallVectorImpl<OpFoldResult> &) {
6764 return memref::foldMemRefCast(*this);
6765}
6766void BatchReduceMatmulOp::getEffects(
6767 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6768 &effects) {
6769 if (hasPureTensorSemantics())
6770 return;
6771 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6772}
6773
6774Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6775 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6776}
6777
6778} // namespace linalg
6779} // namespace mlir
6780
6781//===----------------------------------------------------------------------===//
6782// LinalgDialect
6783//===----------------------------------------------------------------------===//
6784
6785void LinalgDialect::getCanonicalizationPatterns(
6786 RewritePatternSet &results) const {
6787 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6788 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6789}
6790
6791Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6792 Attribute value, Type type,
6793 Location loc) {
6794 return arith::ConstantOp::materialize(builder, value, type, loc);
6795}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
Definition LinalgOps.cpp:95
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:60
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:392
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:462
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
result_iterator result_begin()
Definition Operation.h:438
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
unsigned getNumOperands()
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_type_range getOperandTypes()
Definition Operation.h:422
result_iterator result_end()
Definition Operation.h:439
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override