MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// Tosa dialect interface includes.
43//===----------------------------------------------------------------------===//
44
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49
50namespace {
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52
53//===----------------------------------------------------------------------===//
54// Dialect Function Inliner Interface.
55//===----------------------------------------------------------------------===//
56struct TosaInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 //===--------------------------------------------------------------------===//
60 // Analysis Hooks.
61 //===--------------------------------------------------------------------===//
62
63 /// All operations can be inlined by default.
64 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65 IRMapping &map) const final {
66 return true;
67 }
68
69 /// All regions with If and While parent operators can be inlined.
70 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71 IRMapping &map) const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
74 }
75};
76
77/// This class implements the bytecode interface for the Tosa dialect.
78struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
81
82 //===--------------------------------------------------------------------===//
83 // Attributes
84
85 Attribute readAttribute(DialectBytecodeReader &reader) const override {
86 return ::readAttribute(getContext(), reader);
87 }
88
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer) const override {
91 return ::writeAttribute(attr, writer);
92 }
93
94 //===--------------------------------------------------------------------===//
95 // Types
96
97 Type readType(DialectBytecodeReader &reader) const override {
98 return ::readType(getContext(), reader);
99 }
100
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer) const override {
103 return ::writeType(type, writer);
104 }
105
106 void writeVersion(DialectBytecodeWriter &writer) const final {
107 // TODO: Populate.
108 }
109
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader) const final {
112 // TODO: Populate
113 reader.emitError("Dialect does not support versioning");
114 return nullptr;
115 }
116
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version) const final {
119 return success();
120 }
121};
122
123} // namespace
124
125//===----------------------------------------------------------------------===//
126// TOSA control flow support.
127//===----------------------------------------------------------------------===//
128
129/// Returns the while loop body.
130SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131 return {&getBodyGraph()};
132}
133
134//===----------------------------------------------------------------------===//
135// TOSA variable operator support.
136//===----------------------------------------------------------------------===//
137
139 return map_to_vector(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
141 });
142}
143
144// returns type of variable op
145RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146 Type elementType = variableOp.getType();
147 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149 return RankedTensorType::get(shape, elementType);
150}
151
152//===----------------------------------------------------------------------===//
153// Tosa dialect initialization.
154//===----------------------------------------------------------------------===//
155
156void TosaDialect::initialize() {
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160 >();
161 addOperations<
162#define GET_OP_LIST
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 >();
165 addAttributes<
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 >();
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
178}
179
180Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
181 Type type, Location loc) {
182 // Tosa dialect constants only support ElementsAttr unlike standard dialect
183 // constant which supports all attributes.
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
187 }
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
191 return nullptr;
192}
193
194//===----------------------------------------------------------------------===//
195// Parsers and printers
196//===----------------------------------------------------------------------===//
197
198namespace {
199
200ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201 DenseElementsAttr &varShapeAttr,
202 TypeAttr &typeAttr) {
203 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
205 return parser.emitError(parser.getCurrentLocation())
206 << "expected ranked type";
207
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
210 ArrayRef<int64_t> shape = shapedType.getShape();
211 Builder builder(parser.getContext());
212 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213 return success();
214 }
215 return parser.emitError(parser.getCurrentLocation())
216 << "expected shaped type";
217}
218
219} // namespace
220
221// parses the optional initial value or type for a tosa variable
222// with initial value:
223// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224//
225// without initial value:
226// tosa.variable @name : tensor<1x8xf32>
228 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229 Attribute &initialValueAttr) {
230 if (succeeded(parser.parseOptionalEqual())) {
231 if (failed(parser.parseAttribute(initialValueAttr))) {
232 return parser.emitError(parser.getCurrentLocation())
233 << "expected attribute";
234 }
235 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237 typeAttr);
238 }
239 return parser.emitError(parser.getCurrentLocation())
240 << "expected Typed attr";
241 }
242
243 initialValueAttr = nullptr;
244 Type parsedType;
245 if (failed(parser.parseColonType(parsedType))) {
246 return parser.emitError(parser.getCurrentLocation())
247 << "expected type after colon";
248 }
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250}
251
253 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254 TypeAttr typeAttr, Attribute initialValueAttr) {
255 bool needsSpace = false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 auto shape =
258 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
261 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
263 p << ": ";
264 p.printAttribute(tensorTypeAttr);
265 needsSpace = true; // subsequent attr value needs a space separator
266 }
267 if (initialValueAttr) {
268 if (needsSpace)
269 p << ' ';
270 p << "= ";
271 p.printAttribute(initialValueAttr);
272 }
273}
274
275namespace {
276
277// parse attributes with special handling for tosa enum attributes
278template <typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
280 NamedAttrList &outAttrs) {
281 llvm::StringRef name;
282 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
283 return failure();
284
285 // special handling: rounding_mode accepts a *bare* RoundingMode enum
286 // keyword.
287 llvm::StringRef kw;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name == "rounding_mode" &&
290 succeeded(parser.parseOptionalKeyword(&kw))) {
291 auto sym = symbolizeRoundingMode(kw);
292 if (!sym)
293 return parser.emitError(parser.getCurrentLocation())
294 << "invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
296 outAttrs.push_back(NamedAttribute(name, attr));
297 return success();
298 }
299 }
300 // special handling: mode accepts a *bare* ResizeMode enum keyword.
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
302 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
303 auto sym = symbolizeResizeMode(kw);
304 if (!sym)
305 return parser.emitError(parser.getCurrentLocation())
306 << "invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
308 outAttrs.push_back(NamedAttribute(name, attr));
309 return success();
310 }
311 }
312 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
313 // keyword.
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
315 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
316 auto sym = symbolizeNanPropagationMode(kw);
317 if (!sym)
318 return parser.emitError(parser.getCurrentLocation())
319 << "invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
321 outAttrs.push_back(NamedAttribute(name, attr));
322 return success();
323 }
324 }
325
326 // special handling: block_size accepts a *bare* BlockSizeMode enum
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
328 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
329 auto sym = symbolizeBlockSize(kw);
330 if (!sym)
331 return parser.emitError(parser.getCurrentLocation())
332 << "invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
334 outAttrs.push_back(NamedAttribute(name, attr));
335 return success();
336 }
337 }
338
339 // Default path: parse any normal attribute literal, including fully qualified
340 // enum keyword
341 Attribute attr;
342 return parser.parseAttribute(attr, name, outAttrs);
343}
344
345template <typename EnumType>
346ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
347 // parse operands
349 if (parser.parseCommaSeparatedList(
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
351 return failure();
352
353 // Parse { attr-dict } with special handling for enum bare token
354 NamedAttrList attrs;
355 if (succeeded(parser.parseOptionalLBrace()) &&
356 failed(parser.parseOptionalRBrace())) {
357 do {
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
359 return failure();
360 } while (succeeded(parser.parseOptionalComma()));
361 if (parser.parseRBrace())
362 return failure();
363 }
364
365 FunctionType fnTy;
366 if (parser.parseColonType(fnTy))
367 return failure();
368
369 // Resolve operands and types
370 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
371 parser.getCurrentLocation(),
372 result.operands)))
373 return failure();
374
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
377
378 return success();
379}
380
381void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
382 parser << namedAttr.getName().strref() << " = ";
383 auto attr = namedAttr.getValue();
384 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 } else if (auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
393 } else {
394 parser.printAttribute(attr);
395 }
396}
397
398// print with special handling for default valued NanPropagationMode attribute
399void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
400 parser << " ";
401 parser.printOperands(op->getOperands());
402
403 NamedAttrList toPrint(op->getAttrs());
404 // remove default NanPropagate attribute
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
406 for (auto attr : op->getAttrs()) {
407 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
409 // elide from toPrint
410 toPrint.erase(attr.getName());
411 break;
412 }
413 }
414 }
415
416 if (!toPrint.empty()) {
417 parser << " {";
418 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
420 });
421 parser << "}";
422 }
423
424 parser << " : ";
425 parser.printFunctionalType(op);
426}
427
428// print with special handling for enums: RoundingMode, ResizeMode
429void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
430 parser << " ";
431 parser.printOperands(op->getOperands());
432
433 if (!op->getAttrs().empty()) {
434 parser << " {";
435 llvm::interleaveComma(op->getAttrs(), parser,
436 [&](const NamedAttribute namedAttr) {
437 printNamedAttr(parser, namedAttr);
438 });
439 parser << "}";
440 }
441
442 parser << " : ";
443 parser.printFunctionalType(op);
444}
445
446} // namespace
447
448ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
449 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
450}
451
452void RescaleOp::print(OpAsmPrinter &parser) {
453 printWithEnumHandling(parser, *this);
454}
455
456ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
457 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
458}
459
460void ApplyScaleOp::print(OpAsmPrinter &parser) {
461 printWithEnumHandling(parser, *this);
462}
463
464ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
465 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
466}
467
468void ResizeOp::print(OpAsmPrinter &parser) {
469 printWithEnumHandling(parser, *this);
470}
471
472ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
474}
475
476void ArgMaxOp::print(OpAsmPrinter &parser) {
477 printWithNanPropagationHandling(parser, *this);
478}
479
480ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
482}
483
484void MaxPool2dOp::print(OpAsmPrinter &parser) {
485 printWithNanPropagationHandling(parser, *this);
486}
487
488ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
490}
491
492void ClampOp::print(OpAsmPrinter &parser) {
493 printWithNanPropagationHandling(parser, *this);
494}
495
496ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
498}
499
500void MaximumOp::print(OpAsmPrinter &parser) {
501 printWithNanPropagationHandling(parser, *this);
502}
503
504ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
506}
507
508void MinimumOp::print(OpAsmPrinter &parser) {
509 printWithNanPropagationHandling(parser, *this);
510}
511
512ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
514}
515
516void ReduceMaxOp::print(OpAsmPrinter &parser) {
517 printWithNanPropagationHandling(parser, *this);
518}
519
520ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
522}
523
524void ReduceMinOp::print(OpAsmPrinter &parser) {
525 printWithNanPropagationHandling(parser, *this);
526}
527
528ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
531}
532
533void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
534 printWithEnumHandling(parser, *this);
535}
536
537ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
540}
541
542void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *this);
544}
545
546ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
549}
550
551void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
552 printWithEnumHandling(parser, *this);
553}
554
555ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
558}
559
560void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
561 printWithEnumHandling(parser, *this);
562}
563
564//===----------------------------------------------------------------------===//
565// Tosa utilities.
566//===----------------------------------------------------------------------===//
567
568static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
569 if (lhs % rhs != 0)
570 return std::nullopt;
571 return lhs / rhs;
572}
573
575 auto srcType = getElementTypeOrSelf(type);
576 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
577 srcType = getStorageElementTypeFromQuantized(quantType);
578 return srcType;
579}
580
584
585static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
586 Value valZp, StringRef name) {
588 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
589
590 bool bothInts =
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
592 bool sameBitWidth =
593 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
594
595 if (!bothInts || !sameBitWidth) {
596 return op->emitOpError()
597 << "expected " << name << " and " << name
598 << "_zp to both be integer of the same bitwidth, but got " << eType
599 << " vs. " << eZpType;
600 }
601 return success();
602}
603
604// Create a pad-const const tensor with value of `val` of required data-type
606 Value src, int32_t val) {
607 const auto srcType = getElementTypeOrSelf(src);
608 const auto srcElemType = getStorageElementTypeOrSelf(src);
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
613 ? DenseElementsAttr::get(padConstEType,
614 builder.getFloatAttr(srcElemType, val))
615 : DenseElementsAttr::get(padConstEType,
616 builder.getIntegerAttr(srcElemType, val))};
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
618}
619
621 if (dyn_cast<tosa::mxint8Type>(type))
622 return 8;
623 return type.getIntOrFloatBitWidth();
624}
625
626// Update dim size if current dim is dynamic, otherwise raise an error if sizes
627// do not match
628LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
629 const int64_t newDim,
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
633 currDim = newDim;
634 return success();
635 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
636 return op->emitOpError("expected ")
637 << dimName << " of " << operandName << " to match size " << currDim
638 << ", got " << newDim;
639 }
640 return success();
641}
642
644 Operation *op, const int64_t inputSize, const int64_t kernelSize,
645 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
646 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
650 return success();
651
652 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
653
654 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
656 stride);
657 if (!calculatedOutSizeMinusOne.has_value())
658 return op->emitOpError("expected input_")
659 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
660 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
661 << dimAxis << " to be wholly divisible by stride_" << dimAxis
662 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
663 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
664 << ") / " << stride;
665
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
668 return op->emitOpError("calculated output ")
669 << dimName << " did not match expected: "
670 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
671
672 return success();
673}
674
675//===----------------------------------------------------------------------===//
676// TOSA Operator Verifiers.
677//===----------------------------------------------------------------------===//
678
679template <typename T>
680static LogicalResult verifyConvOp(T op) {
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
683
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
686 auto biasEType =
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
688 auto resultEType =
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
692
693 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
694 inputEType = getStorageElementTypeFromQuantized(quantType);
695
696 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
697 weightEType = getStorageElementTypeFromQuantized(quantType);
698
699 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
700 biasEType = getStorageElementTypeFromQuantized(quantType);
701
702 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
703 resultEType = getStorageElementTypeFromQuantized(quantType);
704
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
706 // for now, only enforce bias element type == result element type for
707 // float types.
708 op.emitOpError(
709 "expect both bias and result to have same element type, got ")
710 << biasEType << " and " << resultEType;
711 return failure();
712 }
713
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
717 op.emitOpError(
718 "expect both input and weight to have same element type, got ")
719 << inputEType << " and " << weightEType;
720 return failure();
721 }
722 }
723
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
726
727 // Either both must be float or both non-float.
728 if (inputIsFloat != weightIsFloat) {
729 op.emitOpError(
730 "expect both input and weight to be float or not together, got ")
731 << inputEType << " and " << weightEType;
732 return failure();
733 }
734
735 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
736 if (inputEType != inputZpEType) {
737 return op.emitOpError("expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType << " and " << inputZpEType;
740 }
741
742 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
743 if (weightEType != weightZpEType) {
744 return op.emitOpError("expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType << " and " << weightZpEType;
747 }
748
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
751 return failure();
752
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
755 return failure();
756
757 return success();
758}
759
760LogicalResult tosa::ConstOp::verify() {
761
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
764
765 if (!attrType || !outputType) {
766 emitOpError("expected tensors for attr/result type");
767 return failure();
768 }
769
770 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
772 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
773 return success();
774 }
775
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError("expected same attr/result element types");
778 return failure();
779 }
780
781 return success();
782}
783
784template <typename T>
785static LogicalResult verifyConvOpModes(T op) {
786 auto inputEType =
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
788
789 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
790 inputEType = getStorageElementTypeFromQuantized(quantType);
791
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
795 << accType;
796
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
799 << accType;
800
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
804 << accType;
805
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
809 << accType;
810
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
813 << accType;
814
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
817 << accType;
818
819 auto resultEType =
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
821
822 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
823 resultEType = getStorageElementTypeFromQuantized(quantType);
824
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// ERROR_IF functions.
830// ERROR_IF is a predicate that must set an error if the condition holds.
831//===----------------------------------------------------------------------===//
832
833template <typename T>
834static LogicalResult verifyConvOpErrorIf(T op) {
835 llvm::ArrayRef<int64_t> padding = op.getPad();
836 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
837 return op.emitOpError("expect all padding values to be >= 0, got ")
838 << padding;
839
840 llvm::ArrayRef<int64_t> strides = op.getStride();
841 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
842 return op.emitOpError("expect all stride values to be >= 1, got ")
843 << strides;
844
845 llvm::ArrayRef<int64_t> dilations = op.getDilation();
846 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
847 return op.emitOpError("expect all dilation values to be >= 1, got ")
848 << dilations;
849
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
852 if (!outputType)
853 // Skip following checks if output is not ranked
854 return success();
855
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
860
861 if (inputType && weightType) {
862 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
864 if (failed(verifyConvOutputSize(
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0], "height", "y", "top", "bottom")))
868 return failure();
869
870 if (failed(verifyConvOutputSize(
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1], "width", "x", "left", "right")))
874 return failure();
875 }
876
877 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
879 if (failed(verifyConvOutputSize(
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0], "height", "y", "top", "bottom")))
883 return failure();
884
885 if (failed(verifyConvOutputSize(
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1], "width", "x", "left", "right")))
889 return failure();
890 }
891
892 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
894 if (failed(verifyConvOutputSize(
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0], "depth", "d", "front", "back")))
898 return failure();
899
900 if (failed(verifyConvOutputSize(
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1], "height", "y", "top", "bottom")))
904 return failure();
905
906 if (failed(verifyConvOutputSize(
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2], "width", "x", "left", "right")))
910 return failure();
911 }
912 }
913
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
916 if (!biasType)
917 // Skip following checks if bias is not ranked
918 return success();
919
920 const int64_t biasChannels = biasType.getDimSize(0);
921 const int64_t outputChannels =
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
925 // Skip following checks if biasChannels or outputChannels is dynamic dim
926 return success();
927
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels << ") or 1, got " << biasChannels;
932
933 return success();
934}
935
936// Verify whether same type and shape of the given two types.
937static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
938 StringRef name1, Type type2,
939 StringRef name2) {
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
943 return failure();
944
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
948 return op->emitOpError()
949 << "require same element type for " << name1 << " (" << elemType1
950 << ") and " << name2 << " (" << elemType2 << ")";
951
952 if (failed(verifyCompatibleShape(type1, type2)))
953 return op->emitOpError()
954 << "require same shapes for " << name1 << " (" << type1 << ") and "
955 << name2 << " (" << type2 << ")";
956
957 return success();
958}
959
960// Verify whether same length, type, and shape of the given two tensor lists.
961static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
962 StringRef name1,
963 ValueRange list2,
964 StringRef name2) {
965 if (list1.size() != list2.size())
966 return op->emitOpError()
967 << "require same number of values in " << name1 << " ("
968 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
969
970 for (auto [type1, type2] :
971 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
972 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
973 return failure();
974 }
975
976 return success();
977}
978
979static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
980 ShapeAdaptor shapeAdaptor(type);
981 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
982 return success();
983
984 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
985}
986
987template <typename T>
988static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
989 Operation *symTableOp =
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
991 if (!symTableOp)
992 // If the operation is not the scope of a symbol table, we cannot
993 // verify it against it's declaration.
994 return success();
995
996 SymbolTable symTable(symTableOp);
997 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
998
999 // Verify prior declaration
1000 if (!varOp)
1001 return op->emitOpError("'")
1002 << op.getName() << "' has not been declared by 'tosa.variable'";
1003
1004 // Verify type and shape
1005 auto variableType = getVariableType(varOp);
1006 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1007 "the input tensor")
1008 .failed())
1009 return failure();
1010 return success();
1011}
1012
1013// verify that inType and outType have same element types
1014template <typename T>
1015static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1016 StringRef aName = "input",
1017 StringRef bName = "output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1020 if (!aTType) {
1021 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1022 return failure();
1023 }
1024 if (!bTType) {
1025 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1026 return failure();
1027 }
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1030 auto aQuantType =
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1032 auto bQuantType =
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1037 // only check if both element types are int/index/float/UniformQuantized
1038 // eg, not sure how to check quant::QuantizedType
1039 // this happens in test_conv2d_q_grouped_convolution in
1040 // tfl-to-tosa-pipeline.mlir
1041 op.emitOpError("expect ")
1042 << aName << " and " << bName << " to have same element type, got "
1043 << aElementType << " and " << bElementType;
1044 return failure();
1045 }
1046 return success();
1047}
1048
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1051
1052 // Ensure output is of 32-bit integer
1053 if (const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError("result tensor is not of integer type");
1056
1057 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1058 if (!inputType.hasRank())
1059 return success();
1060
1061 // Ensure axis is within the tensor rank
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError("specified axis is outside the rank of the tensor");
1065
1066 if (!resultType.hasRank())
1067 return success();
1068
1069 const ArrayRef<int64_t> inputShape = inputType.getShape();
1070 const ArrayRef<int64_t> outputShape = resultType.getShape();
1071 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1073 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1074 return emitOpError("expected output shape '")
1075 << expectedOutputShape << "', got '" << outputShape << "'";
1076
1077 return success();
1078}
1079
1080template <typename T>
1081static LogicalResult verifyPoolingOp(T op) {
1082 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1083 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1084 return op.emitOpError("expect all kernel values to be >= 1, got ")
1085 << kernel;
1086
1087 const llvm::ArrayRef<int64_t> strides = op.getStride();
1088 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1089 return op.emitOpError("expect all stride values to be >= 1, got ")
1090 << strides;
1091
1092 const llvm::ArrayRef<int64_t> padding = op.getPad();
1093 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1094 return op.emitOpError("expect all padding values to be >= 0, got ")
1095 << padding;
1096
1097 // Padding must be less than kernel size to avoid a divide-by-zero
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError("expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1105
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError("expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop << ", pad_bottom=" << padBottom
1113 << ", kernel_y=" << kernelY;
1114
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1120 return success();
1121
1122 const auto verifyOutputSize =
1123 [&op](const int64_t inputSize, const int64_t outputSize,
1124 const int64_t kernelSize, const int64_t strideSize,
1125 const int64_t padBefore, const int64_t padAfter,
1126 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1130 return success();
1131
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError("expected input_")
1136 << dimName << " + pad_" << padBeforeName << " + pad_"
1137 << padAfterName << " - kernel_" << dimAxis
1138 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1139 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1140 << kernelSize << ") / " << strideSize;
1141
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError("calculated output ")
1145 << dimName << " did not match expected: " << "calculated="
1146 << calculatedOutSize << ", expected=" << outputSize;
1147
1148 return success();
1149 };
1150
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height", "y", "top", "bottom")))
1154 return failure();
1155
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width", "x", "left", "right")))
1159 return failure();
1160
1161 return success();
1162}
1163
1164LogicalResult tosa::AvgPool2dOp::verify() {
1165 if (failed(verifyPoolingOp(*this)))
1166 return failure();
1167
1168 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1169 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1170 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1171 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1172
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError("accumulator type for integer tensor is not i32");
1176
1177 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1179
1180 if (inputETy.isBF16() && !accType.isF32())
1181 return emitOpError("accumulator type for bf16 tensor is not f32");
1182
1183 if (inputETy.isF32() && !accType.isF32())
1184 return emitOpError("accumulator type for f32 tensor is not f32");
1185
1186 if (inputETy != inputZpETy)
1187 return emitOpError("expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy << " and " << inputZpETy;
1190
1191 if (resultETy != outputZpETy)
1192 return emitOpError("expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy << " and " << outputZpETy;
1195
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1198 return failure();
1199
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1202 return failure();
1203
1204 return success();
1205}
1206
1207LogicalResult tosa::ClampOp::verify() {
1208 mlir::Type inputETy =
1209 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1210 if (auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1212 inputETy = getStorageElementTypeFromQuantized(quantType);
1213 }
1214 mlir::Type outputETy =
1215 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1216 if (auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1218 outputETy = getStorageElementTypeFromQuantized(quantType);
1219 }
1220 if (inputETy != outputETy)
1221 return emitOpError("input/output element types are incompatible.");
1222
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1225
1226 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1227
1228 if (inputETy.isInteger(dataTypeBitWidth)) {
1229 // if input datatype is integer, check that the min_val/max_val attributes
1230 // are integer attributes, and that their type is the same as the input's
1231 // datatype
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError("min/max attributes types are incompatible with "
1238 "input/output element types.");
1239
1240 const bool isUnsigned = inputETy.isUnsignedInteger();
1241 const bool isBoolean = inputETy.isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError("expected min_val <= max_val, got min_val=")
1246 << minValAttr << ", max_val=" << maxValAttr;
1247 } else {
1248 // otherwise, input datatype is float, check that the min_val/max_val
1249 // attributes share the same type and that their type is the same as the
1250 // input's datatype
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError("min/max attributes types are incompatible with "
1257 "input/output element types.");
1258
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr << ", max_val=" << maxValAttr;
1264
1265 if (maxVal < minVal)
1266 return emitOpError("expected min_val <= max_val, got min_val=")
1267 << minValAttr << ", max_val=" << maxValAttr;
1268 }
1269
1270 return success();
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// TOSA Operator Quantization Builders.
1275//===----------------------------------------------------------------------===//
1276
1277/// This builder is called on all convolution operators except TransposeConv,
1278/// which has specialized output shape semantics. The builder also defines the
1279/// bitwidth of the output given the bit width of the input & weight content.
1281 Type outputType, Value input, Value weight,
1282 Value bias, DenseI64ArrayAttr pad,
1283 DenseI64ArrayAttr stride,
1284 DenseI64ArrayAttr dilation,
1285 TypeAttr accType) {
1286 auto zps = createZPsAsConst(builder, input, weight);
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute("pad", pad);
1289 result.addAttribute("stride", stride);
1290 result.addAttribute("dilation", dilation);
1291 result.addAttribute("acc_type", accType);
1292 Type finalOutputType = outputType;
1293 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1294 if (quantAttr) {
1295 finalOutputType =
1296 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1297 }
1298 result.addTypes(finalOutputType);
1299}
1300
1301/// Handles tosa.transpose_conv2d which has outpad and output shape
1302/// attributes.
1303static void
1305 Type outputType, Value input, Value weight,
1306 Value bias, DenseI64ArrayAttr outpad,
1307 DenseI64ArrayAttr stride, TypeAttr accType) {
1308 auto zps = createZPsAsConst(builder, input, weight);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute("out_pad", outpad);
1311 result.addAttribute("stride", stride);
1312 result.addAttribute("acc_type", accType);
1313 Type finalOutputType = outputType;
1314 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1315 if (quantAttr) {
1316 finalOutputType =
1317 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1318 }
1319 result.addTypes(finalOutputType);
1320}
1321
1322/// The tosa.matmul op is also intended to be generated where a fully_connected
1323/// op must be constructed where the weight is not a constant. In this case,
1324/// the fully_connected op must be expressed using matmul.
1325/// TODO: Add link to the leglization document explaining this.
1327 OperationState &result, Type outputType,
1328 Value a, Value b) {
1329 auto zps = createZPsAsConst(builder, a, b);
1330 result.addOperands({a, b, zps.first, zps.second});
1331
1332 Type finalOutputType{outputType};
1333 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1334 auto eType = getStorageElementTypeOrSelf(a.getType());
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1336
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType && "Output must be a shaped type");
1339
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1342 accElementType = builder.getIntegerType(48);
1343 else
1344 accElementType = builder.getI32Type();
1345
1346 finalOutputType = outputShapedType.clone(accElementType);
1347 }
1348 result.addTypes(finalOutputType);
1349}
1350
1351/// Both the tosa.avg_pool2d and unary ops use the same
1352/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1353/// has additional parameters not part of the unary ops.
1354static void
1356 Type outputType, Value input,
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1359 const Location loc{result.location};
1360 int64_t inputZp{0};
1361 int64_t outputZp{0};
1362
1363 if (auto quantAttr =
1364 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1367 }
1368 const std::optional<Value> inputZpOp =
1369 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1370 if (!inputZpOp) {
1371 (void)emitError(
1372 loc,
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1374 }
1375 const std::optional<Value> outputZpOp =
1376 createZeroPointTensor(builder, loc, outputType, outputZp);
1377 if (!outputZpOp) {
1378 (void)emitError(loc, "Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1380 }
1381
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1384 } else {
1385 // failed to create one or more zero points above: just add input as
1386 // operands this will trigger error in building the op because of missing
1387 // zero points
1388 result.addOperands({input});
1389 }
1390 result.addAttribute("kernel", kernel);
1391 result.addAttribute("stride", stride);
1392 result.addAttribute("pad", pad);
1393 result.addAttribute("acc_type", accType);
1394 result.types.push_back(outputType);
1395}
1396
1397/// This builder is called on single-parameter negate operator
1398/// to construct input and output zero points based on their
1399/// types.
1401 OperationState &result, Type outputType,
1402 Value input) {
1403 const Location loc{result.location};
1404 int64_t input1Zp{0};
1405 int64_t outputZp{0};
1406 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1407 if (quantAttr) {
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1410 }
1411 const std::optional<Value> input1ZpOp =
1412 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1413 if (!input1ZpOp) {
1414 (void)emitError(
1415 loc, "Failed to create input1 zero point for quantized NEGATE op");
1416 }
1417
1418 const std::optional<Value> outputZpOp =
1419 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1420 if (!outputZpOp) {
1421 (void)emitError(
1422 loc, "Failed to create output zero point for quantized NEGATE op");
1423 }
1424
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1427 } else {
1428 // failed to create one or more zero points above: just add input as
1429 // operands. This will trigger error in building the op because of
1430 // missing zero points
1431 result.addOperands({input});
1432 }
1433
1434 result.types.push_back(outputType);
1435}
1436
1437/// This builder is called on TOSA pad operator that needs to create its own
1438/// OptionalAttr quantization_attr parameter to scale the padding values
1439/// correctly. No pad_const is interpreted as zero-padding.
1441 Type outputType, Value input,
1442 Value paddings) {
1443 const Location loc{result.location};
1444 int32_t zp{0};
1445 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1446 if (quantAttr) {
1447 zp = static_cast<int32_t>(quantAttr.getInputZp());
1448 }
1449 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1452}
1453
1455 StringRef name, Type variableType,
1456 Attribute initialValue) {
1457 const Location loc{result.location};
1458 auto nameAttr = builder.getStringAttr(name);
1459
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1461 if (!shapedType) {
1462 (void)emitError(loc, "variable type must be a shaped type");
1463 return;
1464 }
1465 if (!shapedType.hasRank()) {
1466 (void)emitError(loc, "variable type must be a ranked type");
1467 return;
1468 }
1469
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1472 ArrayRef<int64_t> shape = shapedType.getShape();
1473 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1474
1475 result.addAttribute("sym_name", nameAttr);
1476 result.addAttribute("var_shape", varShapeAttr);
1477 result.addAttribute("type", elementTypeAttr);
1478 result.addAttribute("initial_value", initialValue);
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// TOSA Operator Return Type Inference.
1483//===----------------------------------------------------------------------===//
1484static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1485 const int64_t dim2) {
1486 if (dim1 == 1)
1487 return dim2;
1488 if (dim2 == 1)
1489 return dim1;
1490
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1492 return failure();
1493
1494 // Prefer static dimension over dynamic
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1496}
1497
1498static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1499 SmallVector<int64_t> &outShape) {
1500 int64_t outRank = 0;
1501 for (int i = 0, e = operands.size(); i != e; ++i) {
1502 auto shape = operands.getShape(i);
1503 if (!shape.hasRank()) {
1504 // TODO(jennik): Update function to have better case handling for
1505 // invalid operands and for ranked tensors.
1506 return failure();
1507 }
1508 outRank = std::max<int64_t>(outRank, shape.getRank());
1509 }
1510
1511 outShape.resize(outRank, 1);
1512
1513 for (int i = 0, e = operands.size(); i != e; ++i) {
1514 auto shape = operands.getShape(i);
1515 auto rankDiff = outShape.size() - shape.getRank();
1516
1517 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 = shape.getDimSize(i);
1520
1521 const FailureOr<int64_t> maybeResolvedDim =
1522 resolveBroadcastDim(dim1, dim2);
1523 if (failed(maybeResolvedDim))
1524 return failure();
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1527 }
1528 }
1529
1530 return success();
1531}
1532
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1536 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1537 ShapeAdaptor inputShape(adaptor.getInput().getType());
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1540
1541 if (!inputShape.hasRank()) {
1542 inferredReturnShapes.push_back(ShapedTypeComponents());
1543 return success();
1544 }
1545
1546 SmallVector<int64_t> outShape;
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1549 if (i == axisVal)
1550 continue;
1551 outShape.push_back(inputShape.getDimSize(i));
1552 }
1553
1554 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1555 return success();
1556}
1557
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1561 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1563
1564 if (!inputShape.hasRank())
1565 return failure();
1566
1567 llvm::SmallVector<int64_t> outputShape;
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1572
1573 // Note that we can support this calculation symbolically
1574 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1577
1578 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1579 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1580
1581 return success();
1582}
1583
1584static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1587 if (!isPowerOfTwo)
1588 return op->emitOpError("expected ")
1589 << dimName << " to be a power of two, got " << dimSize;
1590
1591 return success();
1592}
1593
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1596 if (failed(verifyCompatibleShapes(outputTypes)))
1597 return emitOpError("expected output shapes to match, got ") << outputTypes;
1598
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1601 if (!inputType)
1602 return success();
1603
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1606 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1607 return failure();
1608
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1611 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1612 return failure();
1613
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1615 if (!outputType)
1616 return success();
1617
1618 // Batch and height input/output dimensions should match
1619 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1620 outputType.getShape().drop_back())))
1621 return emitOpError("expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType << " output=" << outputType;
1624
1625 // Output width dimension expected to be input_width / 2 + 1
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1629 return emitOpError(
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1631 << outputWidth;
1632
1633 return success();
1634}
1635
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1639 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1640 inferredReturnShapes.push_back(
1641 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1642 inferredReturnShapes.push_back(
1643 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1644 return success();
1645}
1646
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1652 if (!inputRealType || !inputImagType)
1653 return success();
1654
1655 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a : b;
1657 };
1658
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1662 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1663 return failure();
1664
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1668 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1669 return failure();
1670
1671 return success();
1672}
1673
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1677 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1678 // Infer all dimension sizes by reducing based on inputs.
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1681 llvm::SmallVector<int64_t> outputShape;
1682 bool hasRankedInput = false;
1683 for (auto operand : adaptor.getOperands()) {
1684 ShapeAdaptor operandShape(operand.getType());
1685 if (!operandShape.hasRank())
1686 continue;
1687
1688 // Copy the Operand's rank.
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1691
1692 // Copy shapes until the dim is non-dynamic.
1693 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1695 continue;
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1699 return emitOptionalError(location,
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1702 i);
1703 }
1704
1705 hasRankedInput = true;
1706 }
1707
1708 if (adaptor.getInput1().empty())
1709 return failure();
1710
1711 Type inputType =
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1714 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1715 return success();
1716 }
1717
1718 // Determine the dimension size along the concatenation axis.
1719 int64_t concatDimSize = 0;
1720 for (auto operand : adaptor.getOperands()) {
1721 ShapeAdaptor operandShape(operand.getType());
1722
1723 // We need to know the length of the concatenation axis of all inputs to
1724 // determine the dimension size of the output shape.
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1727 break;
1728 }
1729
1730 concatDimSize += operandShape.getDimSize(axis);
1731 }
1732
1733 outputShape[axis] = concatDimSize;
1734
1735 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1736 return success();
1737}
1738
1739LogicalResult tosa::ConcatOp::verify() {
1740 // check that each input has same element type as output
1741 auto outType = getOutput().getType();
1742 const Operation::operand_range inputList = getInput1();
1743
1744 // Check there is at least one input
1745 if (inputList.empty())
1746 return emitOpError("expect at least one input");
1747
1748 if (!llvm::all_of(inputList, [&](auto input) {
1749 return succeeded(verifySameElementTypes(
1750 *this, /* inType = */ input.getType(), outType));
1751 })) {
1752 return failure();
1753 }
1754
1755 const int32_t axis = getAxis();
1756 ShapeAdaptor firstRankedInputShape = nullptr;
1757 for (const auto &input : inputList) {
1758 const Type inputType = input.getType();
1759 ShapeAdaptor currShape(inputType);
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1762 // Check axis is in expected range
1763 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1764 return emitOpError("expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1766 << axis;
1767 break;
1768 }
1769 }
1770
1771 const auto allOperandsHasRank = [](const Value input) {
1772 return ShapeAdaptor(input.getType()).hasRank();
1773 };
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1775 const int64_t firstInputRank = firstRankedInputShape.getRank();
1776
1777 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1778 const ShapeAdaptor inputShape(input.getType());
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum = index + 1;
1781
1782 // Check that each operand has the same rank
1783 if (inputRank != firstInputRank)
1784 return emitOpError(
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1787 << operandNum;
1788
1789 // Check non-axis dims match
1790 for (int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1792 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1795 continue;
1796 if (inputDim != firstInputDim)
1797 return emitOpError("expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim << " vs " << firstInputDim << " at index " << i
1800 << " on operands 0 and " << operandNum;
1801 }
1802 }
1803
1804 const ShapeAdaptor outputShape(outType);
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError("expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() << " vs " << firstInputRank;
1808
1809 // ERROR_IF(axis_sum != shape[axis]);
1810 int64_t axisSum = 0;
1811 for (const auto &input : inputList) {
1812 const ShapeAdaptor inputShape(input.getType());
1813 if (inputShape.isDynamicDim(axis)) {
1814 // make axisSum negative to indicate invalid value
1815 axisSum = -1;
1816 break;
1817 }
1818 axisSum += inputShape.getDimSize(axis);
1819 }
1820
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError("requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum << " and " << outputShape.getDimSize(axis);
1827 }
1828
1829 return success();
1830}
1831
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1834 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
1835 RegionRange regions,
1836 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1837 auto elementType = IntegerType::get(context, /*width=*/1);
1838
1840 if (resolveBroadcastShape(operands, outShape).failed()) {
1841 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1842 return success();
1843 }
1844
1845 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1846 return success();
1847}
1848
1849bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1850 if (l.size() != r.size() || l.size() != 1)
1851 return false;
1852 return succeeded(verifyCompatibleShape(l[0], r[0]));
1853}
1854
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1858 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1859 ShapeAdaptor lhsShape(adaptor.getA().getType());
1860 ShapeAdaptor rhsShape(adaptor.getB().getType());
1861
1862 // All shapes are dynamic.
1863 SmallVector<int64_t> outShape;
1864 outShape.resize(3, ShapedType::kDynamic);
1865
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1869 }
1870
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1873 : outShape[0];
1874 outShape[2] = rhsShape.getDimSize(2);
1875 }
1876
1877 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1878 return success();
1879}
1880
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1884
1885 // Must be shaped tensor types
1886 if (!aType)
1887 return emitOpError("expect a shaped tensor for input a, got ")
1888 << getA().getType();
1889
1890 if (!bType)
1891 return emitOpError("expect a shaped tensor for input b, got ")
1892 << getB().getType();
1893
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1896
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1901
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError("expect operands to be both quantized or both not "
1905 "quantized, got ")
1906 << aElementType << " and " << bElementType;
1907 }
1908 // both a and b have quantized element types
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError("expect quantized operands to have same widths, got ")
1913 << aQuantWidth << " and " << bQuantWidth;
1914 }
1915 }
1916
1917 // check a_zp and b_zp
1918 auto aEType = getStorageElementTypeOrSelf(aType);
1919 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1920 if (aEType != aZpEType) {
1921 return emitOpError("expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType << " and " << aZpEType;
1924 }
1925
1926 auto bEType = getStorageElementTypeOrSelf(bType);
1927 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1928 if (bEType != bZpEType) {
1929 return emitOpError("expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType << " and " << bZpEType;
1932 }
1933
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1936 return failure();
1937
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1940 return failure();
1941
1942 return success();
1943}
1944
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1948 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1949 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1950
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1955 }
1956
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1960 : outShape[0];
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1962 : outShape[1];
1963 }
1964
1965 // If B batch size is 1, it is broadcast across A's batch size
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1970 outShape[0] =
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1973 }
1974
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1979 outShape[0] =
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1982 : outShape[2];
1983 }
1984
1985 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1986 return success();
1987}
1988
1989LogicalResult MatmulTBlockScaledOp::verify() {
1990 // Verify same input data types
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1993 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1994 "B_data")))
1995 return failure();
1996
1997 // Verify input shape compatibility
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2001 int64_t W = ShapedType::kDynamic;
2002 int64_t C = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2004
2005 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2006 if (aDataShape.hasRank()) {
2007 N = aDataShape.getDimSize(0);
2008 H = aDataShape.getDimSize(1);
2009 C = aDataShape.getDimSize(2);
2010 }
2011
2012 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2013 if (aScaleShape.hasRank()) {
2014 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2015 "a_scale", "batch")) ||
2016 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2017 "a_scale", "height")))
2018 return failure();
2019 multiplesOfC = aScaleShape.getDimSize(2);
2020 }
2021
2022 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2023 if (bDataShape.hasRank()) {
2024 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2025 "b_data", "batch")) ||
2026 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2027 "b_data", "channels")))
2028 return failure();
2029 W = bDataShape.getDimSize(1);
2030 }
2031
2032 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2033 if (bScaleShape.hasRank()) {
2034 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2035 "b_scale", "batch")) ||
2036 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2037 "b_scale", "width")) ||
2038 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2039 bScaleShape.getDimSize(2), "b_scale",
2040 "C/block_size")))
2041 return failure();
2042 }
2043
2044 // Verify batch size is broadcast compatible
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError("expect B matrix batch size to be broadcast compatible "
2047 "with A, got D=")
2048 << D << " vs N=" << N;
2049
2050 // Verify C is a multiple of block size
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError("expect C to be a multiple of block size, got C=")
2054 << C << ", block_size=" << blockSize;
2055
2056 // Verify multiplesOfC is C / block size
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2059 return emitOpError(
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2062
2063 // Verify output shape
2064 N = ShapedType::isDynamic(N) ? D : N;
2065 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2066 const auto outputType = cast<ShapedType>(getResult().getType());
2067 if (outputType.hasRank() &&
2068 failed(
2069 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2070 InFlightDiagnostic opError = emitOpError("expected output shape ");
2071 auto stringifyDim = [&](int64_t d) {
2072 if (ShapedType::isDynamic(d))
2073 opError << "?";
2074 else
2075 opError << d;
2076 };
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError << " to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2080 return opError;
2081 }
2082
2083 return success();
2084}
2085
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2089 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2091 auto paddingRank =
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2093 SmallVector<int64_t> outputShape;
2094
2095 // If the input rank is unknown, we can infer the output rank using the
2096 // padding shape's rank divided by 2.
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2099 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2100 return success();
2101 }
2102
2103 SmallVector<int64_t> paddingValues;
2104 // If the paddings value is not a constant, all dimensions must be dynamic.
2105 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2106 paddingValues)) {
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2108 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2109 return success();
2110 }
2111
2112 outputShape.reserve(inputShape.getRank());
2113 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2116 continue;
2117 }
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2121 // if either padding for dim i is -1, output dim is unknown
2122 outputShape.push_back(ShapedType::kDynamic);
2123 continue;
2124 }
2125
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2127 }
2128
2129 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2130 return success();
2131}
2132
2133LogicalResult tosa::PadOp::verify() {
2134 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2135 /* outType = */ getOutput().getType())
2136 .failed()) {
2137 return failure();
2138 }
2139
2140 if (auto padConst = getPadConst()) {
2141 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2142 /* outType = */ getOutput().getType())
2143 .failed()) {
2144 return failure();
2145 }
2146 }
2147
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2152 if (!inputType || !outputType)
2153 return success();
2154
2155 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2156 "output")))
2157 return failure();
2158
2159 auto inputRank = inputType.getRank();
2160 DenseIntElementsAttr paddingAttr;
2161 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2162 return success();
2163
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2166 return emitOpError() << "padding tensor must have " << inputRank
2167 << " * 2 = " << inputRank * 2 << " elements, but got "
2168 << paddingValues.size();
2169
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2172
2173 for (int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2176
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2178 return emitOpError()
2179 << "invalid padding values at dimension " << i
2180 << ": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart << ", " << padEnd << "]";
2182 }
2183
2184 // Skip shape verification for dynamic input/output
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2187 continue;
2188
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() << "mismatch in output shape at dimension " << i
2191 << ": expected " << inputShape[i] << " + "
2192 << padStart << " + " << padEnd << " = "
2193 << (inputShape[i] + padStart + padEnd)
2194 << ", but got " << outputShape[i];
2195 }
2196 }
2197
2198 return success();
2199}
2200
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2204 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2205
2206 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2209
2210 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2211 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2213 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2214 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2215 return success();
2216 }
2217
2218 // if size[i] is -1, all remaining elements in dimension i are included
2219 // in the slice, similar to TF.
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2221 // initialize outputShape to all unknown
2222 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2223 if (inputShape.hasRank()) {
2224 for (size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2228 // size[i] is not 0 and not < -1, and start[i] is in valid range
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2230 // input shape has unknown dim[i] - only valid if size[i] > 0
2231 if (size[i] > 0) {
2232 outputShape[i] = size[i];
2233 }
2234 } else {
2235 // input shape has known dim[i]
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2239 // start[i] + size[i] is within bound of input shape's dim[i]
2240 outputShape[i] = size[i];
2241 }
2242 }
2243 }
2244 }
2245 } else {
2246 outputShape = convertToMlirShape(size);
2247 }
2248 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2249 return success();
2250}
2251
2252LogicalResult tosa::SliceOp::verify() {
2253 const Value input = getInput1();
2254 const Value output = getOutput();
2255 if (verifySameElementTypes(*this, /* inType = */ input.getType(),
2256 /* outType = */ output.getType())
2257 .failed())
2258 return failure();
2259
2260 const Value start = getStart();
2261 const Value size = getSize();
2262 const ShapeAdaptor inputShape(input.getType());
2263 const ShapeAdaptor outputShape(output.getType());
2264
2265 if (inputShape.hasRank()) {
2266 const auto inputRank = inputShape.getRank();
2267 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2268 return emitOpError(
2269 "expect input1 and output to have the same ranks, got ")
2270 << inputRank << " and " << outputShape.getRank();
2271
2272 const auto startShapeRank =
2273 llvm::cast<tosa::shapeType>(start.getType()).getRank();
2274 if (inputRank != startShapeRank)
2275 return emitOpError("length of start is not equal to rank of input shape");
2276
2277 const auto sizeShapeRank =
2278 llvm::cast<tosa::shapeType>(size.getType()).getRank();
2279 if (inputRank != sizeShapeRank)
2280 return emitOpError("length of size is not equal to rank of input shape");
2281 }
2282
2283 SmallVector<int64_t> startValues;
2284 tosa::getConstShapeValues(start.getDefiningOp(), startValues);
2285 if (startValues.size()) {
2286 if (llvm::any_of(startValues, [](const int64_t v) {
2287 return v < 0 && v != kInferableDimSize;
2288 }))
2289 return emitOpError("start values must be non-negative, got [")
2290 << startValues << "]";
2291 }
2292
2293 SmallVector<int64_t> sizeValues;
2294 if (!tosa::getConstShapeValues(size.getDefiningOp(), sizeValues))
2295 return success();
2296
2297 if (llvm::any_of(sizeValues, [](const int64_t v) {
2298 return v <= 0 && v != kInferableDimSize;
2299 }))
2300 return emitOpError("size values must be > 0, got [") << sizeValues << "]";
2301 if (outputShape.hasRank()) {
2302 SmallVector<int64_t> outputDims;
2303 outputShape.getDims(outputDims);
2304 const bool hasNoInferableDims = llvm::all_of(
2305 sizeValues, [](const int64_t v) { return v != kInferableDimSize; });
2306 if (hasNoInferableDims &&
2307 failed(verifyCompatibleShape(outputDims, sizeValues)))
2308 return emitOpError("expected output shape to match size values, got ")
2309 << output.getType() << " vs [" << sizeValues << "]";
2310 }
2311
2312 if (inputShape.hasRank() && startValues.size()) {
2313 SmallVector<int64_t> inputDims;
2314 inputShape.getDims(inputDims);
2315 for (const auto &[index, vals] :
2316 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2317 const auto &[start, size, inputDim] = vals;
2318 if (start == kInferableDimSize || size == kInferableDimSize ||
2319 ShapedType::isDynamic(inputDim))
2320 continue;
2321 if (start + size > inputDim)
2322 return emitOpError("start + size must be less than or equal to input "
2323 "dimension size, got start=")
2324 << start << ", size=" << size
2325 << " vs input dim size=" << inputDim << " at dimension "
2326 << index;
2327 }
2328 }
2329
2330 return success();
2331}
2332
2333LogicalResult tosa::MulOp::inferReturnTypeComponents(
2334 MLIRContext *context, ::std::optional<Location> location,
2335 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
2336 RegionRange regions,
2337 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2338 // mul op's output shape only depend on input1 and input2, not on shift
2339 ValueShapeRange twoInputs = operands.drop_back();
2341 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2342 inferredReturnShapes.push_back(ShapedTypeComponents());
2343 } else {
2344 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2345 }
2346 return success();
2347}
2348
2349LogicalResult tosa::MulOp::verify() {
2350 const Value output = getOutput();
2351 auto resElemType = getElementTypeOrSelf(output);
2352
2353 // Verify if the element type among operands and result match tosa
2354 // specification.
2355 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2356 IntegerType lhsIntType =
2357 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2358 IntegerType rhsIntType =
2359 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2360 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2361 return emitOpError("requires the same element type for all operands");
2362
2363 // Though the spec requires the element type of result to be i32, a more
2364 // relaxed way is provided at dialect level for easier cooperating with
2365 // other dialects.
2366 if (lhsIntType.getWidth() > resIntType.getWidth())
2367 return emitOpError("invalid data type size for operands or result");
2368
2369 } else {
2370 // For other supported type, the spec requires requires the same element
2371 // type for all operands (excludes `shift` operand) and results.
2372 for (int i = 0; i < 2; ++i) {
2373 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2374 return emitOpError(
2375 "requires the same element type for all operands and results");
2376 }
2377
2378 // verify shift has value 0 for non-integer types
2379 ElementsAttr shiftElem;
2380 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2381 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2382 if (shift != 0) {
2383 return emitOpError() << "require shift to be 0 for float type";
2384 }
2385 }
2386 }
2387
2388 // Verify the op has same ranks for all main operands (excludes extra operands
2389 // such as shift of mul op, so this is the only difference with the built-in
2390 // `SameOperandsAndResultRank` trait) and results types, if known.
2391 TypeRange operandTypes = getOperandTypes();
2392 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2393 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2394
2395 const bool aHasRank = aType.hasRank();
2396 const bool bHasRank = bType.hasRank();
2397 if (aHasRank && bHasRank) {
2398 const int64_t aRank = aType.getRank();
2399 const int64_t bRank = bType.getRank();
2400 if (aRank != bRank)
2401 return emitOpError("a and b operands don't have matching ranks, got ")
2402 << aRank << " and " << bRank;
2403
2404 // check for broadcast compatible shapes
2405 SmallVector<int64_t> resultShape;
2407 aType.getShape(), bType.getShape(), resultShape))
2408 return emitOpError("a and b operands don't have broadcast-compatible "
2409 "shapes, got ")
2410 << aType << " and " << bType;
2411 }
2412
2413 ShapedType resultType = cast<ShapedType>(output.getType());
2414 if (!resultType.hasRank())
2415 return success();
2416
2417 const int64_t resultRank = resultType.getRank();
2418 if (aHasRank && resultRank != aType.getRank())
2419 return emitOpError("result type has different rank than a, got ")
2420 << resultRank << " vs " << aType.getRank();
2421 if (bHasRank && resultRank != bType.getRank())
2422 return emitOpError("result type has different rank than b, got ")
2423 << resultRank << " vs " << bType.getRank();
2424
2425 return success();
2426}
2427
2428LogicalResult tosa::TableOp::inferReturnTypeComponents(
2429 MLIRContext *context, ::std::optional<Location> location,
2430 TableOp::Adaptor adaptor,
2431 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2432 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2433
2434 if (!inputShape.hasRank()) {
2435 inferredReturnShapes.push_back(ShapedTypeComponents());
2436 return success();
2437 }
2438
2439 inferredReturnShapes.resize(1);
2440 inputShape.getDims(inferredReturnShapes[0]);
2441 return success();
2442}
2443
2444LogicalResult tosa::TableOp::verify() {
2445 const TensorType inputType = getInput1().getType();
2446 const TensorType outputType = getOutput().getType();
2447
2448 if (!inputType.hasRank() || !outputType.hasRank())
2449 return success();
2450
2451 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2452 "result")))
2453 return failure();
2454
2455 auto inputDims = inputType.getShape();
2456 auto outputDims = outputType.getShape();
2457 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2458 int64_t dim = it.index();
2459 auto [inputDim, outputDim] = it.value();
2460 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2461 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2462 << " doesn't match dim(input, " << dim
2463 << ") = " << inputDim;
2464 }
2465 }
2466 return success();
2467}
2468
2469LogicalResult
2470tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2471 // Multiples must be constants.
2472 DenseIntElementsAttr multiplesAttr;
2473 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2474 return failure();
2475 multiples =
2476 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2477 [](const APInt &val) { return val.getSExtValue(); });
2478 return success();
2479}
2480
2481LogicalResult tosa::TileOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2483 TileOp::Adaptor adaptor,
2484 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2485 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2486 SmallVector<int64_t> multiples;
2487 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2488 multiples)) {
2489 auto rank =
2490 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2491 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2492 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2493 return success();
2494 }
2495 multiples = convertToMlirShape(multiples);
2496
2497 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2498 SmallVector<int64_t> outputShape;
2499 if (!inputShape.hasRank()) {
2500 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2501 inferredReturnShapes.push_back(
2502 ShapedTypeComponents(outputShape, inputType));
2503 return success();
2504 }
2505 if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2506 return failure();
2507
2508 // Any non dynamic dimension can be multiplied to a known size.
2509 outputShape.reserve(multiples.size());
2510 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2511 if (multiples[i] == ShapedType::kDynamic) {
2512 outputShape.push_back(ShapedType::kDynamic);
2513 } else {
2514 int64_t dim = inputShape.getDimSize(i);
2515 if (dim != ShapedType::kDynamic)
2516 dim *= multiples[i];
2517 outputShape.push_back(dim);
2518 }
2519 }
2520
2521 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2522 return success();
2523}
2524
2525LogicalResult tosa::TileOp::verify() {
2526 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2527 /* outType = */ getOutput().getType())
2528 .failed()) {
2529 return failure();
2530 }
2531 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2532 ShapedType outputType = llvm::cast<ShapedType>(getType());
2533
2534 shapeType multiplesType =
2535 llvm::cast<tosa::shapeType>(getMultiples().getType());
2536
2537 auto multiplesRank = multiplesType.getRank();
2538
2539 if (inputType.hasRank()) {
2540 if (inputType.getRank() != multiplesRank)
2541 return emitOpError("expect 'multiples' to have rank ")
2542 << inputType.getRank() << " but got " << multiplesRank << ".";
2543 if (outputType.hasRank() &&
2544 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2545 "output")))
2546 return failure();
2547 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2548 return emitOpError("expect 'multiples' array to have length ")
2549 << outputType.getRank() << " but got " << multiplesRank << ".";
2550
2551 SmallVector<int64_t> multiples;
2552 if (getConstantMultiples(multiples).succeeded() &&
2553 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2554 return emitOpError(
2555 "expect element of 'multiples' to be positive integer or -1.");
2556
2557 return success();
2558}
2559
2560bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2561 if (l.size() != r.size() || l.size() != 1)
2562 return false;
2563 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2564}
2565
2566LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2567 MLIRContext *context, ::std::optional<Location> location,
2568 ReshapeOp::Adaptor adaptor,
2569 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2570 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2571 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2572 llvm::SmallVector<int64_t> newShapeValue;
2573 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2574 newShapeValue)) {
2575 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2576 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2577 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2578 return success();
2579 }
2580 newShapeValue = convertToMlirShape(newShapeValue);
2581
2582 // We cannot infer from the total number of elements so we must take the
2583 // shape attribute as exact.
2584 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2585 inferredReturnShapes.push_back(
2586 ShapedTypeComponents(newShapeValue, inputType));
2587 return success();
2588 }
2589
2590 // Determine the number of elements covered by the slice of all static
2591 // dimensions. This allows us to infer the length of the remaining dynamic
2592 // dimension.
2593 int64_t numElements = inputShape.getNumElements();
2594 int64_t staticMul = 1;
2595 for (auto val : newShapeValue) {
2596 if (ShapedType::isStatic(val)) {
2597 staticMul *= val;
2598 }
2599 }
2600
2601 // Determine the length of the dynamic dimension.
2602 for (auto &val : newShapeValue) {
2603 if (ShapedType::isDynamic(val))
2604 val = numElements / staticMul;
2605 }
2606
2607 inferredReturnShapes.push_back(
2608 ShapedTypeComponents(newShapeValue, inputType));
2609 return success();
2610}
2611
2612llvm::LogicalResult tosa::ReshapeOp::verify() {
2613 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2614 /* outType = */ getOutput().getType())
2615 .failed()) {
2616 return failure();
2617 }
2618 TensorType inputType = getInput1().getType();
2619
2620 SmallVector<int64_t> shapeValues;
2621 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2622 // skip following checks if shape is not constant
2623 return mlir::success();
2624 }
2625
2626 int missingDims = llvm::count(shapeValues, kInferableDimSize);
2627 if (missingDims > 1)
2628 return emitOpError() << "expected at most one target dimension to be "
2630
2631 const auto outputType = dyn_cast<RankedTensorType>(getType());
2632 if (!outputType)
2633 return success();
2634
2635 if ((int64_t)shapeValues.size() != outputType.getRank())
2636 return emitOpError() << "new shape does not match result rank";
2637
2638 for (auto [newShapeDim, outputShapeDim] :
2639 zip(shapeValues, outputType.getShape())) {
2640 if (newShapeDim != kInferableDimSize &&
2641 newShapeDim != ShapedType::kDynamic &&
2642 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2643 return emitOpError() << "new shape is inconsistent with result shape";
2644
2645 if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
2646 return emitOpError() << "new shape has invalid tensor dimension size "
2647 << newShapeDim;
2648 }
2649
2650 if (inputType.hasStaticShape()) {
2651 int64_t inputElementsNum = inputType.getNumElements();
2652 if (outputType.hasStaticShape()) {
2653 int64_t outputElementsNum = outputType.getNumElements();
2654 if (inputElementsNum != outputElementsNum) {
2655 return emitOpError() << "cannot reshape " << inputElementsNum
2656 << " elements into " << outputElementsNum;
2657 }
2658 }
2659
2660 int64_t newShapeElementsNum =
2661 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2662 return (dim > 0) ? acc * dim : acc;
2663 });
2664 bool isStaticNewShape =
2665 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2666 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2667 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2668 return emitOpError() << "cannot reshape " << inputElementsNum
2669 << " elements into " << newShapeElementsNum;
2670 }
2671 }
2672
2673 return mlir::success();
2674}
2675
2676// return failure if val is not a constant
2677// set zp to -1 if val is non-zero float or val is not integer nor float
2678// otherwise set zp to val's constant value
2679static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2680 ElementsAttr zpAttr;
2681 if (!matchPattern(val, m_Constant(&zpAttr))) {
2682 return failure();
2683 }
2684
2685 Type zpElemType = zpAttr.getElementType();
2686
2687 if (llvm::isa<FloatType>(zpElemType)) {
2688 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2689 return 0;
2690 }
2691 // return non-zero value to trigger error check
2692 return -1;
2693 }
2694
2695 if (llvm::isa<IntegerType>(zpElemType)) {
2696 if (signExtend)
2697 return zpAttr.getValues<APInt>()[0].getSExtValue();
2698 return zpAttr.getValues<APInt>()[0].getZExtValue();
2699 }
2700
2701 // return non-zero value to trigger error check
2702 return -1;
2703}
2704
2705template <typename T>
2706static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2707 const std::string &operand) {
2708 Type zpElemType = getElementTypeOrSelf(val);
2709
2710 if (!zpElemType.isInteger(8) && zp != 0) {
2711 // convert operand to lower case for error message
2712 std::string lower = operand;
2713 llvm::transform(lower, lower.begin(), ::tolower);
2714 return op.emitOpError()
2715 << lower << " zero point must be zero for non-int8 integer types";
2716 }
2717
2718 return success();
2719}
2720
2721static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2722 const int64_t &zp,
2723 const std::string &operand) {
2724 bool isInputZp = (operand == "Input");
2725
2726 bool tensorUnsigned =
2727 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2728 StringRef tensorName = isInputZp ? "input" : "output";
2729
2730 Type zpElemType = getElementTypeOrSelf(zpVal);
2731
2732 if (zp != 0) {
2733 if (!zpElemType.isInteger(8) &&
2734 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2735 return op.emitOpError()
2736 << "expect " << tensorName << "_zp of 0, got " << zp;
2737 }
2738 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2739 return op.emitOpError() << "expect " << tensorName
2740 << "_zp of 0 or 32768 for unsigned int16 "
2741 << tensorName << ", got " << zp;
2742 }
2743 }
2744
2745 return success();
2746}
2747
2748#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2749 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2750 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2751 } \
2752 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2753 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2754 }
2755
2756ZERO_POINT_HELPER(Conv2DOp, Input, true)
2757ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2758ZERO_POINT_HELPER(Conv3DOp, Input, true)
2759ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2760ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2761ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2762ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2763ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2764ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2765ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2766ZERO_POINT_HELPER(MatMulOp, A, true)
2767ZERO_POINT_HELPER(MatMulOp, B, true)
2768ZERO_POINT_HELPER(NegateOp, Input1, true)
2769ZERO_POINT_HELPER(NegateOp, Output, true)
2770ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2771ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2772#undef ZERO_POINT_HELPER
2773
2774LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2775 MLIRContext *context, ::std::optional<Location> location,
2776 TransposeOp::Adaptor adaptor,
2777 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2778 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2779
2780 // If input rank and permutation length is unknown, the output rank is
2781 // unknown.
2782 if (!inputShape.hasRank()) {
2783 inferredReturnShapes.push_back(ShapedTypeComponents());
2784 return success();
2785 }
2786
2787 const auto inputRank = inputShape.getRank();
2788
2789 // This would imply the number of permutations does not match the rank of
2790 // the input which is illegal.
2791 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2792 return failure();
2793 }
2794
2795 SmallVector<int64_t> outputShape;
2796 // Rank-0 means no permutations matter.
2797 if (inputRank == 0) {
2798 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2799 return success();
2800 }
2801
2802 // Check whether the input dimensions are all the same.
2803 bool allTheSame = true;
2804 for (int i = 1, s = inputRank; i < s; i++) {
2805 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2806 allTheSame = false;
2807 break;
2808 }
2809 }
2810
2811 // If all of the input dimensions are the same we don't care about the
2812 // permutation.
2813 if (allTheSame) {
2814 outputShape.resize(inputRank, inputShape.getDimSize(0));
2815 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2816 return success();
2817 }
2818
2819 outputShape.resize(inputRank, ShapedType::kDynamic);
2820
2821 // Constant permutation values must be within the input rank.
2822 if (llvm::any_of(adaptor.getPerms(),
2823 [inputRank](const auto i) { return i >= inputRank; }))
2824 return failure();
2825
2826 outputShape.reserve(inputRank);
2827 for (int i = 0, s = inputRank; i < s; i++) {
2828 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2829 }
2830
2831 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2832 return success();
2833}
2834
2835LogicalResult tosa::TransposeOp::verify() {
2836 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2837 /* outType = */ getOutput().getType())
2838 .failed()) {
2839 return failure();
2840 }
2841
2842 const ShapeAdaptor inputShape(getInput1().getType());
2843 const ShapeAdaptor outputShape(getOutput().getType());
2844
2845 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2846
2847 if (inputShape.hasRank() &&
2848 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2849 return emitOpError() << "expected perms attribute to have size "
2850 << inputShape.getRank()
2851 << " (input rank) but got size "
2852 << constantPerms.size();
2853
2854 if (inputShape.hasRank() && outputShape.hasRank() &&
2855 inputShape.getRank() != outputShape.getRank())
2856 return emitOpError()
2857 << "expected input tensor rank to equal result tensor rank";
2858
2859 if (outputShape.hasRank() &&
2860 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2861 return emitOpError() << "expected perms attribute to have size "
2862 << outputShape.getRank()
2863 << " (output rank) but got size "
2864 << constantPerms.size();
2865
2866 if (!llvm::all_of(constantPerms,
2867 [&constantPerms](int32_t s) {
2868 return s >= 0 &&
2869 static_cast<size_t>(s) < constantPerms.size();
2870 }) ||
2871 !isPermutationVector(llvm::map_to_vector(
2872 constantPerms, [](int32_t v) -> int64_t { return v; })))
2873 return emitOpError() << "expected valid permutation indices";
2874
2875 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2876 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2877 inputShape.getNumElements() != outputShape.getNumElements())
2878 return emitOpError() << "expected input1 and output to have same numbers "
2879 "of elements, got "
2880 << inputShape.getNumElements() << " and "
2881 << outputShape.getNumElements();
2882
2883 // Verify that the types of the input and output tensors are properly
2884 // permuted.
2885 if (inputShape.hasRank() && outputShape.hasRank()) {
2886 for (auto i = 0; i < outputShape.getRank(); i++) {
2887 if (inputShape.isDynamicDim(constantPerms[i]) ||
2888 outputShape.isDynamicDim(i))
2889 continue;
2890
2891 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2892 return emitOpError()
2893 << "expected output tensor dim " << i << " to match "
2894 << "input dim " << constantPerms[i] << " with value of "
2895 << inputShape.getDimSize(constantPerms[i]);
2896 }
2897 }
2898
2899 return success();
2900}
2901
2902LogicalResult TransposeOp::reifyResultShapes(
2903 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2904
2905 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2906
2907 Value input = getInput1();
2908 auto inputType = cast<TensorType>(input.getType());
2909
2910 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2911 for (auto dim : transposePerms) {
2912 int32_t dimInInput = transposePerms[dim];
2913 if (inputType.isDynamicDim(dimInInput))
2914 returnedDims[dim] =
2915 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2916 .getResult();
2917 else
2918 returnedDims[dim] =
2919 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2920 }
2921
2922 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2923 return success();
2924}
2925
2926LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2927 MLIRContext *context, ::std::optional<Location> location,
2928 GatherOp::Adaptor adaptor,
2929 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2930 llvm::SmallVector<int64_t> outputShape;
2931 outputShape.resize(3, ShapedType::kDynamic);
2932
2933 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2934 if (valuesShape.hasRank()) {
2935 outputShape[0] = valuesShape.getDimSize(0);
2936 outputShape[2] = valuesShape.getDimSize(2);
2937 }
2938
2939 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2940 if (indicesShape.hasRank()) {
2941 if (outputShape[0] == ShapedType::kDynamic)
2942 outputShape[0] = indicesShape.getDimSize(0);
2943 if (outputShape[1] == ShapedType::kDynamic)
2944 outputShape[1] = indicesShape.getDimSize(1);
2945 }
2946
2947 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2948 return success();
2949}
2950
2951LogicalResult tosa::GatherOp::verify() {
2952 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2953 /* outType = */ getOutput().getType())
2954 .failed()) {
2955 return failure();
2956 }
2957
2958 const ShapeAdaptor valuesShape(getValues().getType());
2959 const ShapeAdaptor indicesShape(getIndices().getType());
2960 const ShapeAdaptor outputShape(getOutput().getType());
2961
2962 int64_t n = ShapedType::kDynamic;
2963 int64_t w = ShapedType::kDynamic;
2964 int64_t c = ShapedType::kDynamic;
2965
2966 if (valuesShape.hasRank()) {
2967 n = valuesShape.getDimSize(0);
2968 c = valuesShape.getDimSize(2);
2969 }
2970 if (indicesShape.hasRank()) {
2971 const int64_t indicesN = indicesShape.getDimSize(0);
2972 w = indicesShape.getDimSize(1);
2973 if (n == ShapedType::kDynamic)
2974 n = indicesN;
2975 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2976 return emitOpError() << "requires indices dimension 0 to have size " << n
2977 << ", got " << indicesN;
2978 }
2979 if (outputShape.hasRank()) {
2980 const int64_t outputN = outputShape.getDimSize(0);
2981 const int64_t outputW = outputShape.getDimSize(1);
2982 const int64_t outputC = outputShape.getDimSize(2);
2983 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2984 n != outputN)
2985 return emitOpError() << "requires output dimension 0 to have size " << n
2986 << ", got " << outputN;
2987
2988 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2989 w != outputW)
2990 return emitOpError() << "requires output dimension 1 to have size " << w
2991 << ", got " << outputW;
2992 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2993 c != outputC)
2994 return emitOpError() << "requires output dimension 2 to have size " << c
2995 << ", got " << outputC;
2996 }
2997 return success();
2998}
2999
3000LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3001 MLIRContext *context, ::std::optional<Location> location,
3002 ResizeOp::Adaptor adaptor,
3003 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3004 llvm::SmallVector<int64_t, 4> outputShape;
3005 outputShape.resize(4, ShapedType::kDynamic);
3006
3007 ShapeAdaptor inputShape(adaptor.getInput().getType());
3008 if (!inputShape.hasRank())
3009 return failure();
3010
3011 outputShape[0] = inputShape.getDimSize(0);
3012 outputShape[3] = inputShape.getDimSize(3);
3013 int64_t inputHeight = inputShape.getDimSize(1);
3014 int64_t inputWidth = inputShape.getDimSize(2);
3015
3016 if ((inputHeight == ShapedType::kDynamic) ||
3017 (inputWidth == ShapedType::kDynamic))
3018 return failure();
3019
3020 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3021 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
3022 scaleInt) ||
3023 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
3024 offsetInt) ||
3025 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
3026 borderInt)) {
3027 return failure();
3028 }
3029
3030 // Compute the output shape based on attributes: scale, offset, and border.
3031 const int64_t outputHeight =
3032 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3033 scaleInt[1]) +
3034 1;
3035
3036 const int64_t outputWidth =
3037 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3038 scaleInt[3]) +
3039 1;
3040
3041 if (outputHeight < 0 || outputWidth < 0) {
3042 return emitOptionalError(
3043 location,
3044 "calculated output height and width must be non-negative, "
3045 "got height = ",
3046 outputHeight, ", width = ", outputWidth);
3047 }
3048
3049 outputShape[1] = outputHeight;
3050 outputShape[2] = outputWidth;
3051 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3052 return success();
3053}
3054
3055LogicalResult tosa::ResizeOp::verify() {
3056 const Value input = getInput();
3057 const Value output = getOutput();
3058 const RankedTensorType inputType =
3059 llvm::dyn_cast<RankedTensorType>(input.getType());
3060 const RankedTensorType outputType =
3061 llvm::dyn_cast<RankedTensorType>(output.getType());
3062
3063 SmallVector<int64_t> scaleValues;
3064 SmallVector<int64_t> offsetValues;
3065 SmallVector<int64_t> borderValues;
3066 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3067 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3068 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3069 // Skip following checks if shape is not constant
3070 return success();
3071 }
3072
3073 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3074 return emitOpError("expect all scale values to be > 0, got ")
3075 << scaleValues;
3076
3077 const int64_t scaleYN = scaleValues[0];
3078 const int64_t scaleYD = scaleValues[1];
3079 const int64_t scaleXN = scaleValues[2];
3080 const int64_t scaleXD = scaleValues[3];
3081
3082 const int64_t offsetY = offsetValues[0];
3083 const int64_t offsetX = offsetValues[1];
3084
3085 const int64_t borderY = borderValues[0];
3086 const int64_t borderX = borderValues[1];
3087
3088 if (!inputType)
3089 return success();
3090 if (!outputType)
3091 return success();
3092
3093 const int64_t oh = outputType.getDimSize(1);
3094 const int64_t ow = outputType.getDimSize(2);
3095 const int64_t ih = inputType.getDimSize(1);
3096 const int64_t iw = inputType.getDimSize(2);
3097
3098 // Don't check with input height that could be broadcast (ih != 1)
3099 // since Linalg, a consumer of TOSA, expects broadcasting support
3100 // in resize to be available. Taking the cautious approach for now,
3101 // we can consider removing support for broadcasting later.
3102 if (ih != ShapedType::kDynamic && ih != 1) {
3103 const std::optional<int64_t> calculatedOutHeightMinusOne =
3104 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3105 if (!calculatedOutHeightMinusOne.has_value())
3106 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3107 "border_y ")
3108 << "to be wholly divisible by scale_y_d, got ((" << ih
3109 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3110 << ") / " << scaleYD;
3111 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3112 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3113 return emitOpError("calculated output height did not match expected: ")
3114 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3115 }
3116
3117 // Don't check with input width that could be broadcast (iw != 1)
3118 // since Linalg, a consumer of TOSA, expects broadcasting support
3119 // in resize to be available. Taking the cautious approach for now,
3120 // we can consider removing support for broadcasting later.
3121 if (iw != ShapedType::kDynamic && iw != 1) {
3122 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3123 const std::optional<int64_t> calculatedOutWidthMinusOne =
3124 idivCheck(scaledInWidth, scaleXD);
3125 if (!calculatedOutWidthMinusOne.has_value())
3126 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3127 "border_x ")
3128 << "to be wholly divisible by scale_x_d, got ((" << iw
3129 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3130 << ") / " << scaleXD;
3131 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3132 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3133 return emitOpError("calculated output width did not match expected: ")
3134 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3135 }
3136
3137 return success();
3138}
3139
3140LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3141 MLIRContext *context, ::std::optional<Location> location,
3142 ScatterOp::Adaptor adaptor,
3143 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3144 llvm::SmallVector<int64_t> outputShape;
3145 outputShape.resize(3, ShapedType::kDynamic);
3146
3147 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3148 if (valuesInShape.hasRank()) {
3149 outputShape[0] = valuesInShape.getDimSize(0);
3150 outputShape[1] = valuesInShape.getDimSize(1);
3151 outputShape[2] = valuesInShape.getDimSize(2);
3152 }
3153
3154 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3155 if (indicesShape.hasRank()) {
3156 if (outputShape[0] == ShapedType::kDynamic)
3157 outputShape[0] = indicesShape.getDimSize(0);
3158 }
3159
3160 ShapeAdaptor inputShape(adaptor.getInput().getType());
3161 if (inputShape.hasRank()) {
3162 if (outputShape[0] == ShapedType::kDynamic)
3163 outputShape[0] = inputShape.getDimSize(0);
3164 if (outputShape[2] == ShapedType::kDynamic)
3165 outputShape[2] = inputShape.getDimSize(2);
3166 }
3167
3168 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3169 return success();
3170}
3171
3172LogicalResult tosa::ScatterOp::verify() {
3173 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3174 /* outType = */ getValuesOut().getType())
3175 .failed() ||
3176 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3177 /* outType = */ getValuesOut().getType())
3178 .failed()) {
3179 return failure();
3180 }
3181
3182 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3183 const ShapeAdaptor indicesShape(getIndices().getType());
3184 const ShapeAdaptor inputShape(getInput().getType());
3185 const ShapeAdaptor outputShape(getValuesOut().getType());
3186
3187 int64_t n = ShapedType::kDynamic;
3188 int64_t k = ShapedType::kDynamic;
3189 int64_t w = ShapedType::kDynamic;
3190 int64_t c = ShapedType::kDynamic;
3191 if (valuesInShape.hasRank()) {
3192 n = valuesInShape.getDimSize(0);
3193 k = valuesInShape.getDimSize(1);
3194 c = valuesInShape.getDimSize(2);
3195 }
3196 if (indicesShape.hasRank()) {
3197 const int64_t indicesN = indicesShape.getDimSize(0);
3198 w = indicesShape.getDimSize(1);
3199 if (n == ShapedType::kDynamic)
3200 n = indicesN;
3201 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3202 return emitOpError() << "requires indices dimension 0 to have size " << n
3203 << ", got " << indicesN;
3204 }
3205 if (inputShape.hasRank()) {
3206 const int64_t inputN = inputShape.getDimSize(0);
3207 const int64_t inputW = inputShape.getDimSize(1);
3208 const int64_t inputC = inputShape.getDimSize(2);
3209 if (n == ShapedType::kDynamic)
3210 n = inputN;
3211 else if (inputN != ShapedType::kDynamic && n != inputN)
3212 return emitOpError() << "requires input dimension 0 to have size " << n
3213 << ", got " << inputN;
3214 if (w == ShapedType::kDynamic)
3215 w = inputW;
3216 else if (inputW != ShapedType::kDynamic && w != inputW)
3217 return emitOpError() << "requires input dimension 1 to have size " << w
3218 << ", got " << inputW;
3219
3220 if (c == ShapedType::kDynamic)
3221 c = inputC;
3222 else if (inputC != ShapedType::kDynamic && c != inputC)
3223 return emitOpError() << "requires input dimension 2 to have size " << c
3224 << ", got " << inputC;
3225 }
3226 if (outputShape.hasRank()) {
3227 const int64_t outputN = outputShape.getDimSize(0);
3228 const int64_t outputK = outputShape.getDimSize(1);
3229 const int64_t outputC = outputShape.getDimSize(2);
3230 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3231 n != outputN)
3232 return emitOpError() << "requires values_out dimension 0 to have size "
3233 << n << ", got " << outputN;
3234 if (k == ShapedType::kDynamic)
3235 k = outputK;
3236 else if (outputK != ShapedType::kDynamic && k != outputK)
3237 return emitOpError() << "requires values_out dimension 1 to have size "
3238 << k << ", got " << outputK;
3239 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3240 c != outputC)
3241 return emitOpError() << "requires values_out dimension 2 to have size "
3242 << c << ", got " << outputC;
3243 }
3244 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3245 return emitOpError() << "requires dimensions K >= W, got K=" << k
3246 << " and W=" << w;
3247
3248 return success();
3249}
3250
3251static LogicalResult ReduceInferReturnTypes(
3252 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3253 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3254 int64_t axisVal = axis.getValue().getSExtValue();
3255 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3256 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3257 return success();
3258 }
3259
3260 SmallVector<int64_t> outputShape;
3261 operandShape.getDims(outputShape);
3262 outputShape[axisVal] = 1;
3263 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3264 return success();
3265}
3266
3267#define COMPATIBLE_RETURN_TYPES(OP) \
3268 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3269 if (l.size() != r.size() || l.size() != 1) \
3270 return false; \
3271 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3272 return false; \
3273 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3274 }
3275
3276#define REDUCE_SHAPE_INFER(OP) \
3277 LogicalResult OP::inferReturnTypeComponents( \
3278 MLIRContext *context, ::std::optional<Location> location, \
3279 OP::Adaptor adaptor, \
3280 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3281 Type inputType = \
3282 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3283 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3284 const Properties &prop = adaptor.getProperties(); \
3285 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3286 inferredReturnShapes); \
3287 } \
3288 COMPATIBLE_RETURN_TYPES(OP)
3289
3290REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3291REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3292REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3293REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3294REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3295REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3296#undef REDUCE_SHAPE_INFER
3297COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3298#undef COMPATIBLE_RETURN_TYPES
3299
3300template <typename T>
3301static LogicalResult verifyReduceOp(T op) {
3302 // All TOSA reduce Ops have input, output and axis.
3303 TensorType inputType = op.getInput().getType();
3304 TensorType outputType = op.getOutput().getType();
3305 int32_t reduceAxis = op.getAxis();
3306
3307 if (reduceAxis < 0) {
3308 op.emitOpError("reduce axis must not be negative");
3309 return failure();
3310 }
3311 if (inputType.hasRank()) {
3312 int64_t inputRank = inputType.getRank();
3313 // We allow for a special case where the input/output shape has rank 0 and
3314 // axis is also 0.
3315 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3316 op.emitOpError("expect input tensor rank (")
3317 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3318 << ")";
3319 return failure();
3320 }
3321 }
3322 if (outputType.hasRank()) {
3323 int64_t outputRank = outputType.getRank();
3324 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3325 op.emitOpError(
3326 "expect output tensor rank to be equal to input tensor rank");
3327 return failure();
3328 }
3329 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3330 op.emitOpError("expect output tensor rank (")
3331 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3332 << ")";
3333 return failure();
3334 }
3335 // We can only verify the reduced dimension size to be 1 if this is not
3336 // the special case of output rank == 0.
3337 if (outputRank != 0) {
3338 auto outputShape = outputType.getShape();
3339 if (!outputType.isDynamicDim(reduceAxis) &&
3340 outputShape[reduceAxis] != 1) {
3341 op.emitOpError("expect reduced dimension size to be 1, got ")
3342 << outputShape[reduceAxis];
3343 return failure();
3344 }
3345 }
3346 }
3347 return success();
3348}
3349
3350LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3351LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3352LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3353LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3354LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3355LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3356
3357static LogicalResult NAryInferReturnTypes(
3358 const ValueShapeRange &operands,
3359 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3361 if (resolveBroadcastShape(operands, outShape).failed()) {
3362 inferredReturnShapes.push_back(ShapedTypeComponents());
3363 } else {
3364 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3365 }
3366 return success();
3367}
3368
3369#define NARY_SHAPE_INFER(OP) \
3370 LogicalResult OP::inferReturnTypeComponents( \
3371 MLIRContext *context, ::std::optional<Location> location, \
3372 ValueShapeRange operands, DictionaryAttr attributes, \
3373 PropertyRef properties, RegionRange regions, \
3374 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3375 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3376 }
3377
3378NARY_SHAPE_INFER(tosa::AbsOp)
3379NARY_SHAPE_INFER(tosa::AddOp)
3380NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3381NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3382NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3383NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3384NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3385NARY_SHAPE_INFER(tosa::CastOp)
3386NARY_SHAPE_INFER(tosa::CeilOp)
3387NARY_SHAPE_INFER(tosa::ClampOp)
3388NARY_SHAPE_INFER(tosa::ClzOp)
3389NARY_SHAPE_INFER(tosa::CosOp)
3390NARY_SHAPE_INFER(tosa::ExpOp)
3391NARY_SHAPE_INFER(tosa::FloorOp)
3392NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3393NARY_SHAPE_INFER(tosa::GreaterOp)
3394NARY_SHAPE_INFER(tosa::IdentityOp)
3395NARY_SHAPE_INFER(tosa::IntDivOp)
3396NARY_SHAPE_INFER(tosa::LogOp)
3397NARY_SHAPE_INFER(tosa::LogicalAndOp)
3398NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3399NARY_SHAPE_INFER(tosa::LogicalNotOp)
3400NARY_SHAPE_INFER(tosa::LogicalOrOp)
3401NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3402NARY_SHAPE_INFER(tosa::LogicalXorOp)
3403NARY_SHAPE_INFER(tosa::MaximumOp)
3404NARY_SHAPE_INFER(tosa::MinimumOp)
3405NARY_SHAPE_INFER(tosa::PowOp)
3406NARY_SHAPE_INFER(tosa::ReciprocalOp)
3407NARY_SHAPE_INFER(tosa::ReverseOp)
3408NARY_SHAPE_INFER(tosa::RsqrtOp)
3409NARY_SHAPE_INFER(tosa::SinOp)
3410NARY_SHAPE_INFER(tosa::SelectOp)
3411NARY_SHAPE_INFER(tosa::SubOp)
3412NARY_SHAPE_INFER(tosa::TanhOp)
3413NARY_SHAPE_INFER(tosa::ErfOp)
3414NARY_SHAPE_INFER(tosa::SigmoidOp)
3415#undef PRED_SHAPE_INFER
3416
3417LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3418 MLIRContext *context, ::std::optional<Location> location,
3419 NegateOp::Adaptor adaptor,
3420 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3421 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3422 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3423 return success();
3424}
3425
3426LogicalResult tosa::NegateOp::verify() {
3427 // Verify same element type
3428 const Type input1Type = getInput1().getType();
3429 const Type outputType = getOutput().getType();
3430 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3431 return failure();
3432
3433 // Verify same shape
3434 const SmallVector<Type, 2> types = {input1Type, outputType};
3435 if (failed(verifyCompatibleShapes(types)))
3436 return emitOpError() << "requires the same shape for input1 and output";
3437
3438 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3439 const Type input1ZpEType =
3440 getStorageElementTypeOrSelf(getInput1Zp().getType());
3441 if (input1EType != input1ZpEType) {
3442 return emitOpError("expect both input1 and its zero point are the same "
3443 "element type, got ")
3444 << input1EType << " and " << input1ZpEType;
3445 }
3446 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3447 const Type outputZpEType =
3448 getStorageElementTypeOrSelf(getOutputZp().getType());
3449 if (outputEType != outputZpEType) {
3450 return emitOpError("expect both output and its zero point are the same "
3451 "element type, got ")
3452 << outputEType << " and " << outputZpEType;
3453 }
3454
3455 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3456 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3457 return failure();
3458
3459 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3460 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3461 return failure();
3462
3463 return success();
3464}
3465
3466static LogicalResult poolingInferReturnTypes(
3467 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3469 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3470 llvm::SmallVector<int64_t> outputShape;
3471 outputShape.resize(4, ShapedType::kDynamic);
3472
3473 // We only know the rank if the input type is unranked.
3474 if (!inputShape.hasRank()) {
3475 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3476 return success();
3477 }
3478
3479 // Batch and number of channels are identical for pooling layer.
3480 outputShape[0] = inputShape.getDimSize(0);
3481 outputShape[3] = inputShape.getDimSize(3);
3482
3483 int64_t height = inputShape.getDimSize(1);
3484 int64_t width = inputShape.getDimSize(2);
3485
3486 if (ShapedType::isStatic(height)) {
3487 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3488 outputShape[1] = padded / stride[0] + 1;
3489 }
3490
3491 if (ShapedType::isStatic(width)) {
3492 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3493 outputShape[2] = padded / stride[1] + 1;
3494 }
3495
3496 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3497 return success();
3498}
3499
3500template <typename AdaptorT>
3502
3504protected:
3505 static void updateIfDynamic(int64_t &current, int64_t candidate) {
3506 if (ShapedType::isDynamic(current))
3507 current = candidate;
3508 }
3509};
3510
3511template <>
3512class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
3513 : public ConvInferShapeAdaptorBase {
3514public:
3515 explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
3516 : adaptor(adaptor) {}
3517
3519 SmallVectorImpl<int64_t> &inputSpatial) {
3520 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3521 if (!inputShape.hasRank())
3522 return;
3523
3524 const int64_t outputBatch = inputShape.getDimSize(0);
3525 const int64_t inputHeight = inputShape.getDimSize(1);
3526 const int64_t inputWidth = inputShape.getDimSize(2);
3527
3528 outputShape[0] = outputBatch;
3529 inputSpatial[0] = inputHeight;
3530 inputSpatial[1] = inputWidth;
3531 }
3532
3534 SmallVectorImpl<int64_t> &weightSpatial) {
3535 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3536 if (!weightShape.hasRank())
3537 return;
3538
3539 const int64_t outputChannels = weightShape.getDimSize(0);
3540 const int64_t kernelHeight = weightShape.getDimSize(1);
3541 const int64_t kernelWidth = weightShape.getDimSize(2);
3542
3543 outputShape[3] = outputChannels;
3544 weightSpatial[0] = kernelHeight;
3545 weightSpatial[1] = kernelWidth;
3546 }
3547
3548 int64_t getNumSpatialDims() const { return 2; }
3549 int64_t getOutputRank() const { return 4; }
3550
3552 SmallVector<int64_t> &strideValues,
3553 SmallVector<int64_t> &dilationValues) {
3554 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3555 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3556 dilationValues.assign(adaptor.getDilation().begin(),
3557 adaptor.getDilation().end());
3558 return success();
3559 }
3560
3561private:
3562 Conv2DOp::Adaptor adaptor;
3563};
3564
3565template <>
3566class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
3567 : public ConvInferShapeAdaptorBase {
3568public:
3569 explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
3570 : adaptor(adaptor) {}
3571
3573 SmallVectorImpl<int64_t> &inputSpatial) {
3574 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3575 if (inputDataShape.hasRank()) {
3576 const int64_t outputBatch = inputDataShape.getDimSize(0);
3577 const int64_t inputHeight = inputDataShape.getDimSize(1);
3578 const int64_t inputWidth = inputDataShape.getDimSize(2);
3579
3580 outputShape[0] = outputBatch;
3581 inputSpatial[0] = inputHeight;
3582 inputSpatial[1] = inputWidth;
3583 }
3584
3585 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3586 if (!inputScaleShape.hasRank())
3587 return;
3588
3589 const int64_t scaleBatch = inputScaleShape.getDimSize(0);
3590 const int64_t scaleHeight = inputScaleShape.getDimSize(1);
3591 const int64_t scaleWidth = inputScaleShape.getDimSize(2);
3592
3593 updateIfDynamic(outputShape[0], scaleBatch);
3594 updateIfDynamic(inputSpatial[0], scaleHeight);
3595 updateIfDynamic(inputSpatial[1], scaleWidth);
3596 }
3597
3599 SmallVectorImpl<int64_t> &weightSpatial) {
3600 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3601 if (weightDataShape.hasRank()) {
3602 const int64_t outputChannels = weightDataShape.getDimSize(0);
3603 const int64_t kernelHeight = weightDataShape.getDimSize(1);
3604 const int64_t kernelWidth = weightDataShape.getDimSize(2);
3605
3606 outputShape[3] = outputChannels;
3607 weightSpatial[0] = kernelHeight;
3608 weightSpatial[1] = kernelWidth;
3609 }
3610
3611 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3612 if (!weightScaleShape.hasRank())
3613 return;
3614
3615 const int64_t scaleOutputChannels = weightScaleShape.getDimSize(0);
3616 const int64_t scaleKernelHeight = weightScaleShape.getDimSize(1);
3617 const int64_t scaleKernelWidth = weightScaleShape.getDimSize(2);
3618
3619 updateIfDynamic(outputShape[3], scaleOutputChannels);
3620 updateIfDynamic(weightSpatial[0], scaleKernelHeight);
3621 updateIfDynamic(weightSpatial[1], scaleKernelWidth);
3622 }
3623
3624 int64_t getNumSpatialDims() const { return 2; }
3625 int64_t getOutputRank() const { return 4; }
3626
3628 SmallVector<int64_t> &strideValues,
3629 SmallVector<int64_t> &dilationValues) {
3630 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
3631 padValues) ||
3632 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
3633 strideValues) ||
3634 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
3635 dilationValues))
3636 return failure();
3637 return success();
3638 }
3639
3640private:
3641 Conv2DBlockScaledOp::Adaptor adaptor;
3642};
3643
3644template <>
3645class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
3646 : public ConvInferShapeAdaptorBase {
3647public:
3648 explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
3649 : adaptor(adaptor) {}
3650
3652 SmallVectorImpl<int64_t> &inputSpatial) {
3653 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3654 if (!inputShape.hasRank())
3655 return;
3656
3657 const int64_t outputBatch = inputShape.getDimSize(0);
3658 const int64_t inputDepth = inputShape.getDimSize(1);
3659 const int64_t inputHeight = inputShape.getDimSize(2);
3660 const int64_t inputWidth = inputShape.getDimSize(3);
3661
3662 outputShape[0] = outputBatch;
3663 inputSpatial[0] = inputDepth;
3664 inputSpatial[1] = inputHeight;
3665 inputSpatial[2] = inputWidth;
3666 }
3667
3669 SmallVectorImpl<int64_t> &weightSpatial) {
3670 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3671 if (!weightShape.hasRank())
3672 return;
3673
3674 const int64_t outputChannels = weightShape.getDimSize(0);
3675 const int64_t kernelDepth = weightShape.getDimSize(1);
3676 const int64_t kernelHeight = weightShape.getDimSize(2);
3677 const int64_t kernelWidth = weightShape.getDimSize(3);
3678
3679 outputShape[4] = outputChannels;
3680 weightSpatial[0] = kernelDepth;
3681 weightSpatial[1] = kernelHeight;
3682 weightSpatial[2] = kernelWidth;
3683 }
3684
3685 int64_t getNumSpatialDims() const { return 3; }
3686 int64_t getOutputRank() const { return 5; }
3687
3689 SmallVector<int64_t> &strideValues,
3690 SmallVector<int64_t> &dilationValues) {
3691 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3692 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3693 dilationValues.assign(adaptor.getDilation().begin(),
3694 adaptor.getDilation().end());
3695 return success();
3696 }
3697
3698private:
3699 Conv3DOp::Adaptor adaptor;
3700};
3701
3702template <typename AdaptorT>
3704 AdaptorT adaptor,
3705 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3706 ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
3707 llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
3708 ShapedType::kDynamic);
3709 llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
3710 ShapedType::kDynamic);
3711 llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
3712 ShapedType::kDynamic);
3713
3714 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
3715 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
3716
3717 const ShapeAdaptor biasShape = adaptor.getBias().getType();
3718 if (biasShape.hasRank()) {
3719 const int64_t biasSize = biasShape.getDimSize(0);
3720 if (biasSize != 1) {
3721 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
3722 outputShape[outputChannelDim] =
3723 ShapedType::isDynamic(outputShape[outputChannelDim])
3724 ? biasSize
3725 : outputShape[outputChannelDim];
3726 }
3727 }
3728
3729 SmallVector<int64_t> padValues;
3730 SmallVector<int64_t> strideValues;
3731 SmallVector<int64_t> dilationValues;
3732 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
3733 dilationValues))) {
3734 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3735 return success();
3736 }
3737
3738 for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
3739 if (!ShapedType::isStatic(inputSpatial[dim]) ||
3740 !ShapedType::isStatic(weightSpatial[dim]))
3741 continue;
3742 const int64_t inputSize =
3743 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
3744 const int64_t filterSize =
3745 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
3746 const int64_t unstridedResult = inputSize - filterSize + 1;
3747 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
3748 }
3749
3750 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3751 return success();
3752}
3753
3754LogicalResult Conv2DOp::inferReturnTypeComponents(
3755 MLIRContext *context, ::std::optional<Location> location,
3756 Conv2DOp::Adaptor adaptor,
3757 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3758 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3759}
3760
3761LogicalResult Conv2DOp::verify() {
3762 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3763 verifyConvOpErrorIf(*this).failed())
3764 return failure();
3765 return success();
3766}
3767
3768LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3769 MLIRContext *context, ::std::optional<Location> location,
3770 Conv2DBlockScaledOp::Adaptor adaptor,
3771 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3772 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3773}
3774
3775LogicalResult Conv2DBlockScaledOp::verify() {
3776 if (failed(verifySameElementTypes(*this, getInputData().getType(),
3777 getWeightData().getType(), "input_data",
3778 "weight_data")) ||
3779 failed(verifySameElementTypes(*this, getInputScale().getType(),
3780 getWeightScale().getType(), "input_scale",
3781 "weight_scale")) ||
3782 failed(verifySameElementTypes(*this, getBias().getType(),
3783 getOutput().getType(), "bias", "output")))
3784 return failure();
3785
3786 // Verify input shape compatibility
3787 int64_t N = ShapedType::kDynamic;
3788 int64_t IH = ShapedType::kDynamic;
3789 int64_t IW = ShapedType::kDynamic;
3790 int64_t IC = ShapedType::kDynamic;
3791 int64_t multiplesOfIC = ShapedType::kDynamic;
3792 int64_t OC = ShapedType::kDynamic;
3793 int64_t KH = ShapedType::kDynamic;
3794 int64_t KW = ShapedType::kDynamic;
3795
3796 const ShapeAdaptor inputDataShape(getInputData().getType());
3797 if (inputDataShape.hasRank()) {
3798 N = inputDataShape.getDimSize(0);
3799 IH = inputDataShape.getDimSize(1);
3800 IW = inputDataShape.getDimSize(2);
3801 IC = inputDataShape.getDimSize(3);
3802 }
3803
3804 const ShapeAdaptor inputScaleShape(getInputScale().getType());
3805 if (inputScaleShape.hasRank()) {
3806 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
3807 "input_scale", "batch size")) ||
3808 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
3809 "input_scale", "input height")) ||
3810 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
3811 "input_scale", "input width")))
3812 return failure();
3813 multiplesOfIC = inputScaleShape.getDimSize(3);
3814 }
3815
3816 const ShapeAdaptor weightDataShape(getWeightData().getType());
3817 if (weightDataShape.hasRank()) {
3818 OC = weightDataShape.getDimSize(0);
3819 KH = weightDataShape.getDimSize(1);
3820 KW = weightDataShape.getDimSize(2);
3821 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
3822 "weight_data", "input channels")))
3823 return failure();
3824 }
3825
3826 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
3827 if (weightScaleShape.hasRank()) {
3828 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
3829 "weight_scale", "output channels")) ||
3830 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
3831 "weight_scale", "kernel height")) ||
3832 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
3833 "weight_scale", "kernel width")) ||
3834 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
3835 weightScaleShape.getDimSize(3),
3836 "weight_scale", "input channel blocks")))
3837 return failure();
3838 }
3839
3840 // Verify IC is a multiple of block size
3841 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3842 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3843 return emitOpError("expect IC to be a multiple of block size, got IC=")
3844 << IC << ", block_size=" << blockSize;
3845
3846 // Verify multiplesOfIC is IC / block size
3847 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3848 multiplesOfIC != IC / blockSize)
3849 return emitOpError(
3850 "expect scale operands dimension 2 to equal IC/block_size (")
3851 << IC << "/" << blockSize << ")"
3852 << ", got " << multiplesOfIC;
3853
3854 // Verify pad/stride/dilation values
3855 SmallVector<int64_t> padValues;
3856 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
3857 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
3858 return emitOpError("expect all padding values to be >= 0, got ")
3859 << padValues;
3860 }
3861
3862 SmallVector<int64_t> strideValues;
3863 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
3864 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
3865 return emitOpError("expect all stride values to be >= 1, got ")
3866 << strideValues;
3867 }
3868
3869 SmallVector<int64_t> dilationValues;
3870 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
3871 dilationValues)) {
3872 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
3873 return emitOpError("expect all dilation values to be >= 1, got ")
3874 << dilationValues;
3875 }
3876
3877 // Verify output shape compatibility
3878 const ShapeAdaptor outputShape(getOutput().getType());
3879 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3880 outputShape.hasRank()) {
3881 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
3882 padValues[0], padValues[1], strideValues[0],
3883 dilationValues[0], "height", "y", "top",
3884 "bottom")) ||
3885 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
3886 padValues[2], padValues[3], strideValues[1],
3887 dilationValues[1], "width", "x", "left",
3888 "right")))
3889 return failure();
3890 }
3891
3892 // Verify bias
3893 const ShapeAdaptor biasShape(getBias().getType());
3894 if (biasShape.hasRank() && outputShape.hasRank()) {
3895 const int64_t biasChannels = biasShape.getDimSize(0);
3896 const int64_t outputChannels =
3897 outputShape.getDimSize(outputShape.getRank() - 1);
3898 if (biasChannels == ShapedType::kDynamic ||
3899 outputChannels == ShapedType::kDynamic)
3900 // Skip following checks if biasChannels or outputChannels is dynamic dim
3901 return success();
3902
3903 if (biasChannels != outputChannels && biasChannels != 1)
3904 return emitOpError(
3905 "bias channels expected to be equal to output channels (")
3906 << outputChannels << ") or 1, got " << biasChannels;
3907 }
3908
3909 return success();
3910}
3911
3912LogicalResult Conv3DOp::inferReturnTypeComponents(
3913 MLIRContext *context, ::std::optional<Location> location,
3914 Conv3DOp::Adaptor adaptor,
3915 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3916 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3917}
3918
3919LogicalResult Conv3DOp::verify() {
3920 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3921 verifyConvOpErrorIf(*this).failed())
3922 return failure();
3923 return success();
3924}
3925
3926LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3927 MLIRContext *context, ::std::optional<Location> location,
3928 AvgPool2dOp::Adaptor adaptor,
3929 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3930 ShapeAdaptor inputShape(adaptor.getInput().getType());
3931 const Properties &prop = adaptor.getProperties();
3932 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3933 inferredReturnShapes);
3934}
3935
3936LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3937 MLIRContext *context, ::std::optional<Location> location,
3938 MaxPool2dOp::Adaptor adaptor,
3939 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3940 ShapeAdaptor inputShape(adaptor.getInput().getType());
3941 const Properties &prop = adaptor.getProperties();
3942 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3943 inferredReturnShapes);
3944}
3945
3946LogicalResult MaxPool2dOp::verify() {
3947 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3948 /* outType = */ getOutput().getType())))
3949 return failure();
3950
3951 if (failed(verifyPoolingOp(*this)))
3952 return failure();
3953
3954 return success();
3955}
3956
3957LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3958 MLIRContext *context, ::std::optional<Location> location,
3959 DepthwiseConv2DOp::Adaptor adaptor,
3960 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3961 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3962
3963 int64_t inputWidth = ShapedType::kDynamic;
3964 int64_t inputHeight = ShapedType::kDynamic;
3965 int64_t inputChannels = ShapedType::kDynamic;
3966
3967 int64_t weightWidth = ShapedType::kDynamic;
3968 int64_t weightHeight = ShapedType::kDynamic;
3969 int64_t depthChannels = ShapedType::kDynamic;
3970
3971 // Input shape describes input width/height and batch.
3972 ShapeAdaptor inputShape(adaptor.getInput().getType());
3973 if (inputShape.hasRank()) {
3974 outputShape[0] = inputShape.getDimSize(0);
3975 inputHeight = inputShape.getDimSize(1);
3976 inputWidth = inputShape.getDimSize(2);
3977 inputChannels = inputShape.getDimSize(3);
3978 }
3979
3980 // Weight shapes describes the filter width/height and the output channels.
3981 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3982 if (weightShape.hasRank()) {
3983 weightHeight = weightShape.getDimSize(0);
3984 weightWidth = weightShape.getDimSize(1);
3985 inputChannels = ShapedType::isDynamic(inputChannels)
3986 ? weightShape.getDimSize(2)
3987 : inputChannels;
3988 depthChannels = weightShape.getDimSize(3);
3989 }
3990
3991 // If both inputChannels and depthChannels are available we can determine
3992 // the output channels.
3993 if (ShapedType::isStatic(inputChannels) &&
3994 ShapedType::isStatic(depthChannels)) {
3995 outputShape[3] = inputChannels * depthChannels;
3996 }
3997
3998 // Bias shape can describe the output channels.
3999 ShapeAdaptor biasShape(adaptor.getBias().getType());
4000 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4001 int64_t bc = biasShape.getDimSize(0);
4002 if (bc != ShapedType::kDynamic && bc != 1)
4003 outputShape[3] = bc;
4004 }
4005
4006 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4007 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4008 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4009
4010 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4011 int64_t inputSize = inputHeight + padding[0] + padding[1];
4012 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4013 int64_t unstridedResult = inputSize - filterSize + 1;
4014 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4015 }
4016
4017 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4018 int64_t inputSize = inputWidth + padding[2] + padding[3];
4019 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4020 int64_t unstridedResult = inputSize - filterSize + 1;
4021 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4022 }
4023
4024 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4025 return success();
4026}
4027
4028LogicalResult DepthwiseConv2DOp::verify() {
4029 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4030 verifyConvOpErrorIf(*this).failed())
4031 return failure();
4032 return success();
4033}
4034
4035LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4036 MLIRContext *context, ::std::optional<Location> location,
4037 TransposeConv2DOp::Adaptor adaptor,
4038 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4039 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4040
4041 int64_t inputWidth = ShapedType::kDynamic;
4042 int64_t inputHeight = ShapedType::kDynamic;
4043 int64_t weightWidth = ShapedType::kDynamic;
4044 int64_t weightHeight = ShapedType::kDynamic;
4045
4046 // Input shape describes input width/height and batch.
4047 ShapeAdaptor inputShape(adaptor.getInput().getType());
4048 if (inputShape.hasRank()) {
4049 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4050 ? inputShape.getDimSize(0)
4051 : outputShape[0];
4052 inputHeight = inputShape.getDimSize(1);
4053 inputWidth = inputShape.getDimSize(2);
4054 }
4055
4056 // Weight shapes describes the filter width/height and the output channels.
4057 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4058 if (weightShape.hasRank()) {
4059 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4060 ? weightShape.getDimSize(0)
4061 : outputShape[3];
4062 weightHeight = weightShape.getDimSize(1);
4063 weightWidth = weightShape.getDimSize(2);
4064 }
4065
4066 // Bias shape can describe the output channels.
4067 ShapeAdaptor biasShape(adaptor.getBias().getType());
4068 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4069 int64_t bc = biasShape.getDimSize(0);
4070 if (bc != ShapedType::kDynamic && bc != 1)
4071 outputShape[3] = bc;
4072 }
4073
4074 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4075 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4076
4077 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4078 int64_t calculateSize =
4079 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4080 outputShape[1] =
4081 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4082 }
4083
4084 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4085 int64_t calculateSize =
4086 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4087 outputShape[2] =
4088 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4089 }
4090
4091 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4092 return success();
4093}
4094
4095LogicalResult TransposeConv2DOp::verify() {
4096 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
4097 return failure();
4098
4099 const llvm::ArrayRef<int64_t> strides = getStride();
4100 const int64_t strideY = strides[0];
4101 const int64_t strideX = strides[1];
4102
4103 if (strideY < 1 || strideX < 1)
4104 return emitOpError("expect all stride values to be >= 1, got [")
4105 << strides << "]";
4106
4107 const auto checkPadAgainstKernelDim =
4108 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4109 llvm::StringRef kernelDimName) -> LogicalResult {
4110 if (padValue <= -kernelDimSize)
4111 return emitOpError("expected ")
4112 << padName << " > -" << kernelDimName << ", but got: " << padName
4113 << "=" << padValue << " and " << kernelDimName << "="
4114 << kernelDimSize;
4115 return success();
4116 };
4117
4118 const llvm::ArrayRef<int64_t> padding = getOutPad();
4119 const int64_t outPadTop = padding[0];
4120 const int64_t outPadBottom = padding[1];
4121 const int64_t outPadLeft = padding[2];
4122 const int64_t outPadRight = padding[3];
4123
4124 const auto weightType =
4125 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4126
4127 if (weightType) {
4128 const int64_t kernelHeight = weightType.getDimSize(1);
4129 if (ShapedType::isStatic(kernelHeight)) {
4130 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4131 "out_pad_top", "KH")))
4132 return failure();
4133
4134 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4135 "out_pad_bottom", "KH")))
4136 return failure();
4137 }
4138
4139 const int64_t kernelWidth = weightType.getDimSize(2);
4140 if (ShapedType::isStatic(kernelWidth)) {
4141 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4142 "out_pad_left", "KW")))
4143 return failure();
4144
4145 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4146 "out_pad_right", "KW")))
4147 return failure();
4148 }
4149 }
4150
4151 // Rest of the checks depend on the output type being a RankedTensorType
4152 const auto outputType =
4153 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4154 if (!outputType)
4155 return success();
4156
4157 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4158 if (inputType && weightType) {
4159 const int64_t inputHeight = inputType.getDimSize(1);
4160 const int64_t kernelHeight = weightType.getDimSize(1);
4161 const int64_t outputHeight = outputType.getDimSize(1);
4162
4163 if (ShapedType::isStatic(inputHeight) &&
4164 ShapedType::isStatic(outputHeight)) {
4165 if (outputHeight !=
4166 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4167 return emitOpError(
4168 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4169 "+ out_pad_top + out_pad_bottom + KH, but got ")
4170 << outputHeight << " != (" << inputHeight << " - 1) * "
4171 << strideY << " + " << outPadTop << " + " << outPadBottom
4172 << " + " << kernelHeight;
4173 }
4174
4175 const int64_t inputWidth = inputType.getDimSize(2);
4176 const int64_t kernelWidth = weightType.getDimSize(2);
4177 const int64_t outputWidth = outputType.getDimSize(2);
4178
4179 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4180 if (outputWidth !=
4181 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4182 return emitOpError(
4183 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4184 "+ out_pad_left + out_pad_right + KW, but got ")
4185 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4186 << " + " << outPadLeft << " + " << outPadRight << " + "
4187 << kernelWidth;
4188 }
4189 }
4190
4191 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4192
4193 if (!biasType)
4194 return success();
4195
4196 const int64_t biasChannels = biasType.getDimSize(0);
4197
4198 // Skip further checks if bias is dynamic
4199 if (biasChannels == ShapedType::kDynamic)
4200 return success();
4201
4202 const int64_t outputChannels = outputType.getDimSize(3);
4203 if (!ShapedType::isDynamic(outputChannels) &&
4204 biasChannels != outputChannels && biasChannels != 1)
4205 return emitOpError(
4206 "bias channels expected to be equal to output channels (")
4207 << outputChannels << ") or 1, got " << biasChannels;
4208
4209 return success();
4210}
4211
4212LogicalResult RescaleOp::verify() {
4213 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
4214 if (!inputType) {
4215 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
4216 return failure();
4217 }
4218
4219 auto inputElementType =
4220 getStorageElementTypeOrSelf(inputType.getElementType());
4221 if (!mlir::isa<IntegerType>(inputElementType)) {
4222 emitOpError("expect input to have integer element type, got ")
4223 << inputElementType;
4224 return failure();
4225 }
4226
4227 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
4228 if (!outputType) {
4229 emitOpError("expect shaped tensor for output, got ")
4230 << getOutput().getType();
4231 return failure();
4232 }
4233
4234 auto outputElementType =
4235 getStorageElementTypeOrSelf(outputType.getElementType());
4236 if (!mlir::isa<IntegerType>(outputElementType)) {
4237 emitOpError("expect output to have integer element type, got ")
4238 << outputElementType;
4239 return failure();
4240 }
4241
4242 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4243 .failed())
4244 return failure();
4245
4246 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4247 .failed())
4248 return failure();
4249
4250 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4251 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4252 return failure();
4253
4254 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4255 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4256 return failure();
4257
4258 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
4259 if (!multiplierType) {
4260 emitOpError("expect shaped tensor for multiplier, got ")
4261 << getMultiplier().getType();
4262 return failure();
4263 }
4264
4265 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
4266 if (!shiftType) {
4267 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
4268 return failure();
4269 }
4270
4271 // multiplier element type must be i32 for scale32 = true
4272 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4273 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4274 << multiplierType.getElementType();
4275 return failure();
4276 }
4277
4278 // multiplier element type must be i16 for scale32 = false
4279 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4281 "expect i16 element type for multiplier for scale32=false, got ")
4282 << multiplierType.getElementType();
4283 return failure();
4284 }
4285
4286 if (!inputType.hasRank())
4287 return success();
4288
4289 // multiplier/shift must have shape = {numChannels},
4290 // where numChannel is 1 if per_channel = false
4291 // otherwise numChannel is dimension in input shape's last axis
4292 int64_t numChannels = 1;
4293 if (getPerChannel()) {
4294 if (inputType.getRank() < 1) {
4295 emitOpError("requires input to be at least rank 1 when per_channel is "
4296 "true, but got rank ")
4297 << inputType.getRank();
4298 return failure();
4299 }
4300 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4301 }
4302
4303 if (!multiplierType.hasRank())
4304 return success();
4305
4306 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4307 // multiplier input has rank 1 by dialect definition
4308 if (multiplierShape[0] != ShapedType::kDynamic &&
4309 multiplierShape[0] != numChannels) {
4310 emitOpError("expect shape of { ")
4311 << numChannels << " } for multiplier input, got { "
4312 << multiplierShape[0] << " }";
4313 return failure();
4314 }
4315
4316 if (!shiftType.hasRank())
4317 return success();
4318
4319 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4320 // shift input has rank 1 by dialect definition
4321 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4322 emitOpError("expect shape of { ")
4323 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
4324 return failure();
4325 }
4326
4327 return success();
4328}
4329
4330LogicalResult RescaleOp::inferReturnTypeComponents(
4331 MLIRContext *context, ::std::optional<Location> location,
4332 RescaleOp::Adaptor adaptor,
4333 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4334 ShapeAdaptor inputShape(adaptor.getInput().getType());
4335 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4336 return success();
4337}
4338
4339LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4340 MLIRContext *context, ::std::optional<Location> location,
4341 CastFromBlockScaledOp::Adaptor adaptor,
4342 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4343 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4344 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4345 return success();
4346}
4347
4348LogicalResult CastFromBlockScaledOp::verify() {
4349 const Type inputDataType = getInputData().getType();
4350 const Type outputDataType = getResult().getType();
4351 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4352 return emitOpError() << "require compatible shapes for input_data ("
4353 << inputDataType << ") and " << "output_data ("
4354 << outputDataType << ")";
4355
4356 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4357
4358 if (inputDataShape.hasRank()) {
4359 const unsigned int blockSize =
4360 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4361 const int64_t inputDataLastDim =
4362 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4363 if (inputDataLastDim % blockSize != 0)
4364 return emitOpError() << "expect last dimension of input_data ("
4365 << inputDataLastDim
4366 << ") to be divisible by block_size (" << blockSize
4367 << ")";
4368
4369 const Type inputScaleType = getInputScale().getType();
4370 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4371
4372 if (inputScaleShape.hasRank()) {
4373 SmallVector<int64_t> inputDataDims, inputScaleDims;
4374 inputDataShape.getDims(inputDataDims);
4375 inputScaleShape.getDims(inputScaleDims);
4376
4377 if (inputDataDims.size() != inputScaleDims.size() ||
4379 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4380 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4381 return emitOpError()
4382 << "require compatible shapes for input_data (" << inputDataType
4383 << ") and " << "input_scale (" << inputScaleType
4384 << ") except for the last dimension";
4385
4386 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4387 inputScaleDims.back()};
4388 if (ShapedType::isStatic(inputDataLastDim) &&
4389 failed(verifyCompatibleDims(dimsToCheck)))
4390 return emitOpError()
4391 << "expect last dimension of input_scale ("
4392 << inputScaleDims.back()
4393 << ") to be equal to last dimension of input_data / block_size ("
4394 << inputDataDims.back() / blockSize << ")";
4395 }
4396 }
4397
4398 return success();
4399}
4400
4401LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4402 MLIRContext *context, ::std::optional<Location> location,
4403 CastToBlockScaledOp::Adaptor adaptor,
4404 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4405 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4406 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4407 if (!inputShape.hasRank())
4408 return success();
4409
4410 // Calculate output_scale shape if ranked input provided
4411 SmallVector<int64_t> outputScaleShape;
4412 inputShape.getDims(outputScaleShape);
4413 const int64_t lastDimLoc = inputShape.getRank() - 1;
4414 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4415 if (ShapedType::isStatic(lastDimSize)) {
4416 const unsigned int blockSize =
4417 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4418 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4419 }
4420 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4421 return success();
4422}
4423
4424LogicalResult CastToBlockScaledOp::verify() {
4425 const Type inputDataType = getInputData().getType();
4426 const Type outputDataType = getResult(0).getType();
4427 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4428 return emitOpError() << "require compatible shapes for input_data ("
4429 << inputDataType << ") and " << "output_data ("
4430 << outputDataType << ")";
4431
4432 const unsigned int blockSize =
4433 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4434 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4435 if (inputDataShape.hasRank()) {
4436 const int64_t inputDataLastDim =
4437 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4438 if (ShapedType::isStatic(inputDataLastDim) &&
4439 inputDataLastDim % blockSize != 0)
4440 return emitOpError() << "expect last dimension of input_data ("
4441 << inputDataLastDim
4442 << ") to be divisible by block_size (" << blockSize
4443 << ")";
4444 }
4445
4446 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4447 const Type outputScaleType = getResult(1).getType();
4448 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4449 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4450 SmallVector<int64_t> outputDataDims, outputScaleDims;
4451 outputDataShape.getDims(outputDataDims);
4452 outputScaleShape.getDims(outputScaleDims);
4453
4454 if (outputDataDims.size() != outputScaleDims.size() ||
4456 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4457 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4458 return emitOpError() << "require compatible shapes for output_data ("
4459 << outputDataType << ") and " << "output_scale ("
4460 << outputScaleType
4461 << ") except for the last dimension";
4462
4463 const int64_t outputDataLastDim = outputDataDims.back();
4464 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4465 outputScaleDims.back()};
4466 if (ShapedType::isStatic(outputDataLastDim) &&
4467 failed(verifyCompatibleDims(dimsToCheck)))
4468 return emitOpError()
4469 << "expect last dimension of output_scale ("
4470 << outputScaleDims.back()
4471 << ") to be equal to last dimension of output_data / block_size ("
4472 << outputDataDims.back() / blockSize << ")";
4473 }
4474
4475 return success();
4476}
4477
4478LogicalResult IfOp::inferReturnTypeComponents(
4479 MLIRContext *context, ::std::optional<Location> location,
4480 IfOp::Adaptor adaptor,
4481 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4482 llvm::SmallVector<tosa::YieldOp> yieldOps;
4483 for (Region *region : adaptor.getRegions()) {
4484 for (auto &block : *region)
4485 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4486 yieldOps.push_back(returnOp);
4487 }
4488
4489 if (yieldOps.empty())
4490 return failure();
4491
4492 // Get the initial type information for the yield op.
4493 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4494 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4495 for (auto operand : yieldOps.front().getOperands()) {
4496 resultKnowledge.push_back(
4497 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4498 }
4499
4500 for (auto yieldOp : yieldOps) {
4501 if (resultKnowledge.size() != yieldOp.getNumOperands())
4502 return failure();
4503
4504 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4505 int32_t index = it.index();
4506 auto meet = ValueKnowledge::meet(
4507 resultKnowledge[index],
4508 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4509 if (!meet)
4510 continue;
4511 resultKnowledge[index] = meet;
4512 }
4513 }
4514
4515 for (const ValueKnowledge &result : resultKnowledge) {
4516 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4517 }
4518
4519 return success();
4520}
4521
4522LogicalResult WhileOp::inferReturnTypeComponents(
4523 MLIRContext *context, ::std::optional<Location> location,
4524 WhileOp::Adaptor adaptor,
4525 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4526 llvm::SmallVector<tosa::YieldOp> yieldOps;
4527 for (auto &block : adaptor.getBodyGraph())
4528 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4529 yieldOps.push_back(returnOp);
4530
4531 // TOSA's while must have a tosa.yield as its terminator. If not found this
4532 // tosa.while is invalid.
4533 if (yieldOps.empty())
4534 return failure();
4535
4536 // Get the initial type information from the operand types.
4537 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4538 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4539 for (auto operand : yieldOps.front().getOperands()) {
4540 resultKnowledge.push_back(
4541 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4542 }
4543
4544 for (auto yieldOp : yieldOps) {
4545 if (resultKnowledge.size() != yieldOp.getNumOperands())
4546 return failure();
4547
4548 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4549 int32_t index = it.index();
4550 if (auto meet = ValueKnowledge::meet(
4551 resultKnowledge[index],
4552 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4553 resultKnowledge[index] = meet;
4554 }
4555 }
4556 }
4557
4558 for (const ValueKnowledge &result : resultKnowledge) {
4559 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4560 }
4561
4562 return success();
4563}
4564
4565std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4566 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4567 return llvm::to_vector<4>(vt.getShape());
4568 return std::nullopt;
4569}
4570
4572 Block::BlockArgListType blocksArgs,
4573 ValueRange initializers,
4574 StringRef prefix = "") {
4575 assert(blocksArgs.size() == initializers.size() &&
4576 "expected same length of arguments and initializers");
4577 if (initializers.empty())
4578 return;
4579
4580 parser << prefix << '(';
4581 llvm::interleaveComma(
4582 llvm::zip(blocksArgs, initializers), parser,
4583 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4584 parser << ")";
4585}
4586
4587// parse and print of IfOp refer to the implementation of SCF dialect.
4588ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4589 // Create the regions for 'then'.
4590 result.regions.reserve(2);
4591 Region *thenRegion = result.addRegion();
4592 Region *elseRegion = result.addRegion();
4593
4594 OpAsmParser::UnresolvedOperand cond;
4595
4596 if (parser.parseOperand(cond))
4597 return failure();
4598
4599 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4600 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4601
4602 // Parse the optional block arguments
4603 OptionalParseResult listResult =
4604 parser.parseOptionalAssignmentList(regionArgs, operands);
4605 if (listResult.has_value() && failed(listResult.value()))
4606 return failure();
4607
4608 // Parse a colon.
4609 if (failed(parser.parseColon()))
4610 return parser.emitError(parser.getCurrentLocation(),
4611 "expected type for condition operand");
4612
4613 // Parse the type of the condition operand
4614 Type condType;
4615 if (failed(parser.parseType(condType)))
4616 return parser.emitError(parser.getCurrentLocation(),
4617 "expected type for condition operand");
4618
4619 // Resolve operand with provided type
4620 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4621 return failure();
4622
4623 // Parse optional block arg types
4624 if (listResult.has_value()) {
4625 FunctionType functionType;
4626
4627 if (failed(parser.parseType(functionType)))
4628 return parser.emitError(parser.getCurrentLocation())
4629 << "expected list of types for block arguments "
4630 << "followed by arrow type and list of return types";
4631
4632 result.addTypes(functionType.getResults());
4633
4634 if (functionType.getNumInputs() != operands.size()) {
4635 return parser.emitError(parser.getCurrentLocation())
4636 << "expected as many input types as operands " << "(expected "
4637 << operands.size() << " got " << functionType.getNumInputs()
4638 << ")";
4639 }
4640
4641 // Resolve input operands.
4642 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4643 parser.getCurrentLocation(),
4644 result.operands)))
4645 return failure();
4646 } else {
4647 // Parse optional results type list.
4648 if (parser.parseOptionalArrowTypeList(result.types))
4649 return failure();
4650 }
4651
4652 // Parse the 'then' region.
4653 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4654 return failure();
4655
4656 // If we find an 'else' keyword then parse the 'else' region.
4657 if (!parser.parseOptionalKeyword("else")) {
4658 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4659 return failure();
4660 }
4661
4662 // Parse the optional attribute list.
4663 if (parser.parseOptionalAttrDict(result.attributes))
4664 return failure();
4665 return success();
4666}
4667
4668void IfOp::print(OpAsmPrinter &p) {
4669 p << " " << getCondition();
4670
4671 printInitializationList(p, getThenGraph().front().getArguments(),
4672 getInputList(), " ");
4673 p << " : ";
4674 p << getCondition().getType();
4675
4676 if (!getInputList().empty()) {
4677 p << " (";
4678 llvm::interleaveComma(getInputList().getTypes(), p);
4679 p << ")";
4680 }
4681 p.printArrowTypeList(getResultTypes());
4682 p << " ";
4683
4684 p.printRegion(getThenGraph());
4685
4686 // Print the 'else' regions if it exists and has a block.
4687 auto &elseRegion = getElseGraph();
4688 if (!elseRegion.empty()) {
4689 p << " else ";
4690 p.printRegion(elseRegion);
4691 }
4692
4693 p.printOptionalAttrDict((*this)->getAttrs());
4694}
4695
4696LogicalResult IfOp::verify() {
4697 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4698 "'then_graph' arguments", getInputList(),
4699 "'input_list'")
4700 .failed())
4701 return failure();
4702
4703 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4704 "'else_graph' arguments", getInputList(),
4705 "'input_list'")
4706 .failed())
4707 return failure();
4708
4709 // MLIR will verify the absence of the terminator for us if otherwise.
4710 if (getThenGraph().front().mightHaveTerminator()) {
4711 auto thenYield =
4712 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4713 if (thenYield && errorIfTypeOrShapeMismatch(
4714 *this, thenYield.getInputs(), "'then_graph' results",
4715 getOutputList(), "'output_list'")
4716 .failed())
4717 return failure();
4718 }
4719
4720 // MLIR will verify the absence of the terminator for us if otherwise.
4721 if (getElseGraph().front().mightHaveTerminator()) {
4722 auto elseYield =
4723 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4724 if (elseYield && errorIfTypeOrShapeMismatch(
4725 *this, elseYield.getInputs(), "'else_graph' results",
4726 getOutputList(), "'output_list'")
4727 .failed())
4728 return failure();
4729 }
4730
4731 auto condType = getCondition().getType();
4732 if (errorIfShapeNotSizeOne(*this, condType).failed())
4733 return emitOpError() << "'condition' must be a size 1 tensor, got "
4734 << condType;
4735
4736 return success();
4737}
4738
4739LogicalResult WhileOp::verify() {
4740 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4741 getOutputList(), "'output_list'")
4742 .failed())
4743 return failure();
4744
4745 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4746 "'cond_graph' arguments", getInputList(),
4747 "'input_list'")
4748 .failed())
4749 return failure();
4750
4751 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4752 "'body_graph' arguments", getInputList(),
4753 "'input_list'")
4754 .failed())
4755 return failure();
4756
4757 if (getBodyGraph().front().mightHaveTerminator()) {
4758 auto bodyYield =
4759 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4760 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4761 "'body_graph' results",
4762 getInputList(), "'input_list'")
4763 .failed())
4764 return failure();
4765 }
4766
4767 // Condition block output must be a single element tensor with a single bool
4768 // value.
4769 if (!getCondGraph().front().mightHaveTerminator())
4770 return success();
4771
4772 auto condYield =
4773 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4774 if (!condYield)
4775 return success();
4776
4777 if (condYield.getInputs().size() != 1)
4778 return emitOpError() << "require 'cond_graph' only have one result";
4779
4780 auto condOutType = condYield.getInputs()[0].getType();
4781 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4782 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4783 << condOutType;
4784
4785 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4786 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4787 << condOutType;
4788
4789 return success();
4790}
4791
4792LogicalResult ReverseOp::verify() {
4793 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4794 /* outType = */ getOutput().getType())
4795 .failed())
4796 return failure();
4797 TensorType inputType = getInput1().getType();
4798 TensorType outputType = getOutput().getType();
4799 int32_t reverseAxis = getAxis();
4800
4801 if (reverseAxis < 0)
4802 return emitOpError("expected non-negative reverse axis");
4803 if (inputType.hasRank()) {
4804 int64_t inputRank = inputType.getRank();
4805 // We allow for a special case where the input/output shape has rank 0 and
4806 // axis is also 0.
4807 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4808 return emitOpError("expect input tensor rank (")
4809 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4810 << ")";
4811 }
4812 if (outputType.hasRank()) {
4813 int64_t outputRank = outputType.getRank();
4814 if (inputType.hasRank() && outputRank != inputType.getRank())
4815 return emitOpError(
4816 "expect output tensor rank to be equal to input tensor rank");
4817 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4818 return emitOpError("expect output tensor rank (")
4819 << outputRank << ") to be larger than reverse axis ("
4820 << reverseAxis << ")";
4821 }
4822 return success();
4823}
4824
4825LogicalResult tosa::SelectOp::verify() {
4826 // verify input2 and input3 have same element type as output
4827 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4828 /* outType = */ getOutput().getType())
4829 .failed() ||
4830 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4831 /* outType = */ getOutput().getType())
4832 .failed()) {
4833 return failure();
4834 }
4835 // verify input1 has element type of bool
4836 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4837 if (!predicateType) {
4838 return emitOpError("expect shaped tensor for input1, got ")
4839 << getInput1().getType();
4840 }
4841 auto predicateElementType = predicateType.getElementType();
4842 if (!predicateElementType.isInteger(1)) {
4843 return emitOpError("expect element type of bool for input1, got ")
4844 << predicateElementType;
4845 }
4846
4847 return success();
4848}
4849
4850LogicalResult tosa::VariableReadOp::verify() {
4851 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4852 .failed())
4853 return failure();
4854
4855 return success();
4856}
4857
4858LogicalResult tosa::VariableWriteOp::verify() {
4859 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4860 .failed())
4861 return failure();
4862
4863 return success();
4864}
4865
4866// parse and print of WhileOp refer to the implementation of SCF dialect.
4867ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4868 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4869 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4870 Region *cond = result.addRegion();
4871 Region *body = result.addRegion();
4872
4873 OptionalParseResult listResult =
4874 parser.parseOptionalAssignmentList(regionArgs, operands);
4875 if (listResult.has_value() && failed(listResult.value()))
4876 return failure();
4877
4878 FunctionType functionType;
4879 SMLoc typeLoc = parser.getCurrentLocation();
4880 if (failed(parser.parseColonType(functionType)))
4881 return failure();
4882
4883 result.addTypes(functionType.getResults());
4884
4885 if (functionType.getNumInputs() != operands.size()) {
4886 return parser.emitError(typeLoc)
4887 << "expected as many input types as operands " << "(expected "
4888 << operands.size() << " got " << functionType.getNumInputs() << ")";
4889 }
4890
4891 // Resolve input operands.
4892 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4893 parser.getCurrentLocation(),
4894 result.operands)))
4895 return failure();
4896
4897 // Propagate the types into the region arguments.
4898 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4899 regionArgs[i].type = functionType.getInput(i);
4900
4901 return failure(parser.parseRegion(*cond, regionArgs) ||
4902 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4903 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4904}
4905
4906void WhileOp::print(OpAsmPrinter &parser) {
4907 printInitializationList(parser, getCondGraph().front().getArguments(),
4908 getInputList(), " ");
4909 parser << " : ";
4910 parser.printFunctionalType(getInputList().getTypes(),
4911 getResults().getTypes());
4912 parser << ' ';
4913 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4914 parser << " do ";
4915 parser.printRegion(getBodyGraph());
4916 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4917}
4918
4919// Create a rank-1 const tensor for zero point of the source tensor.
4920std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4921 Location loc,
4922 Type srcElemType,
4923 int64_t zp) {
4924 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4925 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4926 if (llvm::isa<FloatType>(srcElemType)) {
4927 auto zpAttr = DenseElementsAttr::get(
4928 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4929 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4930 }
4931 if (llvm::isa<IntegerType>(srcElemType)) {
4932 auto zpAttr =
4933 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4934 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4935 }
4936 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4937 return std::nullopt;
4938}
4939
4940//===----------------------------------------------------------------------===//
4941// TOSA Shape and Shape Operators Helper functions.
4942//===----------------------------------------------------------------------===//
4943
4945 return mlir::isa<tosa::shapeType>(t);
4946}
4947
4948LogicalResult
4949mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4950 int rank) {
4951 if (rank < 0)
4952 return emitError() << "invalid rank (must be >= 0): " << rank;
4953 return success();
4954}
4955
4957 for (auto v : op->getOperands()) {
4958 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4959 Operation *definingOp = v.getDefiningOp();
4960 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4961 return op->emitOpError("shape operand is not compile time resolvable");
4962 }
4963 }
4964 }
4965 return success();
4966}
4967
4968LogicalResult
4970 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
4971 return failure();
4972
4973 // delegate function that returns rank of shape type
4974 auto getRank = [](const Type type) {
4975 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4976 };
4977 auto operandTypes = op->getOperandTypes();
4978 auto resultTypes = op->getResultTypes();
4979
4980 auto rank = getRank(*op->getOperandTypes().begin());
4981 for (auto type : operandTypes) {
4982 if (getRank(type) != rank) {
4983 return op->emitOpError("operands don't have matching ranks");
4984 }
4985 }
4986 for (auto type : resultTypes) {
4987 if (getRank(type) != rank) {
4988 return op->emitOpError("result shape has different rank than operands");
4989 }
4990 }
4991 return success();
4992}
4993
4994//===----------------------------------------------------------------------===//
4995// TOSA Shape Operators verify functions.
4996//===----------------------------------------------------------------------===//
4997
4998LogicalResult tosa::ConstShapeOp::verify() {
4999 // check one dimensional rank
5000 auto valuesRank = getValues().getType().getRank();
5001 if (valuesRank != 1)
5002 return emitOpError("expect elements in attribute values with rank 1");
5003 // check that number of elements in values attr equal to rank of result shape
5004 auto count = getValues().getNumElements();
5005 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
5006 if (count != rank && (count != 1 || rank != 0)) {
5007 return emitOpError("expect number of elements in attribute values (")
5008 << count << ") to be equal to the rank (" << rank
5009 << ") for the result shape type";
5010 }
5011 return success();
5012}
5013
5014LogicalResult tosa::DimOp::verify() {
5015 const tosa::shapeType outShapeType =
5016 cast<tosa::shapeType>(getResult().getType());
5017 if (outShapeType.getRank() != 1)
5018 return emitOpError("expect output shape type to contain one element, got ")
5019 << outShapeType;
5020
5021 const ShapeAdaptor inputType(getInput1().getType());
5022 if (inputType.hasRank()) {
5023 const int64_t inputRank = inputType.getRank();
5024 const int64_t axis = getAxisAttr().getInt();
5025 if (axis < 0 || axis >= inputRank)
5026 return emitOpError("expect axis to be in the range [0, ")
5027 << inputRank << "), got " << axis;
5028 }
5029 return success();
5030}
5031
5032LogicalResult tosa::ConcatShapeOp::verify() {
5033 const tosa::shapeType outShapeType =
5034 cast<tosa::shapeType>(getResult().getType());
5035 const int64_t outputRank = outShapeType.getRank();
5036 const Operation::operand_range inputList = getInput();
5037
5038 if (inputList.size() == 0)
5039 return emitOpError("requires at least one input shape");
5040
5041 if (llvm::any_of(inputList, [](Value v) {
5042 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
5043 }))
5044 return emitOpError("requires all inputs shapes have a rank greater than 0");
5045
5046 const int64_t inputsRank =
5047 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
5048 const tosa::shapeType inShapeType =
5049 cast<tosa::shapeType>(input.getType());
5050 return acc + inShapeType.getRank();
5051 });
5052 if (outputRank != inputsRank)
5053 return emitOpError("requires output shape rank to be equal to the sum of "
5054 "the input shape ranks (")
5055 << inputsRank << "), got " << outputRank;
5056
5057 return success();
5058}
5059
5060LogicalResult tosa::SliceShapeOp::verify() {
5061 std::optional<int32_t> start;
5062 DenseIntElementsAttr startAttr;
5063 if (matchPattern(getStart(), m_Constant(&startAttr)))
5064 start = startAttr.getValues<int32_t>()[0];
5065 if (start && start.value() < 0)
5066 return emitOpError("expected non-negative start index, got ")
5067 << start.value();
5068
5069 std::optional<int32_t> size;
5070 DenseIntElementsAttr sizeAttr;
5071 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
5072 size = sizeAttr.getValues<int32_t>()[0];
5073 if (size && size.value() <= 0)
5074 return emitOpError("expected positive size, got ") << size.value();
5075
5076 if (!size)
5077 return success();
5078
5079 const tosa::shapeType outShapeType =
5080 cast<tosa::shapeType>(getResult().getType());
5081 const int64_t outputRank = outShapeType.getRank();
5082 if (outputRank != size)
5083 return emitOpError(
5084 "expected output type size to be equal to size attribute, got ")
5085 << outputRank << " vs " << size.value();
5086
5087 if (!start)
5088 return success();
5089
5090 const tosa::shapeType inShapeType =
5091 cast<tosa::shapeType>(getInput().getType());
5092 const int64_t inputRank = inShapeType.getRank();
5093 const int64_t sliceSize = start.value() + size.value();
5094 if (sliceSize > inputRank)
5095 return emitOpError("expected start + size to be less than or equal to "
5096 "input shape rank (")
5097 << inputRank << "), got " << sliceSize;
5098
5099 return success();
5100}
5101
5102//===----------------------------------------------------------------------===//
5103// TOSA Attribute Definitions.
5104//===----------------------------------------------------------------------===//
5105
5106#define GET_ATTRDEF_CLASSES
5107#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5108
5109//===----------------------------------------------------------------------===//
5110// TOSA Type Definitions.
5111//===----------------------------------------------------------------------===//
5112#define GET_TYPEDEF_CLASSES
5113#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5114
5115//===----------------------------------------------------------------------===//
5116// TOSA Operator Definitions.
5117//===----------------------------------------------------------------------===//
5118
5119#define GET_OP_CLASSES
5120#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1326
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1015
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3703
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:138
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3251
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:585
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:979
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3276
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:680
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:988
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3466
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1440
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:568
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1454
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:643
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3301
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3369
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2748
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1304
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:834
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2679
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:628
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:785
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3357
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:574
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3267
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1498
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1400
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1280
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1355
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:937
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1484
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2706
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1081
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1584
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static void updateIfDynamic(int64_t &current, int64_t candidate)
Definition TosaOps.cpp:3505
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3598
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3627
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3572
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
Definition TosaOps.cpp:3569
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3518
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3533
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
Definition TosaOps.cpp:3515
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3551
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3668
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
Definition TosaOps.cpp:3648
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3651
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3688
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OperandRange operand_range
Definition Operation.h:397
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4969
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4956
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:227
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:252
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:620
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4920
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4944
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:605
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45